From 5d7bdc0e391e14208f605a808ee0869fb805bc84 Mon Sep 17 00:00:00 2001 From: e2hang <2099307493@qq.com> Date: Thu, 11 Sep 2025 12:35:02 +0800 Subject: [PATCH] CartPole --- Pytorch/Project/CartPole/.idea/workspace.xml | 26 +++++----- Pytorch/Project/CartPole/TestPthModule.py | 46 ++++++++++++++++++ .../__pycache__/AddBuffer.cpython-313.pyc | Bin 0 -> 6600 bytes 3 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc diff --git a/Pytorch/Project/CartPole/.idea/workspace.xml b/Pytorch/Project/CartPole/.idea/workspace.xml index 36bc551..365859e 100644 --- a/Pytorch/Project/CartPole/.idea/workspace.xml +++ b/Pytorch/Project/CartPole/.idea/workspace.xml @@ -17,24 +17,26 @@ - + { + "associatedIndex": 1 +} - { + "keyToString": { + "ModuleVcsDetector.initialDetectionPerformed": "true", + "Python.AddBuffer.executor": "Run", + "Python.NoBuffer.executor": "Run", + "Python.TestEnv.executor": "Run", + "Python.TestPthModule.executor": "Run", + "Python.main.executor": "Run", + "RunOnceActivity.ShowReadmeOnStart": "true", + "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" } -}]]> +} diff --git a/Pytorch/Project/CartPole/TestPthModule.py b/Pytorch/Project/CartPole/TestPthModule.py index e69de29..95466e7 100644 --- a/Pytorch/Project/CartPole/TestPthModule.py +++ b/Pytorch/Project/CartPole/TestPthModule.py @@ -0,0 +1,46 @@ +import torch +import gymnasium as gym +import torch.nn as nn +import torch.nn.functional as F + +# 定义和训练时一样的网络结构 +class DQN(nn.Module): + def __init__(self, state_size, action_size): + super(DQN, self).__init__() + self.l1 = nn.Linear(state_size, 128) + self.l2 = nn.Linear(128, 32) + self.l4 = nn.Linear(32, action_size) + + def forward(self, x): + x = F.relu(self.l1(x)) + x = F.relu(self.l2(x)) + x = self.l4(x) + return x + + +# 创建环境 (带渲染) +env = gym.make("CartPole-v1", render_mode="human") + +# 初始化模型并加载权重 +state_size = env.observation_space.shape[0] +action_size = env.action_space.n +model = DQN(state_size, action_size) +model.load_state_dict(torch.load("cartpole_dqn_success.pth")) +model.eval() +print("已加载模型 cartpole_dqn.pth") + +# 测试 +for ep in range(5): # 测试 5 回合 + state, _ = env.reset() + done = False + total_reward = 0 + while not done: + with torch.no_grad(): + action = model(torch.tensor(state, dtype=torch.float32)).argmax().item() + state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + total_reward += reward + env.render() + print(f"Episode {ep+1} reward = {total_reward}") + +env.close() diff --git a/Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc b/Pytorch/Project/CartPole/__pycache__/AddBuffer.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ff3e694c9da0bb9e3d807b62638e223a53998b9 GIT binary patch literal 6600 zcmbt2ZEPFIm9yk7zduBZluX-_XxXt%Ss%4y$&wt|mOmsxw*0{@9XHy9wn)2}z2z{cx0X>Wwl4G7M(xr@Kb&iYbgn-PjXG>rgA= z&Fc}*3ifC-UD{~7%{3U;@f`N?7VPJ(xL#mjl#RFdF*v|GaFBQ62Hu4m1!weI##jmg zue+OBo+l)PqPr=@Xrny_Od%d|0OE}>qMzc7Hu{PIQwX;Rb*)I)D718&YB(GgtkHv% zZnVG^jvg`k6zX0E1Rw9|GvUpkl@Ura+EzsQ@JqaRJ;GasZLrHay32Yr%#3t^?At+h z-WW$AI(izU#BF#7@5Ajp!W}~U_GOmjeA>UndmN~62%%Voj*Eh;3!0CxC`V>QLNDlU3BJY zIC$KRd+;9Ii+2cn2ACDSfX2c|=*8_qAgUN+DI}Pqc?uh?z8~|~?MLVmdcJmCH&DJt z4?_EHj!FzNb>lXnL+I&ngX|r`R^yz2W}GX?KSs`WFUV9cG<3jDg~k8nUZg#vrP}8k=xQ(nS*_8So8Q2P z)}uEWykFQ4^Mr&jzm8wuXAYxDQ)J@;+2Zk(oXW@Jnn{*5%dwO!iC8nGcfxNM{B}oJ zja8)dgyvB4VqS_XsaZ*LiivzGE7v;qcs!976(t^5h$6$`a6#)m+%F`uNl6$^KH?Y|b9BiB!(pW5|r!mBNguSGygDPUeK4>)Nfp2HprWYC$kI zFvnCPF6Q$%b#^*0f&FRjWm8N@>2y3EfusSL6_ODg(08lbyj~GAQ^+>|^gj9@6U%ix zb+Mf9A%aSmS!bI3-_IW189RG&|K!hx|2P}?*?;X{4mJ<>rw$zS5*KC3h!@h5{NMmg zHCg#lZop9iuD1hZ!1wn5*-4GOBAdVQYNl1It$V?0rf<9&w`#Q) z?Ts6U%&u09vx_!SoYC?SVxlcZpF$9@mMxcBS9BAl87&Q+vSrnJE9V;BY;aP@n3p#b zmn&Q0%(hT!%t~?^kcsPnw+(?wCL`O%w^5EWQ_%>8jh4zmWtHvYZIl`%Q>-bJKa67b zXxADdbPf8dd3OFaJ_yO|@y|$plb^{+4~XtC@yU^pW(H>hPDRRO@r>p;E9MhZwJ2-B zqMXcTG>al;a%o93&!%!3r$@1bh(qIa;Xz)KHYl34mNEcbsF$F6Vvaa-BdQ8;q{<(_}xmiDUbw+ch0R~NhQ zyZse+(;auy^|Lpou1zhRFE@9V-Cc#D`~Jp?zx9s4^~bw@+WW)a#iQlTy=8x2;plx& zeZ|v!$J2Zr-?(t?!lL!(_}1*rS+(V0*>k8c{5eGOWlOiZrF%&%Z|N<2`zqc+ z)jL@B9x9C74>XkC{^Ud@uuTnYTV%?Cw#A)lprbHWbqBsbJ%3ScY+quQj@-1Xjs2>- zzcBQbGf+BneadHmYpnT{cu@Z6A!iJNCU3fc)V2yT^aW2K2a? zdIPjNgo%P@ZAWq`1(Lil-a)#~{%}-3+VDVlwk`qLT@JFlc=DXwnafYjwu}L5I;@E2 zrDQmt4XzA^fm@MikYJ%>swP|Haa^zokEnzTVpJ5BTg4yp!S&U*%#>8 zz1sV8RlG3MSz@LR$X{g7s<|976U^d_d7hnR5^X~5Xyp{a7A8z#)B-0IM#tYDFt?xy zCgM7dw}KQJH#MCRWk}VWSdxHD(8J21nM(GbnJeqlOCZn+#@hwT?z`0M7MpC3a-bDnlLdL6mR!d%zl`O`4 z2-X5#kho_SymAH~uAy={5LM%H82r0tR;Du%uV$9y^BPBlPir>1tme*V^I|$qJ+Mit z#frS7F>#G0E?9G|W`Xteho0sze6PUxURJU>ZSAde5~ zrJr%uFjmROK|^vL%fKYSPcg&MfWvmU=eA$oUK&`OR2@BqL)Ab-VGMj|Y2+WWMN_rD z;cDMUeV^>TWGix2Z=lFl8`l?IfABYah$4=#R+rQzYNQ`v}@9<&;!9D`Rq_aVJ5!mr?Pglz?Q!NxNVw*P$)#EQXzqHstn0jxjc(H70DwwP8 zx{AA1b+=aB+g11WMW5>40o)!Rq%x|fx#9_{o^ZkW$bu~P;`-8fncD=bxO{V?mq$zZ z>cx*Pe)8SLU8<|IV6VD71bgLJ!T#TDfC)ZoM-FdMnOk2Nta|-(GnZzry#4M-;Ye|3 z)mc|KT&)ijj{U*ouR48m1D6M`?4N(N?A%ffuB!xJQiCs5g6(Rsy&UYkH1x+{<2-l6 zb`ObV-cZH6b^bd=_Dh>*&V9)Z!e99K!UDG#{pZ&{drfWZtu{94{~;akSZcl1 zcC$@w9H=&KcxY!Eo1P%n-u%#kd>g>y*aGtw)z(~e-LrW}=GS1@5CDVMgjKaF<{2%?ABzXV; literal 0 HcmV?d00001