Files
python/Pytorch/Project/CartPole/AddBuffer.py
2025-09-10 20:52:02 +08:00

114 lines
3.5 KiB
Python

from collections import deque
import numpy as np
import random
import torch
import torch.optim as optim
import torch.nn as nn
import gymnasium as gym
import torch.nn.functional as F
class DQN(nn.Module):
def __init__(self, state_size, action_size):
super(DQN, self).__init__()
self.l1 = nn.Linear(state_size, 128)
self.l2 = nn.Linear(128, 32)
self.l4 = nn.Linear(32, action_size)
def forward(self, x):
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = self.l4(x)
return x
lr = 0.001
gamma = 0.99
epsilon = 1.0
memory = deque(maxlen=1000000)
batch_size = 64
epsilon_decay = 0.995
epsilon_min = 0.01
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
ez = DQN(state_size, action_size)
optimizer = optim.Adam(ez.parameters(), lr=lr)
criterion = nn.MSELoss()
def replay():
if(len(memory) < batch_size):
return
#随机选一组64个训练
batch = random.sample(memory, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.from_numpy(np.stack(states)) # 直接把 batch 堆成 array 再转 tensor
next_states = torch.from_numpy(np.stack(next_states))
rewards = torch.from_numpy(np.array(rewards, dtype=np.float32))
actions = torch.from_numpy(np.array(actions, dtype=np.int64))
dones = torch.from_numpy(np.array(dones, dtype=np.float32))
q_values = ez(states).gather(1, actions.unsqueeze(1)).squeeze()
next_q_values = ez(next_states).max(1)[0]
target = rewards + gamma * next_q_values * (1 - dones)
loss = criterion(q_values, target.detach())
optimizer.zero_grad()
loss.backward()
optimizer.step()
for e in range(350):
state, _ = env.reset()
done = False
total_reward = 0
while not done:
if random.random() < epsilon:
action = env.action_space.sample()
else:
action = ez(torch.tensor(state, dtype=torch.float32)).argmax().item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
# 在存储到 memory 时:
memory.append((np.array(state, dtype=np.float32),
action,
reward,
np.array(next_state, dtype=np.float32),
done))
state = next_state
total_reward += reward
replay()
epsilon = max(epsilon * epsilon_decay, epsilon_min)
print(f"Episode {e + 1}: Reward = {total_reward}")
torch.save(ez.state_dict(), "cartpole_dqn.pth")
print("Model saved to cartpole_dqn.pth")
# 测试函数
def test_agent(env, model, episodes=10):
env = gym.make('CartPole-v1', render_mode='human')
total_rewards = []
for ep in range(episodes):
state, _ = env.reset()
done = False
total_reward = 0.0
while not done:
# 关闭梯度计算,加速
with torch.no_grad():
action = model(torch.tensor(state, dtype=torch.float32)).argmax().item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
state = next_state
total_reward += reward
total_rewards.append(total_reward)
print(f"Test Episode {ep+1}: reward = {total_reward}")
avg_reward = sum(total_rewards) / episodes
print(f"Average reward over {episodes} episodes: {avg_reward}")
# 调用测试
test_agent(env, ez, episodes=10)