CNN-Renew
This commit is contained in:
55
Pytorch/Project/ez0/noise.py
Normal file
55
Pytorch/Project/ez0/noise.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from torch import optim
|
||||
|
||||
# x 张量
|
||||
x = torch.linspace(-10, 10, 100).unsqueeze(1)
|
||||
|
||||
# 特征矩阵
|
||||
X = torch.cat([x, x**2, x**3, torch.sin(2*x)], dim=1)
|
||||
|
||||
# 高斯噪声
|
||||
noise = torch.from_numpy(np.random.normal(0, 3, size=x.shape)).float()
|
||||
|
||||
# 带噪声的 y
|
||||
y_noisy = 0.5 * x**3 - 2 * x**2 + 3*x + 5 + 4 * torch.sin(2*x) + noise
|
||||
|
||||
# 可视化
|
||||
'''
|
||||
plt.scatter(x, y_noisy, label="Noisy data")
|
||||
plt.plot(x, 0.5 * x**3 - 2 * x**2 + 3*x + 5 + 4 * torch.sin(2*x), color='red', label="Original function")
|
||||
plt.legend()
|
||||
plt.show()
|
||||
'''
|
||||
ez = nn.Sequential(
|
||||
nn.Linear(4, 8),
|
||||
nn.ReLU(),
|
||||
nn.Linear(8, 4),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4, 1),
|
||||
)
|
||||
|
||||
criterion = nn.MSELoss()
|
||||
optimizer = optim.Adam(ez.parameters(), lr=0.0065)
|
||||
|
||||
for i in range(10000):
|
||||
#向前传播
|
||||
out = ez(X)
|
||||
loss = criterion(out, y_noisy)
|
||||
#反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i % 100 == 0:
|
||||
print("Step:", i, "loss =", loss.item())
|
||||
|
||||
with torch.no_grad():
|
||||
y_pred = ez(X)
|
||||
|
||||
plt.scatter(x.numpy(), y_noisy.numpy(), alpha=0.3, label="Noisy data")
|
||||
plt.plot(x.numpy(), y_pred.numpy(), color='red', label="NN prediction")
|
||||
plt.legend()
|
||||
plt.show()
|
Reference in New Issue
Block a user