diff --git a/Pytorch/Project/CartPole/.idea/CartPole.iml b/Pytorch/Project/CartPole/.idea/CartPole.iml new file mode 100644 index 0000000..b910500 --- /dev/null +++ b/Pytorch/Project/CartPole/.idea/CartPole.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/Pytorch/Project/CartPole/.idea/inspectionProfiles/Project_Default.xml b/Pytorch/Project/CartPole/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..92645b2 --- /dev/null +++ b/Pytorch/Project/CartPole/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,266 @@ + + + + \ No newline at end of file diff --git a/Pytorch/Project/CartPole/.idea/inspectionProfiles/profiles_settings.xml b/Pytorch/Project/CartPole/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/Pytorch/Project/CartPole/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/Pytorch/Project/CartPole/.idea/misc.xml b/Pytorch/Project/CartPole/.idea/misc.xml new file mode 100644 index 0000000..15d27bc --- /dev/null +++ b/Pytorch/Project/CartPole/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/Pytorch/Project/CartPole/.idea/modules.xml b/Pytorch/Project/CartPole/.idea/modules.xml new file mode 100644 index 0000000..58f1aab --- /dev/null +++ b/Pytorch/Project/CartPole/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/Pytorch/Project/CartPole/.idea/workspace.xml b/Pytorch/Project/CartPole/.idea/workspace.xml new file mode 100644 index 0000000..36bc551 --- /dev/null +++ b/Pytorch/Project/CartPole/.idea/workspace.xml @@ -0,0 +1,79 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1757478687929 + + + + \ No newline at end of file diff --git a/Pytorch/Project/CartPole/AddBuffer.py b/Pytorch/Project/CartPole/AddBuffer.py new file mode 100644 index 0000000..36eac2b --- /dev/null +++ b/Pytorch/Project/CartPole/AddBuffer.py @@ -0,0 +1,114 @@ +from collections import deque +import numpy as np +import random +import torch +import torch.optim as optim +import torch.nn as nn +import gymnasium as gym +import torch.nn.functional as F + + +class DQN(nn.Module): + def __init__(self, state_size, action_size): + super(DQN, self).__init__() + self.l1 = nn.Linear(state_size, 128) + self.l2 = nn.Linear(128, 32) + self.l4 = nn.Linear(32, action_size) + + def forward(self, x): + x = F.relu(self.l1(x)) + x = F.relu(self.l2(x)) + x = self.l4(x) + return x + +lr = 0.001 +gamma = 0.99 +epsilon = 1.0 +memory = deque(maxlen=1000000) +batch_size = 64 +epsilon_decay = 0.995 +epsilon_min = 0.01 + +env = gym.make('CartPole-v1') +state_size = env.observation_space.shape[0] +action_size = env.action_space.n + + +ez = DQN(state_size, action_size) +optimizer = optim.Adam(ez.parameters(), lr=lr) +criterion = nn.MSELoss() + +def replay(): + if(len(memory) < batch_size): + return + + #随机选一组64个训练 + batch = random.sample(memory, batch_size) + states, actions, rewards, next_states, dones = zip(*batch) + states = torch.from_numpy(np.stack(states)) # 直接把 batch 堆成 array 再转 tensor + next_states = torch.from_numpy(np.stack(next_states)) + rewards = torch.from_numpy(np.array(rewards, dtype=np.float32)) + actions = torch.from_numpy(np.array(actions, dtype=np.int64)) + dones = torch.from_numpy(np.array(dones, dtype=np.float32)) + + q_values = ez(states).gather(1, actions.unsqueeze(1)).squeeze() + next_q_values = ez(next_states).max(1)[0] + target = rewards + gamma * next_q_values * (1 - dones) + + loss = criterion(q_values, target.detach()) + optimizer.zero_grad() + loss.backward() + optimizer.step() + +for e in range(350): + state, _ = env.reset() + done = False + total_reward = 0 + while not done: + if random.random() < epsilon: + action = env.action_space.sample() + else: + action = ez(torch.tensor(state, dtype=torch.float32)).argmax().item() + + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + # 在存储到 memory 时: + memory.append((np.array(state, dtype=np.float32), + action, + reward, + np.array(next_state, dtype=np.float32), + done)) + state = next_state + total_reward += reward + replay() + + epsilon = max(epsilon * epsilon_decay, epsilon_min) + print(f"Episode {e + 1}: Reward = {total_reward}") + +torch.save(ez.state_dict(), "cartpole_dqn.pth") +print("Model saved to cartpole_dqn.pth") + +# 测试函数 +def test_agent(env, model, episodes=10): + env = gym.make('CartPole-v1', render_mode='human') + total_rewards = [] + for ep in range(episodes): + state, _ = env.reset() + done = False + total_reward = 0.0 + while not done: + # 关闭梯度计算,加速 + with torch.no_grad(): + action = model(torch.tensor(state, dtype=torch.float32)).argmax().item() + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + state = next_state + total_reward += reward + total_rewards.append(total_reward) + print(f"Test Episode {ep+1}: reward = {total_reward}") + + avg_reward = sum(total_rewards) / episodes + print(f"Average reward over {episodes} episodes: {avg_reward}") + +# 调用测试 +test_agent(env, ez, episodes=10) \ No newline at end of file diff --git a/Pytorch/Project/CartPole/NoBuffer.py b/Pytorch/Project/CartPole/NoBuffer.py new file mode 100644 index 0000000..af1792e --- /dev/null +++ b/Pytorch/Project/CartPole/NoBuffer.py @@ -0,0 +1,114 @@ +import gymnasium as gym +import random +import torch +import torch.optim as optim +import torch.nn.functional as F +import torch.nn as nn +import time + +# 创建环境 +#env = gym.make("CartPole-v1", render_mode="human") # human模式会用pyglet显示窗口 +env = gym.make('CartPole-v1') +# 重置环境 +observation, info = env.reset() +print("初始观察值:", observation) + +#记住DQN训练的是Q*,输出的也是Q*,而不是动作,动作要根据Q*判断并反馈 +class DQN(nn.Module): + def __init__(self, state_size, action_size): + super(DQN, self).__init__() + self.l1 = nn.Linear(state_size, 128) + self.l3 = nn.Linear(128, 64) + self.l4 = nn.Linear(64, 32) + self.l5 = nn.Linear(32, action_size) + + def forward(self, x): + x = F.relu(self.l1(x)) + x = F.relu(self.l3(x)) + x = F.relu(self.l4(x)) + x = self.l5(x) + return x + + +#定义参数 +epsilon = 0.95 +state_size = env.observation_space.shape[0] +action_size = env.action_space.n +gamma = 0.99 +lrs = 0.005 +epsilon_decay = 0.995 +epsilon_min = 0.01 +batch_size = 64 +#memory = deque(maxlen=10000) +num_episodes = 500 + +#初始化 +ez = DQN(state_size, action_size) +optimizer = optim.Adam(ez.parameters(), lr = lrs) +criterion = nn.MSELoss() + +#训练 +for i in range(500): + state, _ = env.reset() # gym >=0.26 返回 (obs, info) + total_reward = 0.0 + done = False + + while not done: + # ε-greedy 选择动作 + if random.random() < epsilon: + action = env.action_space.sample() + else: + action = ez(torch.tensor(state, dtype=torch.float32)).argmax().item() + + # 与环境交互 + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + total_reward += reward + + # 计算 Q(s, a) + q_values = ez(torch.tensor(state, dtype=torch.float32)) + now_Q = q_values[action] + + # 计算 target + next_q_value = ez(torch.tensor(next_state, dtype=torch.float32)).max().item() + target_Q = reward + gamma * next_q_value * (0 if done else 1) + + # 计算损失 + target_Q = torch.tensor(target_Q, dtype=torch.float32) + loss = criterion(now_Q, target_Q) + + # 更新网络 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 状态更新 + state = next_state + + print(f"Episode {i}, total_reward = {total_reward}") + +# 测试函数 +def test_agent(env, model, episodes=10): + total_rewards = [] + for ep in range(episodes): + state, _ = env.reset() + done = False + total_reward = 0.0 + while not done: + # 关闭梯度计算,加速 + with torch.no_grad(): + action = model(torch.tensor(state, dtype=torch.float32)).argmax().item() + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + state = next_state + total_reward += reward + total_rewards.append(total_reward) + print(f"Test Episode {ep+1}: reward = {total_reward}") + + avg_reward = sum(total_rewards) / episodes + print(f"Average reward over {episodes} episodes: {avg_reward}") + +# 调用测试 +test_agent(env, ez, episodes=10) + +#最佳数据:平均150 \ No newline at end of file diff --git a/Pytorch/Project/CartPole/TestEnv.py b/Pytorch/Project/CartPole/TestEnv.py new file mode 100644 index 0000000..f1a7a5e --- /dev/null +++ b/Pytorch/Project/CartPole/TestEnv.py @@ -0,0 +1,22 @@ +import gymnasium as gym +import time + +# 创建环境 +env = gym.make("CartPole-v1", render_mode="human") # human模式会用pyglet显示窗口 + +# 重置环境 +observation, info = env.reset() +print("初始观察值:", observation) +print(":", info) + +# 随机动作走几步 +for step in range(100): + observation, reward, terminated, truncated, info = env.step(1) + print(f"Step {step+1}: observation={observation}, reward={reward}, done={terminated}") + + if terminated or truncated: + observation, info = env.reset() + +time.sleep(1) +# 关闭环境 +env.close() diff --git a/Pytorch/Project/CartPole/TestPthModule.py b/Pytorch/Project/CartPole/TestPthModule.py new file mode 100644 index 0000000..e69de29 diff --git a/Pytorch/Project/CartPole/cartpole_dqn_success.pth b/Pytorch/Project/CartPole/cartpole_dqn_success.pth new file mode 100644 index 0000000..66b081c Binary files /dev/null and b/Pytorch/Project/CartPole/cartpole_dqn_success.pth differ diff --git a/Pytorch/Project/ez0/HWCNN.png b/Pytorch/Project/ez0/HWCNN.png new file mode 100644 index 0000000..2a32317 Binary files /dev/null and b/Pytorch/Project/ez0/HWCNN.png differ diff --git a/Pytorch/Project/ez0/HWMLP.png b/Pytorch/Project/ez0/HWMLP.png new file mode 100644 index 0000000..c2ef621 Binary files /dev/null and b/Pytorch/Project/ez0/HWMLP.png differ diff --git a/Pytorch/Project/ez0/noiseGraph.png b/Pytorch/Project/ez0/noiseGraph.png new file mode 100644 index 0000000..c77187c Binary files /dev/null and b/Pytorch/Project/ez0/noiseGraph.png differ diff --git a/Pytorch/Project/ez0/noiseLoss.png b/Pytorch/Project/ez0/noiseLoss.png new file mode 100644 index 0000000..00115ee Binary files /dev/null and b/Pytorch/Project/ez0/noiseLoss.png differ diff --git a/Pytorch/Project/ez0/x^2 + 2x + 1.png b/Pytorch/Project/ez0/x^2 + 2x + 1.png new file mode 100644 index 0000000..b98428b Binary files /dev/null and b/Pytorch/Project/ez0/x^2 + 2x + 1.png differ