CNN-Renew

This commit is contained in:
e2hang
2025-09-10 10:18:27 +08:00
parent a8d78878fc
commit 8db8502dba
21 changed files with 1171 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
import torch
import torch.nn as nn
import torch.optim as optim
# 1. 定义网络
model = nn.Sequential(
nn.Linear(10, 20), # 输入10维 → 输出20维
nn.ReLU(),
nn.Linear(20, 2) # 输出2维比如2分类
)
# 2. 定义损失函数 & 优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 3. 模拟训练一步
x = torch.randn(5, 10) # batch=5, feature=10
y = torch.tensor([0, 1, 0, 1, 1]) # 分类标签
out = model(x) # forward
loss = criterion(out, y) # loss
optimizer.zero_grad() # 梯度清零
loss.backward() # backward
optimizer.step() # 更新参数
print("loss =", loss.item())