Files
python/Pytorch/Project/CartPole/NoBuffer.py
2025-09-10 20:52:02 +08:00

114 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gymnasium as gym
import random
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import time
# 创建环境
#env = gym.make("CartPole-v1", render_mode="human") # human模式会用pyglet显示窗口
env = gym.make('CartPole-v1')
# 重置环境
observation, info = env.reset()
print("初始观察值:", observation)
#记住DQN训练的是Q*输出的也是Q*而不是动作动作要根据Q*判断并反馈
class DQN(nn.Module):
def __init__(self, state_size, action_size):
super(DQN, self).__init__()
self.l1 = nn.Linear(state_size, 128)
self.l3 = nn.Linear(128, 64)
self.l4 = nn.Linear(64, 32)
self.l5 = nn.Linear(32, action_size)
def forward(self, x):
x = F.relu(self.l1(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
x = self.l5(x)
return x
#定义参数
epsilon = 0.95
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
gamma = 0.99
lrs = 0.005
epsilon_decay = 0.995
epsilon_min = 0.01
batch_size = 64
#memory = deque(maxlen=10000)
num_episodes = 500
#初始化
ez = DQN(state_size, action_size)
optimizer = optim.Adam(ez.parameters(), lr = lrs)
criterion = nn.MSELoss()
#训练
for i in range(500):
state, _ = env.reset() # gym >=0.26 返回 (obs, info)
total_reward = 0.0
done = False
while not done:
# ε-greedy 选择动作
if random.random() < epsilon:
action = env.action_space.sample()
else:
action = ez(torch.tensor(state, dtype=torch.float32)).argmax().item()
# 与环境交互
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
total_reward += reward
# 计算 Q(s, a)
q_values = ez(torch.tensor(state, dtype=torch.float32))
now_Q = q_values[action]
# 计算 target
next_q_value = ez(torch.tensor(next_state, dtype=torch.float32)).max().item()
target_Q = reward + gamma * next_q_value * (0 if done else 1)
# 计算损失
target_Q = torch.tensor(target_Q, dtype=torch.float32)
loss = criterion(now_Q, target_Q)
# 更新网络
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 状态更新
state = next_state
print(f"Episode {i}, total_reward = {total_reward}")
# 测试函数
def test_agent(env, model, episodes=10):
total_rewards = []
for ep in range(episodes):
state, _ = env.reset()
done = False
total_reward = 0.0
while not done:
# 关闭梯度计算,加速
with torch.no_grad():
action = model(torch.tensor(state, dtype=torch.float32)).argmax().item()
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
state = next_state
total_reward += reward
total_rewards.append(total_reward)
print(f"Test Episode {ep+1}: reward = {total_reward}")
avg_reward = sum(total_rewards) / episodes
print(f"Average reward over {episodes} episodes: {avg_reward}")
# 调用测试
test_agent(env, ez, episodes=10)
#最佳数据平均150