Files
python/Pytorch/Project/ez0/MyFirstModule.py
2025-09-10 10:18:27 +08:00

50 lines
1.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
ez = nn.Sequential(
nn.Linear(2, 8),
nn.ReLU(),
nn.Linear(8, 8),
nn.ReLU(),
nn.Linear(8, 1)
)
# y = w1 @ x + w2 @ x^2 + b
# 两个特征第一个是x第二个是x^2
criterion = nn.MSELoss()
x = torch.linspace(-10, 10, 10000).unsqueeze(1)
X = torch.cat([x, x**2], dim=1)
y = x**2 + 2 * x + 1
inx = torch.tensor([[2, 4], [1, 1], [3, 9], [4, 16]], dtype=torch.float32)
target = torch.tensor([[9], [4], [16], [25]], dtype=torch.float32)
lr = 0.00001
for i in range(10000):
#向前传播
out = ez(X)
loss = criterion(out, y)
#反向传播
loss.backward()
with torch.no_grad():
# 3. 手动更新参数
with torch.no_grad(): # 禁止 autograd 追踪
for param in ez.parameters():
param -= lr * param.grad # 梯度下降更新参数
#梯度清零
ez.zero_grad()
if i % 1000 == 0:
print(f"Epoch {i}: loss={loss.item():.4f}")
import matplotlib.pyplot as plt
y_pred = ez(X).detach()
plt.scatter(x.numpy(), x.numpy()**2 + 2*x.numpy() + 1, label='real y = x^2 + 2x + 1')
plt.plot(x.numpy(), y_pred.numpy(), color='r', label='Predicted y = x^2 + 2x + 1')
plt.legend()
plt.show()