Files
python/Pytorch/Project/ez0/HandWriteCNN.py
2025-09-10 10:18:27 +08:00

116 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import matplotlib.pyplot as plt
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from main import optimizer
# 设置超参数
batch_size = 64
# 定义预处理步骤
transform = transforms.Compose([
transforms.ToTensor(), # 转换为张量,范围 [0,1]
transforms.Normalize((0.1307,), (0.3081,)) # 标准化:均值、方差是 MNIST 的经验值
])
# 加载训练集
train_dataset = datasets.MNIST(
root='./data', # 数据存放路径
train=True, # 训练集
download=True, # 如果没有就下载
transform=transform # 应用预处理
)
# 加载测试集
test_dataset = datasets.MNIST(
root='./data',
train=False, # 测试集
download=True,
transform=transform
)
# 构建 DataLoader
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True # 打乱数据,适合训练
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=False # 测试集不需要打乱
)
# 简单测试一下
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
# 取一个 batch 看看形状
images, labels = next(iter(train_loader))
print(f"图片批次维度: {images.shape}") # [batch_size, 1, 28, 28]
print(f"标签批次维度: {labels.shape}") # [batch_size]
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.c1 = nn.Conv2d(1, 16, 3, padding=1)
self.c2 = nn.Conv2d(16, 32, 3, padding=1)
self.c3 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.linear0 = nn.Linear(64 * 7 * 7, 128) # 注意这里是14*14如果只池化一次池化一次减半
self.linear1 = nn.Linear(128, 64)
self.linear2 = nn.Linear(64, 32)
self.linear3 = nn.Linear(32, 10)
self.drop = nn.Dropout(p=0.31) # 丢弃概率
def forward(self, x):
x = F.relu(self.c1(x))
x = self.pool(F.relu(self.c2(x))) # [batch,32,14,14] → pool → [batch,32,7,7]
x = self.pool(F.relu(self.c3(x)))
x = x.view(x.size(0), -1) # flatten
x = F.relu(self.linear0(x))
x = F.relu(self.linear1(x))
x = self.drop(x)
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
ez = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(ez.parameters(), lr=0.001)
for i in range(10):
ez.train()
for images, labels in train_loader:
out = ez(images)
loss = criterion(out, labels)
#反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"<UNK>: {loss}")
#训练结束
ez.eval() # 关闭 dropout/batchnorm 等训练特性
correct = 0
total = 0
with torch.no_grad(): # 测试不需要计算梯度,节省显存
for images, labels in test_loader:
outputs = ez(images) # [batch_size, 10]
# 取每行最大值对应的索引作为预测类别
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"测试集准确率: {correct}/{total} = {correct/total*100:.2f}%")