Files
python/Pytorch/Project/ez0/main.py
2025-09-10 10:18:27 +08:00

27 lines
720 B
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.optim as optim
# 1. 定义网络
model = nn.Sequential(
nn.Linear(10, 20), # 输入10维 → 输出20维
nn.ReLU(),
nn.Linear(20, 2) # 输出2维比如2分类
)
# 2. 定义损失函数 & 优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 3. 模拟训练一步
x = torch.randn(5, 10) # batch=5, feature=10
y = torch.tensor([0, 1, 0, 1, 1]) # 分类标签
out = model(x) # forward
loss = criterion(out, y) # loss
optimizer.zero_grad() # 梯度清零
loss.backward() # backward
optimizer.step() # 更新参数
print("loss =", loss.item())