diff --git a/Pytorch/Project/CartPole/.idea/workspace.xml b/Pytorch/Project/CartPole/.idea/workspace.xml
index 36bc551..365859e 100644
--- a/Pytorch/Project/CartPole/.idea/workspace.xml
+++ b/Pytorch/Project/CartPole/.idea/workspace.xml
@@ -17,24 +17,26 @@
-
+ {
+ "associatedIndex": 1
+}
- {
+ "keyToString": {
+ "ModuleVcsDetector.initialDetectionPerformed": "true",
+ "Python.AddBuffer.executor": "Run",
+ "Python.NoBuffer.executor": "Run",
+ "Python.TestEnv.executor": "Run",
+ "Python.TestPthModule.executor": "Run",
+ "Python.main.executor": "Run",
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable"
}
-}]]>
+}
diff --git a/Pytorch/Project/CartPole/TestPthModule.py b/Pytorch/Project/CartPole/TestPthModule.py
index e69de29..95466e7 100644
--- a/Pytorch/Project/CartPole/TestPthModule.py
+++ b/Pytorch/Project/CartPole/TestPthModule.py
@@ -0,0 +1,46 @@
+import torch
+import gymnasium as gym
+import torch.nn as nn
+import torch.nn.functional as F
+
+# 定义和训练时一样的网络结构
+class DQN(nn.Module):
+ def __init__(self, state_size, action_size):
+ super(DQN, self).__init__()
+ self.l1 = nn.Linear(state_size, 128)
+ self.l2 = nn.Linear(128, 32)
+ self.l4 = nn.Linear(32, action_size)
+
+ def forward(self, x):
+ x = F.relu(self.l1(x))
+ x = F.relu(self.l2(x))
+ x = self.l4(x)
+ return x
+
+
+# 创建环境 (带渲染)
+env = gym.make("CartPole-v1", render_mode="human")
+
+# 初始化模型并加载权重
+state_size = env.observation_space.shape[0]
+action_size = env.action_space.n
+model = DQN(state_size, action_size)
+model.load_state_dict(torch.load("cartpole_dqn_success.pth"))
+model.eval()
+print("已加载模型 cartpole_dqn.pth")
+
+# 测试
+for ep in range(5): # 测试 5 回合
+ state, _ = env.reset()
+ done = False
+ total_reward = 0
+ while not done:
+ with torch.no_grad():
+ action = model(torch.tensor(state, dtype=torch.float32)).argmax().item()
+ state, reward, terminated, truncated, _ = env.step(action)
+ done = terminated or truncated
+ total_reward += reward
+ env.render()
+ print(f"Episode {ep+1} reward = {total_reward}")
+
+env.close()
diff --git a/Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc b/Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc
new file mode 100644
index 0000000..1ff3e69
Binary files /dev/null and b/Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc differ