CNN-Renew
This commit is contained in:
26
Pytorch/Project/ez0/main.py
Normal file
26
Pytorch/Project/ez0/main.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
# 1. 定义网络
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 20), # 输入10维 → 输出20维
|
||||
nn.ReLU(),
|
||||
nn.Linear(20, 2) # 输出2维(比如2分类)
|
||||
)
|
||||
|
||||
# 2. 定义损失函数 & 优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||||
|
||||
# 3. 模拟训练一步
|
||||
x = torch.randn(5, 10) # batch=5, feature=10
|
||||
y = torch.tensor([0, 1, 0, 1, 1]) # 分类标签
|
||||
|
||||
out = model(x) # forward
|
||||
loss = criterion(out, y) # loss
|
||||
optimizer.zero_grad() # 梯度清零
|
||||
loss.backward() # backward
|
||||
optimizer.step() # 更新参数
|
||||
|
||||
print("loss =", loss.item())
|
Reference in New Issue
Block a user