27 lines
720 B
Python
27 lines
720 B
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
|
||
# 1. 定义网络
|
||
model = nn.Sequential(
|
||
nn.Linear(10, 20), # 输入10维 → 输出20维
|
||
nn.ReLU(),
|
||
nn.Linear(20, 2) # 输出2维(比如2分类)
|
||
)
|
||
|
||
# 2. 定义损失函数 & 优化器
|
||
criterion = nn.CrossEntropyLoss()
|
||
optimizer = optim.SGD(model.parameters(), lr=0.01)
|
||
|
||
# 3. 模拟训练一步
|
||
x = torch.randn(5, 10) # batch=5, feature=10
|
||
y = torch.tensor([0, 1, 0, 1, 1]) # 分类标签
|
||
|
||
out = model(x) # forward
|
||
loss = criterion(out, y) # loss
|
||
optimizer.zero_grad() # 梯度清零
|
||
loss.backward() # backward
|
||
optimizer.step() # 更新参数
|
||
|
||
print("loss =", loss.item())
|