112 lines
2.9 KiB
Python
112 lines
2.9 KiB
Python
import torch
|
|
import matplotlib.pyplot as plt
|
|
from torch import nn
|
|
import torch.optim as optim
|
|
from torchvision import datasets, transforms
|
|
from torch.utils.data import DataLoader
|
|
|
|
from main import optimizer
|
|
|
|
# 设置超参数
|
|
batch_size = 64
|
|
|
|
# 定义预处理步骤
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(), # 转换为张量,范围 [0,1]
|
|
transforms.Normalize((0.1307,), (0.3081,)) # 标准化:均值、方差是 MNIST 的经验值
|
|
])
|
|
|
|
# 加载训练集
|
|
train_dataset = datasets.MNIST(
|
|
root='./data', # 数据存放路径
|
|
train=True, # 训练集
|
|
download=True, # 如果没有就下载
|
|
transform=transform # 应用预处理
|
|
)
|
|
|
|
# 加载测试集
|
|
test_dataset = datasets.MNIST(
|
|
root='./data',
|
|
train=False, # 测试集
|
|
download=True,
|
|
transform=transform
|
|
)
|
|
|
|
# 构建 DataLoader
|
|
train_loader = DataLoader(
|
|
dataset=train_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True # 打乱数据,适合训练
|
|
)
|
|
|
|
test_loader = DataLoader(
|
|
dataset=test_dataset,
|
|
batch_size=batch_size,
|
|
shuffle=False # 测试集不需要打乱
|
|
)
|
|
|
|
# 简单测试一下
|
|
print(f"训练集大小: {len(train_dataset)}")
|
|
print(f"测试集大小: {len(test_dataset)}")
|
|
|
|
# 取一个 batch 看看形状
|
|
images, labels = next(iter(train_loader))
|
|
print(f"图片批次维度: {images.shape}") # [batch_size, 1, 28, 28]
|
|
print(f"标签批次维度: {labels.shape}") # [batch_size]
|
|
|
|
# 从训练集中取一个 batch
|
|
images, labels = next(iter(train_loader))
|
|
'''
|
|
# 画前 9 张图
|
|
fig, axes = plt.subplots(3, 3, figsize=(6, 6))
|
|
for i, ax in enumerate(axes.flat):
|
|
img = images[i].squeeze().numpy() # [1,28,28] -> [28,28]
|
|
label = labels[i].item()
|
|
ax.imshow(img, cmap="gray")
|
|
ax.set_title(f"Label: {label}")
|
|
ax.axis("off")
|
|
|
|
plt.tight_layout()
|
|
plt.show()
|
|
'''
|
|
|
|
ez = nn.Sequential(
|
|
nn.Linear(28 * 28, 256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, 100),
|
|
nn.ReLU(),
|
|
nn.Linear(100, 10),
|
|
)
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
optimizer = optim.Adam(ez.parameters(), lr=0.002)
|
|
|
|
for images, labels in train_loader:
|
|
images = images.view(images.size(0), -1) # [batch_size, 28*28]
|
|
out = ez(images)
|
|
loss = criterion(out, labels)
|
|
#反向传播
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
print(f"<UNK>: {loss}")
|
|
|
|
#训练结束
|
|
ez.eval() # 关闭 dropout/batchnorm 等训练特性
|
|
|
|
correct = 0
|
|
total = 0
|
|
|
|
with torch.no_grad(): # 测试不需要计算梯度,节省显存
|
|
for images, labels in test_loader:
|
|
images = images.view(images.size(0), -1) # flatten
|
|
outputs = ez(images) # [batch_size, 10]
|
|
|
|
# 取每行最大值对应的索引作为预测类别
|
|
_, predicted = torch.max(outputs, 1)
|
|
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
|
|
print(f"测试集准确率: {correct}/{total} = {correct/total*100:.2f}%")
|