""" mccfr_trainer.py — MCCFR 自我博弈训练流水线 基于外部采样 (External Sampling MCCFR) 的 Deep CFR 训练器。 核心流程: 1. traverse(state, traversing_player) 递归遍历博弈树, 计算每个信息集的反事实遗憾值 (Counterfactual Regret) 2. 将 [info_state, legal_mask, regrets, strategy] 存入 CFRBuffer 3. 从 Buffer 采样 mini-batch,训练 CFRNetwork 的 Regret Head 和 Policy Head === 多进程并行数据生成 === 阶段 A (Data Generation) 已改造为多进程版本: - 使用 ProcessPoolExecutor(mp_context=spawn) 并发执行 - spawn 启动方式避免 fork 继承 CUDA Context 导致死锁 - 进程池跨 iteration 复用,通过函数参数直接传递权重(内存安全) - 每个 Worker 批量处理多局,大幅减少 IPC 消息数量 - Worker 返回 (utility_list, experience_list),经验以 Python 基本类型传输 === 外部采样 MCCFR 核心逻辑 === MCCFR 是一种近似求解大规模扩展式博弈纳什均衡的算法。 "外部采样" 指的是:在对方回合,根据当前策略采样 1 个动作继续; 在遍历者回合,遍历所有合法动作以精确计算 regret。 设 v(I, a) 为在信息集 I 采取动作 a 后的期望收益, v(I) = sum_a(σ(I,a) * v(I,a)) 为加权期望收益, 则反事实遗憾值: regret(I, a) = v(I, a) - v(I) 当遍历者反复经历相同信息集时,累计 regret 越大的动作 越应该被优先选择——这就是 Regret Matching 的核心直觉。 """ # ── 必须在 import torch 之前,锁死所有 C++ 线性代数库多线程,防止 spawn 模式下 OpenMP 死锁 ── import os os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" import random import sys import os import copy import glob import uuid import atexit import shutil import multiprocessing as mp from typing import List, Tuple, Dict, Any from concurrent.futures import ProcessPoolExecutor, as_completed import torch import torch.nn as nn import torch.nn.functional as F import pyspiel torch.set_num_threads(1) # ── 将 poker/ 根目录加入 sys.path,方便导入同目录模块 ── _POKER_DIR = os.path.dirname(os.path.abspath(__file__)) if _POKER_DIR not in sys.path: sys.path.insert(0, _POKER_DIR) # ── 本地临时目录:将经验落盘文件从系统 /tmp (tmpfs) 移出,防止磁盘空间不足 ── LOCAL_TEMP_DIR = os.path.join(_POKER_DIR, "local_temp_exp") os.makedirs(LOCAL_TEMP_DIR, exist_ok=True) def cleanup_temp_files(): """清空本地临时目录,粉碎上一次遗留的僵尸文件。""" if os.path.exists(LOCAL_TEMP_DIR): try: shutil.rmtree(LOCAL_TEMP_DIR) except: pass atexit.register(cleanup_temp_files) # ── 将项目根目录加入 sys.path,方便导入 card_model 包 ── _PROJECT_DIR = os.path.dirname(_POKER_DIR) if _PROJECT_DIR not in sys.path: sys.path.insert(0, _PROJECT_DIR) from env_adapter import ( HUNL_FULLGAME_STRING, HUNL_SB_BB_STRING, HUNL_BB_SB_STRING, BetTranslator, extract_env_state, CFR_ACTIONS, NUM_CFR_ACTIONS, STACK_NORMALIZE, ) from card_model.config import NUM_BINS, PAD_TOKEN, BOARD_SIZE from card_model.data_generator import extract_cards_from_state from card_model.model import CardModel from cfr_net import CFRNetwork, CARD_DIM, ENV_DIM, NUM_ACTIONS from cfr_buffer import CFRBuffer # ───────────────────── 终极版超参数 (128G RAM + 24G VRAM, 24核) ───────────────────── # ── 大规模并行与迭代 ── NUM_ITERATIONS = 30000 # 总训练迭代次数(Deep CFR: 高频迭代,少局数) GAMES_PER_ITER = 1000 # 每次迭代搜索的局数(快速进入更新阶段) # NUM_WORKERS: 留 2 核给主进程(GPU 训练 + 日志 + OS),避免上下文切换惩罚 # 24 核 - 2 = 22 个 worker,每个 worker 纯 CPU 密集型遍历 NUM_WORKERS = 22 # ── 巨型经验池 (吞噬你的 128G 内存) ── # 1条样本大约 1KB 多。500 万条样本约占 6~8 GB 内存。 # 设为 1000 万,充分利用你的系统内存,避免灾难性遗忘 (Catastrophic Forgetting) BUFFER_MAX_SIZE = 10_000_000 MIN_BUFFER_SIZE_FOR_TRAIN = 200_000 # 初期攒够 2 万条就开始训练 # ── 显存利用 (24G VRAM) ── # CFR 网络极其小 (3层 MLP, ~100K 参数)。BS=32768 时: # Forward 中间激活: ~33 MB, Backward: ~33 MB, Optimizer: ~0.8 MB # 总计 < 100 MB,在 24G VRAM 上毫无压力 # 超大 Batch Size 能让策略收敛平滑,减少策略震荡 TRAIN_BATCH_SIZE = 16384 TRAIN_STEPS_PER_ITER = 64 # 每次迭代后,网络更新 64 步 # ── 优化器细节 ── LEARNING_RATE = 3e-4 # 适度降低 LR,降低 NaN 风险 WEIGHT_DECAY = 1e-4 # 防止过拟合 # 训练 #NUM_ITERATIONS = 100 # 总训练迭代次数 #GAMES_PER_ITER = 50 # 每个 iteration 跑多少局自对弈 #TRAIN_BATCH_SIZE = 256 # 训练 mini-batch 大小 #TRAIN_STEPS_PER_ITER = 20 # 每 iteration 训练多少步 #MIN_BUFFER_SIZE_FOR_TRAIN = 1000 # Buffer 至少有多少条数据才开始训练 #LEARNING_RATE = 1e-3 # AdamW 学习率 #WEIGHT_DECAY = 1e-4 # AdamW 权重衰减 # 归一化常量(与 env_features 的 5 维对应) # STACK_NORMALIZE = 20000.0 已从 env_adapter 导入(与 Botzone 初始筹码对齐) STREET_NORMALIZE = 3.0 # 街道归一化因子 (0/3, 1/3, 2/3, 3/3) CLIP_GRAD_NORM = 1.0 # 梯度裁剪阈值,防止训练不稳定 # Card Model 权重路径(设为 None 则使用随机初始化) CARD_MODEL_CHECKPOINT = "card_model/data/best_card_model.pt" # 设备 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── 检查点配置 ── CHECKPOINT_DIR = os.path.join(_POKER_DIR, "checkpoints") CHECKPOINT_INTERVAL = 50 # 每 50 个 iteration 保存一次 KEEP_LAST_N_CHECKPOINTS = 3 # 本地最多保留最近的 3 个检查点文件,防止磁盘撑爆 # ───────────────────── 辅助函数 ───────────────────── def build_env_features(env_info: dict) -> torch.Tensor: """ 将 extract_env_state 返回的字典归一化为 5 维 env_features Tensor。 5 维内容: [pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position] Args: env_info: extract_env_state() 返回的字典 Returns: [5] 的 FloatTensor """ features = [ env_info["pot"] / STACK_NORMALIZE, env_info["p0_stack"] / STACK_NORMALIZE, env_info["p1_stack"] / STACK_NORMALIZE, env_info["street"] / STREET_NORMALIZE, float(env_info["position"]), ] return torch.tensor(features, dtype=torch.float32) # ── CardModel 结果缓存 ── # 在扑克博弈树中,只要没有发新牌,牌面概率绝对不变。 # 以 (hole_cards_tuple, board_cards_tuple) 为 key 缓存 CardModel 输出, # 避免在同一个发牌阶段(如 Flop)反复推理。 CARD_CACHE: Dict[str, torch.Tensor] = {} def build_card_features( card_model: CardModel, state, ) -> torch.Tensor: """ 使用 CardModel 从当前 state 提取 50 维胜率直方图 (card_features)。 带字典缓存:只要 hole_cards 和 board_cards 不变,直接返回缓存结果。 在每次 iteration 开始前由主循环调用 CARD_CACHE.clear() 防止内存泄漏。 流程: 1. 调用 extract_cards_from_state(state, player_id=current_player) 获取当前行动玩家的手牌 2. 构造 cache_key,命中则直接返回 3. 未命中则构造 Tensor 进行推理(在模型所在设备上进行) 4. 返回 [50] 的 FloatTensor(CPU) Args: card_model: 已加载权重的 CardModel(eval 模式) state: OpenSpiel State Returns: [50] 的 FloatTensor(脱离计算图,在 CPU 上) """ global CARD_CACHE current_player = state.current_player() hole_cards, board_cards = extract_cards_from_state(state, player_id=current_player) # 将卡牌列表转为可哈希的 tuple,大幅降低字符串分配与 hash 运算的 CPU 耗时 cache_key = (tuple(hole_cards), tuple(board_cards)) if cache_key in CARD_CACHE: return CARD_CACHE[cache_key] # 构造输入 Tensor(使用 card_model 当前所在设备) model_device = next(card_model.parameters()).device x_hole = torch.tensor([hole_cards], dtype=torch.int64, device=model_device) # [1, 2] # board_cards 不足 5 张时用 PAD_TOKEN 填充 padded_board = board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards)) x_board = torch.tensor([padded_board], dtype=torch.int64, device=model_device) # [1, 5] with torch.no_grad(): _, pred_histogram = card_model(x_hole, x_board) # [1, 50] # 脱离计算图,转到 CPU,去掉 batch 维度 result = pred_histogram.squeeze(0).cpu() CARD_CACHE[cache_key] = result return result # ───────────────────── 核心:外部采样遍历(单进程版本,保留兼容) ───────────────────── def traverse( state, traversing_player: int, card_model: CardModel, cfr_net: CFRNetwork, buffer: CFRBuffer, translator: BetTranslator, depth: int = 0, ) -> float: """ 外部采样 MCCFR 的博弈树遍历函数(单进程版本)。 递归地遍历博弈树,在遍历者回合计算每个动作的反事实遗憾值 (Regret), 并将经验 [info_state, legal_mask, regrets, strategy] 存入 Buffer。 === 三种节点类型的处理 === 1. Terminal Node (终局): 直接返回 traversing_player 的收益值 state.returns()[traversing_player] 2. Chance Node (发牌): 按概率随机采样一个发牌动作,递归继续 3. Player Node (玩家行动): 提取信息集特征,通过 CFRNetwork 获取当前策略 σ(I,·) 然后分两种情况: a) 对方回合 (current_player != traversing_player): 根据 σ(I,·) 的概率分布采样 1 个动作,执行后递归 → 这是"外部采样"的含义:只采样 1 条路径,大幅减少计算量 b) 遍历者回合 (current_player == traversing_player): 遍历所有合法动作 a: - clone state,执行动作 a,递归得到 v(I, a) 计算加权期望收益: v(I) = Σ_a σ(I,a) * v(I,a) 计算遗憾值: regret(I, a) = v(I,a) - v(I) 将经验存入 Buffer Args: state: OpenSpiel 的 State 对象 traversing_player: 当前遍历者 (0 或 1) card_model: Card Model(eval 模式),用于提取牌面特征 cfr_net: CFR 策略网络(eval 模式),用于获取当前策略 buffer: CFR 经验回放池 translator: BetTranslator,CFR 动作 ↔ 引擎动作转换 Returns: float: 当前节点对 traversing_player 的期望收益 """ # ── 深度熔断保护:超过极限深度,强制截断博弈树 ── if depth >= 40: return 0.0 # ── 1. Terminal Node: 终局,直接返回收益 ── if state.is_terminal(): return float(state.returns()[traversing_player]) # ── 2. Chance Node: 随机发牌 ── if state.is_chance_node(): outcomes = state.chance_outcomes() action_list, prob_list = zip(*outcomes) chance_action = random.choices(action_list, weights=prob_list, k=1)[0] child_state = state.clone() child_state.apply_action(chance_action) utility = traverse(child_state, traversing_player, card_model, cfr_net, buffer, translator, depth + 1) del child_state return utility # ── 3. Player Node: 玩家行动 ── current_player = state.current_player() env_info = extract_env_state(state, translator) env_features = build_env_features(env_info) card_features = build_card_features(card_model, state) legal_mask = env_info["legal_mask"] info_state = torch.cat([card_features, env_features], dim=0) # 使用 cfr_net 当前所在设备(数据生成阶段为 CPU) net_device = next(cfr_net.parameters()).device legal_mask_tensor = torch.tensor( [legal_mask], dtype=torch.float32, device=net_device ) card_input = card_features.unsqueeze(0).to(net_device) env_input = env_features.unsqueeze(0).to(net_device) with torch.no_grad(): current_strategy, _ = cfr_net.get_strategy( card_input, env_input, legal_mask_tensor ) strategy_list = current_strategy.squeeze(0).cpu().tolist() # ── 3a. 对方回合 ── if current_player != traversing_player: legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: legal_indices = [1] probs = [strategy_list[i] for i in legal_indices] prob_sum = sum(probs) if prob_sum > 0: probs = [p / prob_sum for p in probs] else: probs = [1.0 / len(legal_indices)] * len(legal_indices) sampled_idx = random.choices(legal_indices, weights=probs, k=1)[0] engine_action = translator.cfr_to_engine_action(state, sampled_idx) child_state = state.clone() child_state.apply_action(engine_action) utility = traverse(child_state, traversing_player, card_model, cfr_net, buffer, translator, depth + 1) del child_state return utility # ── 3b. 遍历者回合 ── legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: return 0.0 action_utilities = [0.0] * NUM_CFR_ACTIONS for a in legal_indices: child_state = state.clone() engine_action = translator.cfr_to_engine_action(child_state, a) child_state.apply_action(engine_action) action_utilities[a] = traverse( child_state, traversing_player, card_model, cfr_net, buffer, translator, depth + 1 ) del child_state node_utility = 0.0 for a in legal_indices: node_utility += strategy_list[a] * action_utilities[a] regrets = [0.0] * NUM_CFR_ACTIONS for a in legal_indices: regrets[a] = (action_utilities[a] - node_utility) / STACK_NORMALIZE # 裁剪极端 regret 值,防止训练不稳定 MAX_REGRET = 10.0 regrets = [max(-MAX_REGRET, min(MAX_REGRET, r)) for r in regrets] buffer.add( info_state=info_state, legal_mask=legal_mask, regrets=regrets, strategy=strategy_list, ) return node_utility # ───────────────────── 多进程 Worker 基础设施 ───────────────────── # # Worker 进程通过 _init_worker 初始化,在其中自行创建: # - pyspiel 游戏实例(不可跨进程序列化) # - CPU 上的 CFRNetwork(从主进程传入的 state_dict 加载权重) # - CPU 上的 CardModel(从主进程传入的 state_dict 加载权重) # - BetTranslator(轻量无状态对象) # # 权重热更新机制: # 模型参数极小(< 500KB),直接通过 worker_traverse_batch 函数参数传递, # 无需临时文件中转,彻底消除并发读写冲突和 BrokenProcessPool 风险。 # # IPC 安全机制: # - 使用 spawn 启动方式,避免 fork 继承 CUDA Context 导致死锁 # - Worker 返回经验全部使用 Python 基本类型(list[float]/list[int]) # - 每个 Worker 批量处理多局游戏,减少 IPC 消息数量,防止管道阻塞 # Worker 进程的全局状态 _WORKER_STATE: Dict[str, Any] = {} def _init_worker( cfr_state_dict: dict, card_state_dict: dict, ) -> None: """ ProcessPoolExecutor 的 worker 初始化函数。 在每个 worker 进程启动时调用一次,创建该进程专属的: - OpenSpiel 游戏实例 - CPU 上的 CFRNetwork(eval 模式) - CPU 上的 CardModel(eval 模式) - BetTranslator 传入的 state_dict 中所有 Tensor 已通过 .cpu() 确保在 CPU 上, 避免 CUDA 跨进程问题。 权重热更新: 每次 iteration 主进程通过 worker_traverse_batch 的 cfr_state_dict 参数 直接传递最新权重,Worker 在 batch 开始时 load_state_dict 加载, 彻底消除临时文件并发读写冲突。 Args: cfr_state_dict: 主进程 CFRNetwork 的 state_dict(CPU Tensor) card_state_dict: 主进程 CardModel 的 state_dict(CPU Tensor) """ global _WORKER_STATE # 每个 worker 自行创建 OpenSpiel 游戏实例(不可跨进程传递) _WORKER_STATE["games"] = [ pyspiel.load_game(HUNL_SB_BB_STRING), pyspiel.load_game(HUNL_BB_SB_STRING), ] # 在 CPU 上创建 CFRNetwork 并加载主进程的最新权重 cfr_net = CFRNetwork( card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS, ) cfr_net.load_state_dict(cfr_state_dict) cfr_net.eval() _WORKER_STATE["cfr_net"] = cfr_net # 在 CPU 上创建 CardModel 并加载权重 # CardModel 权重在训练期间不会变化,只需初始化时加载一次 card_model = CardModel() card_model.load_state_dict(card_state_dict) card_model.eval() _WORKER_STATE["card_model"] = card_model # BetTranslator 是无状态对象,可以直接创建 _WORKER_STATE["translator"] = BetTranslator() def worker_traverse_batch( game_indices: List[int], cfr_state_dict: dict, ) -> Tuple[List[Tuple[int, float]], str]: """ 多进程版本的批量 MCCFR 遍历函数,在 Worker 进程中执行。 一个 Worker 一次处理多局游戏(由 game_indices 指定), 大幅减少 IPC 消息数量,防止管道被海量小消息撑爆。 权重通过函数参数直接传入(模型 < 500KB),无需临时文件中转, 彻底消除并发读写冲突和 BrokenProcessPool 风险。 经验数据通过临时文件回传,避免 multiprocessing.Pipe 序列化爆炸: Worker 将经验列表 torch.save 到临时 .pt 文件,仅返回文件路径。 主进程 torch.load 读取后立即删除临时文件。 Args: game_indices: 要处理的游戏索引列表(长度约 GAMES_PER_ITER/NUM_WORKERS) cfr_state_dict: 主进程当前 iteration 的 CFRNetwork 权重(CPU Tensor) Returns: game_results: [(game_idx, utility), ...] 每局的索引和收益 exp_file_path: 经验临时文件路径,主进程需 torch.load 读取后删除 """ global _WORKER_STATE global CARD_CACHE # ── 清除本 Worker 上个批次的卡牌缓存,彻底封杀内存泄露! ── CARD_CACHE.clear() # ── 直接通过参数加载最新权重,无需文件中转 ── _WORKER_STATE["cfr_net"].load_state_dict(cfr_state_dict) cfr_net = _WORKER_STATE["cfr_net"] card_model = _WORKER_STATE["card_model"] translator = _WORKER_STATE["translator"] game_results: List[Tuple[int, float]] = [] all_experiences: List[Tuple] = [] for game_idx in game_indices: traversing_player = game_idx % 2 # 每局随机决定谁做小盲注 (Button),消除位置过拟合 game = random.choice(_WORKER_STATE["games"]) state = game.new_initial_state() experiences: List[Tuple] = [] utility = _traverse_worker( state, traversing_player, card_model, cfr_net, experiences, translator, ) game_results.append((game_idx, utility)) all_experiences.extend(experiences) #print(f"Worker 跑完了一局! (Game {game_idx})", flush=True) # ── 将经验落盘,避免 IPC 管道序列化爆炸 ── # 直接 return all_experiences 会导致 multiprocessing.Pipe 被撑爆 # (单次可达数百 MB 的 Python list 序列化数据),引发 OS OOM 和 BrokenProcessPool。 # 改用临时文件:torch.save 落盘 → 仅传输文件路径字符串 → 主进程读取后删除。 # 进一步优化:Worker 端直接转为 Tensor dict 再落盘,主进程 zero-copy 加载。 exp_file_path = os.path.join( LOCAL_TEMP_DIR, f"cfr_exp_{uuid.uuid4().hex}.pt", ) if all_experiences: info_states_batch, legal_masks_batch, regrets_batch, strategies_batch = zip(*all_experiences) tensor_dict = { "info": torch.tensor(info_states_batch, dtype=torch.float32), "legal": torch.tensor(legal_masks_batch, dtype=torch.float32), "regrets": torch.tensor(regrets_batch, dtype=torch.float32), "strats": torch.tensor(strategies_batch, dtype=torch.float32), } torch.save(tensor_dict, exp_file_path) else: # 如果空经验,存一个空字典 torch.save({}, exp_file_path) del all_experiences # 立即释放 Worker 内存,避免文件缓存双倍占用 return game_results, exp_file_path def _traverse_worker( state, traversing_player: int, card_model: CardModel, cfr_net: CFRNetwork, experiences: List[Tuple], translator: BetTranslator, depth: int = 0, ) -> float: """ Worker 进程内部的递归遍历函数。 与 traverse() 逻辑完全一致,唯一区别是将经验追加到 本地 experiences 列表(而非全局 CFRBuffer),且存入的 info_state 使用 list[float] 而非 Tensor,以便跨进程传输。 内存安全保证: - 所有 Tensor 操作在 torch.no_grad() 下进行 - info_state_tensor.tolist() 后,原始 Tensor 即可被 GC 回收 - 递归过程中 state.clone() 产生的对象在递归返回后也被 GC - 最终返回的 experiences 只包含 Python 基本类型, 主进程 torch.tensor() 新建时无 grad_fn,buffer.add() 安全 设备说明: Worker 中所有模型始终在 CPU 上。Tensor 设备由模型参数所在 设备自动推导(next(cfr_net.parameters()).device),无需显式传递。 Args: state: OpenSpiel State traversing_player: 当前遍历者 card_model: CardModel(eval,CPU) cfr_net: CFRNetwork(eval,CPU) experiences: 本地经验暂存列表 translator: BetTranslator Returns: float: 当前节点对 traversing_player 的期望收益 """ # ── 深度熔断保护:超过极限深度,强制截断博弈树 ── if depth >= 40: return 0.0 # ── 1. Terminal Node ── if state.is_terminal(): return float(state.returns()[traversing_player]) # ── 2. Chance Node ── if state.is_chance_node(): outcomes = state.chance_outcomes() action_list, prob_list = zip(*outcomes) chance_action = random.choices(action_list, weights=prob_list, k=1)[0] child_state = state.clone() child_state.apply_action(chance_action) utility = _traverse_worker(child_state, traversing_player, card_model, cfr_net, experiences, translator, depth + 1) del child_state return utility # ── 3. Player Node ── current_player = state.current_player() env_info = extract_env_state(state, translator) env_features = build_env_features(env_info) card_features = build_card_features(card_model, state) legal_mask = env_info["legal_mask"] info_state_tensor = torch.cat([card_features, env_features], dim=0) # 使用 cfr_net 当前所在设备(Worker 中始终为 CPU) net_device = next(cfr_net.parameters()).device legal_mask_tensor = torch.tensor( [legal_mask], dtype=torch.float32, device=net_device ) card_input = card_features.unsqueeze(0).to(net_device) env_input = env_features.unsqueeze(0).to(net_device) with torch.no_grad(): current_strategy, _ = cfr_net.get_strategy( card_input, env_input, legal_mask_tensor ) strategy_list = current_strategy.squeeze(0).cpu().tolist() # ── 3a. 对方回合 ── if current_player != traversing_player: legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: legal_indices = [1] probs = [strategy_list[i] for i in legal_indices] prob_sum = sum(probs) if prob_sum > 0: probs = [p / prob_sum for p in probs] else: probs = [1.0 / len(legal_indices)] * len(legal_indices) sampled_idx = random.choices(legal_indices, weights=probs, k=1)[0] engine_action = translator.cfr_to_engine_action(state, sampled_idx) child_state = state.clone() child_state.apply_action(engine_action) utility = _traverse_worker(child_state, traversing_player, card_model, cfr_net, experiences, translator, depth + 1) del child_state return utility # ── 3b. 遍历者回合 ── legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: return 0.0 action_utilities = [0.0] * NUM_CFR_ACTIONS for a in legal_indices: child_state = state.clone() engine_action = translator.cfr_to_engine_action(child_state, a) child_state.apply_action(engine_action) action_utilities[a] = _traverse_worker( child_state, traversing_player, card_model, cfr_net, experiences, translator, depth + 1, ) del child_state node_utility = 0.0 for a in legal_indices: node_utility += strategy_list[a] * action_utilities[a] regrets = [0.0] * NUM_CFR_ACTIONS for a in legal_indices: regrets[a] = (action_utilities[a] - node_utility) / STACK_NORMALIZE # 裁剪极端 regret 值,防止训练不稳定 MAX_REGRET = 10.0 regrets = [max(-MAX_REGRET, min(MAX_REGRET, r)) for r in regrets] # 存入本地经验列表(info_state 转为 list[float] 以便跨进程传输) # .tolist() 后 info_state_tensor 可被 GC 回收,不会造成内存累积 info_state_list = info_state_tensor.tolist() experiences.append((info_state_list, legal_mask, regrets, strategy_list)) return node_utility # ───────────────────── 检查点保存与恢复 ───────────────────── def save_checkpoint( iteration: int, cfr_net: CFRNetwork, buffer: CFRBuffer, optimizer, checkpoint_dir: str, keep_last_n: int, ) -> None: """ 原子保存检查点,并清理旧文件。 原子操作流程: 1. 先写入临时文件 ckpt_tmp.pt 2. 写入成功后 os.replace 重命名为 ckpt_iter_{iteration}.pt 3. 复制一份 latest_ckpt.pt 作为默认恢复入口 4. 清理超过 keep_last_n 的旧检查点 Args: iteration: 当前迭代次数 cfr_net: CFR 策略网络 buffer: CFR 经验回放池 optimizer: AdamW 优化器(可能为 None) checkpoint_dir: 检查点保存目录 keep_last_n: 保留最近 N 个检查点 """ os.makedirs(checkpoint_dir, exist_ok=True) # 构造保存字典 save_dict = { "iteration": iteration, "model_state_dict": {k: v.cpu() for k, v in cfr_net.state_dict().items()}, "buffer_state_dict": buffer.state_dict(), } if optimizer is not None: # 优化器状态中的 Tensor 可能在 GPU 上,需转 CPU opt_state = optimizer.state_dict() opt_state_cpu = {} for k, v in opt_state.items(): if isinstance(v, torch.Tensor): opt_state_cpu[k] = v.cpu() elif isinstance(v, dict): opt_state_cpu[k] = _recursive_tensor_to_cpu(v) else: opt_state_cpu[k] = v save_dict["optimizer_state_dict"] = opt_state_cpu # 原子保存:先写临时文件,再 rename ckpt_path = os.path.join(checkpoint_dir, f"ckpt_iter_{iteration}.pt") tmp_path = os.path.join(checkpoint_dir, "ckpt_tmp.pt") torch.save(save_dict, tmp_path) os.replace(tmp_path, ckpt_path) # 更新 latest_ckpt.pt(用 copy 而非 symlink,更跨平台) latest_path = os.path.join(checkpoint_dir, "latest_ckpt.pt") # os.replace 是原子的,先用临时文件再替换 tmp_latest = os.path.join(checkpoint_dir, "latest_ckpt_tmp.pt") torch.save(save_dict, tmp_latest) os.replace(tmp_latest, latest_path) print(f"[Checkpoint] 已保存: iter={iteration}, buffer={len(buffer)}") # 清理旧检查点(保留最近 keep_last_n 个) _cleanup_old_checkpoints(checkpoint_dir, keep_last_n) def _recursive_tensor_to_cpu(d: dict) -> dict: """递归地将字典中所有 Tensor 转到 CPU。""" out = {} for k, v in d.items(): if isinstance(v, torch.Tensor): out[k] = v.cpu() elif isinstance(v, dict): out[k] = _recursive_tensor_to_cpu(v) elif isinstance(v, list): out[k] = [ item.cpu() if isinstance(item, torch.Tensor) else item for item in v ] else: out[k] = v return out def _cleanup_old_checkpoints(checkpoint_dir: str, keep_last_n: int) -> None: """删除旧的检查点文件,只保留最近 keep_last_n 个。""" pattern = os.path.join(checkpoint_dir, "ckpt_iter_*.pt") files = glob.glob(pattern) if len(files) <= keep_last_n: return # 按文件名中的 iteration 数字排序 def _extract_iter(path: str) -> int: basename = os.path.basename(path) # ckpt_iter_123.pt -> 123 num = basename.replace("ckpt_iter_", "").replace(".pt", "") try: return int(num) except ValueError: return -1 files_with_iter = [(f, _extract_iter(f)) for f in files if _extract_iter(f) >= 0] files_with_iter.sort(key=lambda x: x[1]) # 删除多余的旧文件 to_delete = files_with_iter[:-keep_last_n] for path, it in to_delete: os.remove(path) print(f"[Checkpoint] 已删除旧检查点: iter={it}") def _optimizer_state_to_device(optimizer: torch.optim.Optimizer, device: torch.device) -> None: """将优化器状态中所有 Tensor 移到指定设备。 PyTorch 的 optimizer.load_state_dict() 不会自动将 Tensor 移到 当前参数所在的设备,需要手动处理。这对于从 CPU 检查点恢复到 GPU 上的场景至关重要。 Args: optimizer: PyTorch 优化器 device: 目标设备 """ for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) def load_checkpoint( checkpoint_path: str, cfr_net: CFRNetwork, buffer: CFRBuffer, ) -> dict: """ 从检查点文件恢复训练状态。 Args: checkpoint_path: 检查点文件路径 cfr_net: CFR 策略网络(将被加载权重) buffer: CFR 经验回放池(将被加载数据) Returns: dict: 包含 iteration 和 optimizer_state_dict(如果有)的字典 """ print(f"[Checkpoint] 正在加载: {checkpoint_path}") ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) # 恢复模型权重 cfr_net.load_state_dict(ckpt["model_state_dict"]) # 恢复 Buffer buffer.load_state_dict(ckpt["buffer_state_dict"]) result = { "iteration": ckpt.get("iteration", 0), } if "optimizer_state_dict" in ckpt: result["optimizer_state_dict"] = ckpt["optimizer_state_dict"] print(f"[Checkpoint] 恢复成功: iter={result['iteration']}, buffer={len(buffer)}") return result # ───────────────────── 训练一步 ───────────────────── def train_step( cfr_net: CFRNetwork, buffer: CFRBuffer, optimizer: torch.optim.Optimizer, device: torch.device, ) -> Tuple[float, float, float]: """ 从 Buffer 采样一个 mini-batch,训练 CFRNetwork 一步。 Loss 组成: 1. Regret Loss: MSE(预测 regret, 目标 regret),只在 legal_mask=1 的位置计算 2. Policy Loss: MSE(预测策略, 目标策略),拟合 Buffer 中存储的策略分布 总 Loss = Regret Loss + Policy Loss Args: cfr_net: CFR 策略网络 buffer: CFR 经验回放池 optimizer: AdamW 优化器 device: 运行设备 Returns: (total_loss, regret_loss, policy_loss): 三个 float 标量 """ cfr_net.train() info_states, legal_masks, target_regrets, target_strategies = buffer.sample( TRAIN_BATCH_SIZE ) info_states = info_states.to(device) legal_masks = legal_masks.to(device) target_regrets = target_regrets.to(device) target_strategies = target_strategies.to(device) # ── 防御性检查:跳过含 NaN/Inf 的 target 数据 ── if torch.isnan(target_regrets).any() or torch.isnan(target_strategies).any(): return 0.0, 0.0, 0.0 if torch.isinf(target_regrets).any() or torch.isinf(target_strategies).any(): return 0.0, 0.0, 0.0 # ── 裁剪极端目标 regret 值,防止离群标签导致训练不稳定 ── target_regrets = torch.clamp(target_regrets, -10.0, 10.0) card_features = info_states[:, :CARD_DIM] env_features = info_states[:, CARD_DIM:] pred_regrets, pred_policy_logits = cfr_net(card_features, env_features) # ── Loss 1: Regret Loss ── # 注意: 不对 pred_regrets 裁剪,保留完整梯度信号让网络纠正极端输出; # target_regrets 已裁剪防爆,且 CLIP_GRAD_NORM 防止梯度爆炸 regret_diff = F.smooth_l1_loss(pred_regrets, target_regrets, reduction='none') regret_loss = (regret_diff * legal_masks).sum() / legal_masks.sum().clamp(min=1.0) # ── Loss 2: Policy Loss ── all_illegal = (legal_masks.sum(dim=-1, keepdim=True) == 0) # 创建安全 mask,全非法的行强制给一个合法位,防止 softmax 崩溃 safe_mask = torch.where(all_illegal, torch.ones_like(legal_masks), legal_masks) masked_logits = pred_policy_logits.masked_fill(safe_mask == 0, -1e9) pred_strategy = F.softmax(masked_logits, dim=-1) num_legal = legal_masks.sum(dim=-1, keepdim=True).clamp(min=1.0) uniform = legal_masks / num_legal pred_strategy = torch.where(all_illegal, uniform, pred_strategy) policy_diff = F.smooth_l1_loss(pred_strategy, target_strategies, reduction='none') policy_loss = (policy_diff * legal_masks).sum() / legal_masks.sum().clamp(min=1.0) total_loss = regret_loss + policy_loss # ── Loss NaN/Inf 检查:若 loss 异常则跳过本步更新 ── if torch.isnan(total_loss) or torch.isinf(total_loss): return 0.0, 0.0, 0.0 optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(cfr_net.parameters(), max_norm=CLIP_GRAD_NORM) # ── 梯度 NaN 检查:若梯度含 NaN 则清除并跳过 ── has_nan_grad = any( p.grad is not None and torch.isnan(p.grad).any() for p in cfr_net.parameters() ) if has_nan_grad: optimizer.zero_grad() return 0.0, 0.0, 0.0 optimizer.step() # ── 参数 NaN 急救:若更新后参数含 NaN 则就地修复 ── has_nan_param = any(torch.isnan(p).any() for p in cfr_net.parameters()) if has_nan_param: with torch.no_grad(): for p in cfr_net.parameters(): p.nan_to_num_(nan=0.0, posinf=1e6, neginf=-1e6) return 0.0, 0.0, 0.0 return ( total_loss.item(), regret_loss.item(), policy_loss.item(), ) # ───────────────────── 主训练循环 ───────────────────── def main(): """ MCCFR 训练主循环(多进程版本)。 每个 iteration 分两个阶段: 阶段 A (Data Generation) — 多进程并行: - 使用 ProcessPoolExecutor(mp_context=spawn) 并发执行 - 进程池跨 iteration 复用,通过共享内存 dict 热更新权重 - 每个 Worker 批量处理多局,减少 IPC 消息数量 - 主进程汇总所有 Worker 的经验到全局 CFRBuffer 阶段 B (Network Training) — 单进程 GPU: - 如果 Buffer 中数据量足够,进行 mini-batch 训练 - 更新 Regret Head 和 Policy Head 日志输出: - 每 iteration 打印 Buffer 大小、Loss 变化 """ # ── 强制使用 spawn 启动方式,防止 fork 继承 CUDA Context 死锁 ── # Linux 默认 fork,fork 后子进程继承父进程的 CUDA Context, # 但 CUDA 运行时内部锁状态不一致,任何后续 CUDA 操作都会永久死锁。 # spawn 从零启动 Python 解释器,完全隔离 CUDA 状态。 # 必须在 if __name__ == "__main__" 保护内、任何 CUDA 操作之前调用。 mp.set_start_method('spawn', force=True) # ── 启动前彻底粉碎上次遗留的僵尸临时文件 ── cleanup_temp_files() os.makedirs(LOCAL_TEMP_DIR, exist_ok=True) print("=" * 70) print(" MCCFR Self-Play Training Pipeline (Multi-Process)") print(" External Sampling + Deep CFR") print(f" NUM_WORKERS = {NUM_WORKERS}") print(f" Start method = spawn") print("=" * 70) # ── 1. 初始化 OpenSpiel 环境 ── game = pyspiel.load_game(HUNL_FULLGAME_STRING) translator = BetTranslator() print(f"\n[环境] 游戏加载成功: {HUNL_FULLGAME_STRING}") # ── 2. 初始化 Card Model(先放 CPU,数据生成阶段用 CPU 推理) ── card_model = CardModel() # 初始在 CPU if CARD_MODEL_CHECKPOINT and os.path.isfile(CARD_MODEL_CHECKPOINT): ckpt = torch.load(CARD_MODEL_CHECKPOINT, map_location="cpu", weights_only=False) card_model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt) print(f"[CardModel] 已加载权重: {CARD_MODEL_CHECKPOINT}") else: print(f"[CardModel] 使用随机初始化(未提供 checkpoint)") card_model.eval() # ── 3. 初始化 CFR Network(先放 CPU,数据生成阶段用 CPU 推理) ── cfr_net = CFRNetwork( card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS, ) # 初始在 CPU total_params = sum(p.numel() for p in cfr_net.parameters()) print(f"[CFRNetwork] 参数量: {total_params:,},初始设备: CPU") # ── 4. 初始化 Buffer ── buffer = CFRBuffer(max_size=BUFFER_MAX_SIZE) print(f"[Buffer] 容量: {BUFFER_MAX_SIZE:,}") # Optimizer 延迟到首次阶段 B 时创建(此时 cfr_net 已移至 GPU) # 但如果从检查点恢复且存档中有 optimizer_state_dict,则提前初始化 optimizer = None # ── NaN 恢复计数器 ── consecutive_nan_count = 0 MAX_CONSECUTIVE_NAN = 10 # 连续 NaN 步数阈值,超过则从 checkpoint 恢复 # ── 波谷捕获器 ── ema_loss = None # EMA 平滑 Loss,消除批次噪音 best_valley_loss = float('inf') # 当前区间内最深波谷的 EMA Loss best_valley_state_dict = None # 最深波谷对应的模型权重(CPU) VALLEY_SAVE_INTERVAL = 500 # 每 500 轮落盘一次最深波谷 # ── 检查点恢复 ── # 检查 latest_ckpt.pt 是否存在,存在则自动恢复 start_iteration = 1 latest_ckpt_path = os.path.join(CHECKPOINT_DIR, "latest_ckpt.pt") if os.path.isfile(latest_ckpt_path): ckpt_info = load_checkpoint(latest_ckpt_path, cfr_net, buffer) start_iteration = ckpt_info["iteration"] + 1 # 如果存档中有 optimizer 状态,必须现在就初始化并恢复 # 原因:optimizer 的动量缓冲区绑定在参数对象上, # 必须在 cfr_net 移到 GPU 后、任何训练之前加载 if "optimizer_state_dict" in ckpt_info: # 先将网络移至 GPU(optimizer 的参数必须在正确设备上) cfr_net.to(DEVICE) optimizer = torch.optim.AdamW( cfr_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, ) # 加载优化器状态(可能包含 GPU 上的 Tensor,需映射到正确设备) optimizer.load_state_dict(ckpt_info["optimizer_state_dict"]) # 将优化器状态中的所有 Tensor 移到与参数相同的设备 _optimizer_state_to_device(optimizer, DEVICE) # 网络移回 CPU,等阶段 A 再用 cfr_net.to("cpu") print(f"[Optimizer] 已从检查点恢复优化器状态, device={DEVICE}") print(f"[Checkpoint] 将从 Iteration {start_iteration} 继续训练") else: print("[Checkpoint] 未找到检查点,从头开始训练") # ── 5. 初始化 CPU state_dict(给 Worker 首次初始化用) ── # 必须在创建进程池之前转为 CPU,避免 CUDA Tensor 跨进程传输 cfr_state_dict_cpu = {k: v.cpu() for k, v in cfr_net.state_dict().items()} card_state_dict_cpu = {k: v.cpu() for k, v in card_model.state_dict().items()} # ── 6. 创建进程池(跨 iteration 复用,避免重复初始化开销) ── # mp_context=spawn 确保子进程不继承 CUDA Context # 使用 spawn context 创建 ProcessPoolExecutor,而非全局 set_start_method spawn_ctx = mp.get_context('spawn') print(f"\n{'='*70}") print(f" 开始训练: {NUM_ITERATIONS} iterations, " f"每 iter {GAMES_PER_ITER} 局自对弈, " f"{NUM_WORKERS} workers 并行") print(f"{'='*70}") # 将 GAMES_PER_ITER 局均匀分配给 NUM_WORKERS 个 Worker # 每个 Worker 一次处理一批游戏,减少 IPC 消息数量 game_indices_all = list(range(GAMES_PER_ITER)) # 计算每个 Worker 分配的游戏索引 chunks: List[List[int]] = [[] for _ in range(NUM_WORKERS)] for i, gi in enumerate(game_indices_all): chunks[i % NUM_WORKERS].append(gi) # 外层循环:每 CHECKPOINT_INTERVAL (50) 轮重启一次进程池,彻底释放 OS 内存碎片 for chunk_start in range(start_iteration, NUM_ITERATIONS + 1, CHECKPOINT_INTERVAL): chunk_end = min(chunk_start + CHECKPOINT_INTERVAL, NUM_ITERATIONS + 1) with ProcessPoolExecutor( max_workers=NUM_WORKERS, initializer=_init_worker, initargs=(cfr_state_dict_cpu, card_state_dict_cpu), mp_context=spawn_ctx, ) as executor: # 内层循环:执行这 50 轮的逻辑 for iteration in range(chunk_start, chunk_end): # ──────────── 阶段 A: 数据生成 — 多进程 ──────────── # 动态设备切换:数据生成阶段,模型放 CPU,避免 Batch=1 的 GPU PCIe 延迟 card_model.to("cpu") cfr_net.to("cpu") cfr_net.eval() # 清除 CardModel 缓存,防止内存泄漏 CARD_CACHE.clear() # 获取当前网络权重,直接通过函数参数传递给 Worker(无需临时文件) cfr_state_dict_cpu = {k: v.cpu() for k, v in cfr_net.state_dict().items()} # 提交批量任务(每个 Worker 处理一批游戏,权重直接参数传递) futures = {} for worker_id, chunk in enumerate(chunks): if not chunk: continue future = executor.submit(worker_traverse_batch, chunk, cfr_state_dict_cpu) futures[future] = worker_id # 收集结果并汇总经验 game_results = [] total_experiences = 0 for future in as_completed(futures): worker_id = futures[future] try: worker_game_results, exp_file_path = future.result() except Exception as e: print(f"[警告] Worker {worker_id} 执行失败: {e}") continue # 记录对局收益 for game_idx, utility in worker_game_results: traversing_player = game_idx % 2 game_results.append((traversing_player, utility)) # 从临时文件加载经验(避免 IPC 管道爆炸) try: tensor_dict = torch.load(exp_file_path, map_location="cpu", weights_only=False) except Exception as e: print(f"[警告] Worker {worker_id} 经验文件加载失败: {e}") continue finally: # 无论加载成功与否,立即删除临时文件,防止 /tmp 堆积 if os.path.isfile(exp_file_path): os.remove(exp_file_path) # 将 worker 返回的经验 Tensor 直接写入全局 Buffer(zero-copy) if tensor_dict: buffer.add_batch_tensors( tensor_dict["info"], tensor_dict["legal"], tensor_dict["regrets"], tensor_dict["strats"] ) total_experiences += tensor_dict["info"].shape[0] del tensor_dict # 立即释放,避免主进程内存峰值 # 统计本 iteration 数据生成的概况 p0_utils = [u for p, u in game_results if p == 0] p1_utils = [u for p, u in game_results if p == 1] avg_p0 = sum(p0_utils) / len(p0_utils) if p0_utils else 0.0 avg_p1 = sum(p1_utils) / len(p1_utils) if p1_utils else 0.0 # ──────────── 阶段 B: 网络训练 — GPU ──────────── # 动态设备切换:训练阶段,cfr_net 移至 GPU 利用大 Batch 高效训练 cfr_net.to(DEVICE) # 首次进入训练阶段时创建 optimizer(此时参数已在 GPU 上) if optimizer is None: optimizer = torch.optim.AdamW( cfr_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, ) print(f"[Optimizer] AdamW lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}, device={DEVICE}") avg_total_loss = 0.0 avg_regret_loss = 0.0 avg_policy_loss = 0.0 trained_steps = 0 if len(buffer) >= MIN_BUFFER_SIZE_FOR_TRAIN: for step in range(TRAIN_STEPS_PER_ITER): total_loss, regret_loss, policy_loss = train_step( cfr_net, buffer, optimizer, DEVICE ) # ── NaN 恢复:统计连续 NaN 步数 ── if total_loss == 0.0 and regret_loss == 0.0 and policy_loss == 0.0: consecutive_nan_count += 1 else: consecutive_nan_count = 0 # 连续 NaN 超过阈值,从最近 checkpoint 恢复 if consecutive_nan_count >= MAX_CONSECUTIVE_NAN: print(f"[NaN 恢复] 连续 {consecutive_nan_count} 步 NaN,从检查点恢复...") latest_ckpt = os.path.join(CHECKPOINT_DIR, "latest_ckpt.pt") if os.path.isfile(latest_ckpt): ckpt_info = load_checkpoint(latest_ckpt, cfr_net, buffer) cfr_net.to(DEVICE) # 重新创建 optimizer(旧 optimizer 动量已被 NaN 污染) optimizer = torch.optim.AdamW( cfr_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, ) # 不恢复 optimizer 状态,让它重新积累动量 consecutive_nan_count = 0 print(f"[NaN 恢复] 已从检查点恢复 (iter={ckpt_info['iteration']}),optimizer 已重置") else: print("[NaN 恢复] 未找到检查点,重置网络参数和 optimizer") for p in cfr_net.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) else: nn.init.zeros_(p) optimizer = torch.optim.AdamW( cfr_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, ) consecutive_nan_count = 0 # 恢复后跳过本 iteration 剩余训练步 break avg_total_loss += total_loss avg_regret_loss += regret_loss avg_policy_loss += policy_loss trained_steps += 1 avg_total_loss /= trained_steps avg_regret_loss /= trained_steps avg_policy_loss /= trained_steps # ── 波谷捕获:EMA 平滑 + 记录最深波谷 ── if trained_steps > 0: if ema_loss is None: ema_loss = avg_total_loss else: ema_loss = 0.9 * ema_loss + 0.1 * avg_total_loss if ema_loss < best_valley_loss: best_valley_loss = ema_loss best_valley_state_dict = copy.deepcopy( {k: v.cpu() for k, v in cfr_net.state_dict().items()} ) # ──────────── 打印日志 ──────────── buffer_size = len(buffer) loss_str = ( f"Total={avg_total_loss:.6f} " f"Regret={avg_regret_loss:.6f} " f"Policy={avg_policy_loss:.6f}" if trained_steps > 0 else "训练未开始(Buffer 不足)" ) print( f"[Iter {iteration:3d}/{NUM_ITERATIONS}] " f"Buffer={buffer_size:6d} " f"Exp={total_experiences:6d} " f"P0_util={avg_p0:+.1f} P1_util={avg_p1:+.1f} " f"Loss: {loss_str}" ) # ──────────── 定时保存检查点 ──────────── if iteration % CHECKPOINT_INTERVAL == 0: # 保存时网络在 GPU 上,save_checkpoint 会将权重转到 CPU save_checkpoint( iteration=iteration, cfr_net=cfr_net, buffer=buffer, optimizer=optimizer, checkpoint_dir=CHECKPOINT_DIR, keep_last_n=KEEP_LAST_N_CHECKPOINTS, ) # ──────────── 波谷落盘 ──────────── if iteration % VALLEY_SAVE_INTERVAL == 0 and best_valley_state_dict is not None: valley_path = os.path.join(CHECKPOINT_DIR, f"ckpt_valley_iter_{iteration}.pt") try: torch.save(best_valley_state_dict, valley_path) print(f"[波谷捕获] Iter {iteration}: EMA Loss={best_valley_loss:.6f}, " f"已保存到 {valley_path}") except Exception as e: print(f"[波谷捕获] Iter {iteration}: 保存失败: {e}") # 重置,在下一个区间内寻找新波谷 best_valley_loss = float('inf') best_valley_state_dict = None # ── 8. 训练结束,保存最终检查点 ── save_checkpoint( iteration=NUM_ITERATIONS, cfr_net=cfr_net, buffer=buffer, optimizer=optimizer, checkpoint_dir=CHECKPOINT_DIR, keep_last_n=KEEP_LAST_N_CHECKPOINTS, ) # 同时保存一份到根目录(向后兼容) save_path = os.path.join(_POKER_DIR, "cfr_net_checkpoint.pt") torch.save({ "model_state_dict": {k: v.cpu() for k, v in cfr_net.state_dict().items()}, "optimizer_state_dict": optimizer.state_dict() if optimizer else {}, "iteration": NUM_ITERATIONS, "buffer_size": len(buffer), }, save_path) print(f"[保存] 模型已保存到: {save_path}") print(f"\n{'='*70}") print(f" 训练完成!共 {NUM_ITERATIONS} iterations") print(f" 最终 Buffer 大小: {len(buffer)}") print(f"{'='*70}") # ───────────────────── 入口 ───────────────────── if __name__ == "__main__": main()