diff --git a/Pytorch/Project/CartPole/.idea/workspace.xml b/Pytorch/Project/CartPole/.idea/workspace.xml index 36bc551..365859e 100644 --- a/Pytorch/Project/CartPole/.idea/workspace.xml +++ b/Pytorch/Project/CartPole/.idea/workspace.xml @@ -17,24 +17,26 @@ - + { + "associatedIndex": 1 +} - { + "keyToString": { + "ModuleVcsDetector.initialDetectionPerformed": "true", + "Python.AddBuffer.executor": "Run", + "Python.NoBuffer.executor": "Run", + "Python.TestEnv.executor": "Run", + "Python.TestPthModule.executor": "Run", + "Python.main.executor": "Run", + "RunOnceActivity.ShowReadmeOnStart": "true", + "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" } -}]]> +} diff --git a/Pytorch/Project/CartPole/TestPthModule.py b/Pytorch/Project/CartPole/TestPthModule.py index e69de29..95466e7 100644 --- a/Pytorch/Project/CartPole/TestPthModule.py +++ b/Pytorch/Project/CartPole/TestPthModule.py @@ -0,0 +1,46 @@ +import torch +import gymnasium as gym +import torch.nn as nn +import torch.nn.functional as F + +# 定义和训练时一样的网络结构 +class DQN(nn.Module): + def __init__(self, state_size, action_size): + super(DQN, self).__init__() + self.l1 = nn.Linear(state_size, 128) + self.l2 = nn.Linear(128, 32) + self.l4 = nn.Linear(32, action_size) + + def forward(self, x): + x = F.relu(self.l1(x)) + x = F.relu(self.l2(x)) + x = self.l4(x) + return x + + +# 创建环境 (带渲染) +env = gym.make("CartPole-v1", render_mode="human") + +# 初始化模型并加载权重 +state_size = env.observation_space.shape[0] +action_size = env.action_space.n +model = DQN(state_size, action_size) +model.load_state_dict(torch.load("cartpole_dqn_success.pth")) +model.eval() +print("已加载模型 cartpole_dqn.pth") + +# 测试 +for ep in range(5): # 测试 5 回合 + state, _ = env.reset() + done = False + total_reward = 0 + while not done: + with torch.no_grad(): + action = model(torch.tensor(state, dtype=torch.float32)).argmax().item() + state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + total_reward += reward + env.render() + print(f"Episode {ep+1} reward = {total_reward}") + +env.close() diff --git a/Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc b/Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc new file mode 100644 index 0000000..1ff3e69 Binary files /dev/null and b/Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc differ