import torch import gymnasium as gym import torch.nn as nn import torch.nn.functional as F # 定义和训练时一样的网络结构 class DQN(nn.Module): def __init__(self, state_size, action_size): super(DQN, self).__init__() self.l1 = nn.Linear(state_size, 128) self.l2 = nn.Linear(128, 32) self.l4 = nn.Linear(32, action_size) def forward(self, x): x = F.relu(self.l1(x)) x = F.relu(self.l2(x)) x = self.l4(x) return x # 创建环境 (带渲染) env = gym.make("CartPole-v1", render_mode="human") # 初始化模型并加载权重 state_size = env.observation_space.shape[0] action_size = env.action_space.n model = DQN(state_size, action_size) model.load_state_dict(torch.load("cartpole_dqn_success.pth")) model.eval() print("已加载模型 cartpole_dqn.pth") # 测试 for ep in range(5): # 测试 5 回合 state, _ = env.reset() done = False total_reward = 0 while not done: with torch.no_grad(): action = model(torch.tensor(state, dtype=torch.float32)).argmax().item() state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated total_reward += reward env.render() print(f"Episode {ep+1} reward = {total_reward}") env.close()