CNN-Renew
This commit is contained in:
49
Pytorch/Project/ez0/MyFirstModule.py
Normal file
49
Pytorch/Project/ez0/MyFirstModule.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
ez = nn.Sequential(
|
||||
nn.Linear(2, 8),
|
||||
nn.ReLU(),
|
||||
nn.Linear(8, 8),
|
||||
nn.ReLU(),
|
||||
nn.Linear(8, 1)
|
||||
)
|
||||
# y = w1 @ x + w2 @ x^2 + b
|
||||
# 两个特征,第一个是x,第二个是x^2
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
x = torch.linspace(-10, 10, 10000).unsqueeze(1)
|
||||
X = torch.cat([x, x**2], dim=1)
|
||||
y = x**2 + 2 * x + 1
|
||||
|
||||
inx = torch.tensor([[2, 4], [1, 1], [3, 9], [4, 16]], dtype=torch.float32)
|
||||
target = torch.tensor([[9], [4], [16], [25]], dtype=torch.float32)
|
||||
|
||||
lr = 0.00001
|
||||
for i in range(10000):
|
||||
#向前传播
|
||||
out = ez(X)
|
||||
loss = criterion(out, y)
|
||||
|
||||
#反向传播
|
||||
loss.backward()
|
||||
|
||||
with torch.no_grad():
|
||||
# 3. 手动更新参数
|
||||
with torch.no_grad(): # 禁止 autograd 追踪
|
||||
for param in ez.parameters():
|
||||
param -= lr * param.grad # 梯度下降更新参数
|
||||
|
||||
#梯度清零
|
||||
ez.zero_grad()
|
||||
|
||||
if i % 1000 == 0:
|
||||
print(f"Epoch {i}: loss={loss.item():.4f}")
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
y_pred = ez(X).detach()
|
||||
plt.scatter(x.numpy(), x.numpy()**2 + 2*x.numpy() + 1, label='real y = x^2 + 2x + 1')
|
||||
plt.plot(x.numpy(), y_pred.numpy(), color='r', label='Predicted y = x^2 + 2x + 1')
|
||||
plt.legend()
|
||||
plt.show()
|
Reference in New Issue
Block a user