116 lines
3.3 KiB
Python
116 lines
3.3 KiB
Python
import torch
|
||
import matplotlib.pyplot as plt
|
||
from torch import nn
|
||
import torch.optim as optim
|
||
import torch.nn.functional as F
|
||
from torchvision import datasets, transforms
|
||
from torch.utils.data import DataLoader
|
||
|
||
from main import optimizer
|
||
|
||
# 设置超参数
|
||
batch_size = 64
|
||
|
||
# 定义预处理步骤
|
||
transform = transforms.Compose([
|
||
transforms.ToTensor(), # 转换为张量,范围 [0,1]
|
||
transforms.Normalize((0.1307,), (0.3081,)) # 标准化:均值、方差是 MNIST 的经验值
|
||
])
|
||
|
||
# 加载训练集
|
||
train_dataset = datasets.MNIST(
|
||
root='./data', # 数据存放路径
|
||
train=True, # 训练集
|
||
download=True, # 如果没有就下载
|
||
transform=transform # 应用预处理
|
||
)
|
||
|
||
# 加载测试集
|
||
test_dataset = datasets.MNIST(
|
||
root='./data',
|
||
train=False, # 测试集
|
||
download=True,
|
||
transform=transform
|
||
)
|
||
|
||
# 构建 DataLoader
|
||
train_loader = DataLoader(
|
||
dataset=train_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=True # 打乱数据,适合训练
|
||
)
|
||
|
||
test_loader = DataLoader(
|
||
dataset=test_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=False # 测试集不需要打乱
|
||
)
|
||
|
||
# 简单测试一下
|
||
print(f"训练集大小: {len(train_dataset)}")
|
||
print(f"测试集大小: {len(test_dataset)}")
|
||
|
||
# 取一个 batch 看看形状
|
||
images, labels = next(iter(train_loader))
|
||
print(f"图片批次维度: {images.shape}") # [batch_size, 1, 28, 28]
|
||
print(f"标签批次维度: {labels.shape}") # [batch_size]
|
||
|
||
class CNN(nn.Module):
|
||
def __init__(self):
|
||
super(CNN, self).__init__()
|
||
self.c1 = nn.Conv2d(1, 16, 3, padding=1)
|
||
self.c2 = nn.Conv2d(16, 32, 3, padding=1)
|
||
self.c3 = nn.Conv2d(32, 64, 3, padding=1)
|
||
self.pool = nn.MaxPool2d(2, 2)
|
||
self.linear0 = nn.Linear(64 * 7 * 7, 128) # 注意这里是14*14,如果只池化一次,池化一次减半
|
||
self.linear1 = nn.Linear(128, 64)
|
||
self.linear2 = nn.Linear(64, 32)
|
||
self.linear3 = nn.Linear(32, 10)
|
||
self.drop = nn.Dropout(p=0.31) # 丢弃概率
|
||
|
||
def forward(self, x):
|
||
x = F.relu(self.c1(x))
|
||
x = self.pool(F.relu(self.c2(x))) # [batch,32,14,14] → pool → [batch,32,7,7]
|
||
x = self.pool(F.relu(self.c3(x)))
|
||
x = x.view(x.size(0), -1) # flatten
|
||
x = F.relu(self.linear0(x))
|
||
x = F.relu(self.linear1(x))
|
||
x = self.drop(x)
|
||
x = F.relu(self.linear2(x))
|
||
x = self.linear3(x)
|
||
return x
|
||
|
||
ez = CNN()
|
||
|
||
criterion = nn.CrossEntropyLoss()
|
||
optimizer = optim.AdamW(ez.parameters(), lr=0.001)
|
||
|
||
for i in range(10):
|
||
ez.train()
|
||
for images, labels in train_loader:
|
||
out = ez(images)
|
||
loss = criterion(out, labels)
|
||
#反向传播
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
optimizer.step()
|
||
print(f"<UNK>: {loss}")
|
||
|
||
#训练结束
|
||
ez.eval() # 关闭 dropout/batchnorm 等训练特性
|
||
|
||
correct = 0
|
||
total = 0
|
||
|
||
with torch.no_grad(): # 测试不需要计算梯度,节省显存
|
||
for images, labels in test_loader:
|
||
outputs = ez(images) # [batch_size, 10]
|
||
|
||
# 取每行最大值对应的索引作为预测类别
|
||
_, predicted = torch.max(outputs, 1)
|
||
|
||
total += labels.size(0)
|
||
correct += (predicted == labels).sum().item()
|
||
|
||
print(f"测试集准确率: {correct}/{total} = {correct/total*100:.2f}%")
|