""" mccfr_trainer.py — MCCFR 自我博弈训练流水线 基于外部采样 (External Sampling MCCFR) 的 Deep CFR 训练器。 核心流程: 1. traverse(state, traversing_player) 递归遍历博弈树, 计算每个信息集的反事实遗憾值 (Counterfactual Regret) 2. 将 [info_state, legal_mask, regrets, strategy] 存入 CFRBuffer 3. 从 Buffer 采样 mini-batch,训练 CFRNetwork 的 Regret Head 和 Policy Head === 多进程并行数据生成 === 阶段 A (Data Generation) 已改造为多进程版本: - 使用 ProcessPoolExecutor(mp_context=spawn) 并发执行 - spawn 启动方式避免 fork 继承 CUDA Context 导致死锁 - 进程池跨 iteration 复用,通过函数参数直接传递权重(内存安全) - 每个 Worker 批量处理多局,大幅减少 IPC 消息数量 - Worker 返回 (utility_list, experience_list),经验以 Python 基本类型传输 === 外部采样 MCCFR 核心逻辑 === MCCFR 是一种近似求解大规模扩展式博弈纳什均衡的算法。 "外部采样" 指的是:在对方回合,根据当前策略采样 1 个动作继续; 在遍历者回合,遍历所有合法动作以精确计算 regret。 设 v(I, a) 为在信息集 I 采取动作 a 后的期望收益, v(I) = sum_a(σ(I,a) * v(I,a)) 为加权期望收益, 则反事实遗憾值: regret(I, a) = v(I, a) - v(I) 当遍历者反复经历相同信息集时,累计 regret 越大的动作 越应该被优先选择——这就是 Regret Matching 的核心直觉。 """ # ── 必须在 import torch 之前,锁死所有 C++ 线性代数库多线程,防止 spawn 模式下 OpenMP 死锁 ── import os os.environ["OMP_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["OPENBLAS_NUM_THREADS"] = "1" import random import sys import os import glob import tempfile import uuid import multiprocessing as mp from typing import List, Tuple, Dict, Any from concurrent.futures import ProcessPoolExecutor, as_completed import torch import torch.nn as nn import torch.nn.functional as F import pyspiel torch.set_num_threads(1) # ── 将 poker/ 根目录加入 sys.path,方便导入同目录模块 ── _POKER_DIR = os.path.dirname(os.path.abspath(__file__)) if _POKER_DIR not in sys.path: sys.path.insert(0, _POKER_DIR) # ── 将项目根目录加入 sys.path,方便导入 card_model 包 ── _PROJECT_DIR = os.path.dirname(_POKER_DIR) if _PROJECT_DIR not in sys.path: sys.path.insert(0, _PROJECT_DIR) from env_adapter import ( HUNL_FULLGAME_STRING, BetTranslator, extract_env_state, CFR_ACTIONS, NUM_CFR_ACTIONS, STACK_NORMALIZE, ) from card_model.config import NUM_BINS, PAD_TOKEN, BOARD_SIZE from card_model.data_generator import extract_cards_from_state from card_model.model import CardModel from cfr_net import CFRNetwork, CARD_DIM, ENV_DIM, NUM_ACTIONS from cfr_buffer import CFRBuffer # ───────────────────── 终极版超参数 (128G RAM + 24G VRAM, 24核) ───────────────────── # ── 大规模并行与迭代 ── NUM_ITERATIONS = 30000 # 总训练迭代次数(Deep CFR: 高频迭代,少局数) GAMES_PER_ITER = 1000 # 每次迭代搜索的局数(快速进入更新阶段) # NUM_WORKERS: 留 2 核给主进程(GPU 训练 + 日志 + OS),避免上下文切换惩罚 # 24 核 - 2 = 22 个 worker,每个 worker 纯 CPU 密集型遍历 NUM_WORKERS = 22 # ── 巨型经验池 (吞噬你的 128G 内存) ── # 1条样本大约 1KB 多。500 万条样本约占 6~8 GB 内存。 # 设为 1000 万,充分利用你的系统内存,避免灾难性遗忘 (Catastrophic Forgetting) BUFFER_MAX_SIZE = 10_000_000 MIN_BUFFER_SIZE_FOR_TRAIN = 200_000 # 初期攒够 2 万条就开始训练 # ── 显存利用 (24G VRAM) ── # CFR 网络极其小 (3层 MLP, ~100K 参数)。BS=32768 时: # Forward 中间激活: ~33 MB, Backward: ~33 MB, Optimizer: ~0.8 MB # 总计 < 100 MB,在 24G VRAM 上毫无压力 # 超大 Batch Size 能让策略收敛平滑,减少策略震荡 TRAIN_BATCH_SIZE = 16384 TRAIN_STEPS_PER_ITER = 64 # 每次迭代后,网络更新 64 步 # ── 优化器细节 ── LEARNING_RATE = 5e-4 # BS 变大,LR 稍微降一点,求稳 WEIGHT_DECAY = 1e-4 # 防止过拟合 # 训练 #NUM_ITERATIONS = 100 # 总训练迭代次数 #GAMES_PER_ITER = 50 # 每个 iteration 跑多少局自对弈 #TRAIN_BATCH_SIZE = 256 # 训练 mini-batch 大小 #TRAIN_STEPS_PER_ITER = 20 # 每 iteration 训练多少步 #MIN_BUFFER_SIZE_FOR_TRAIN = 1000 # Buffer 至少有多少条数据才开始训练 #LEARNING_RATE = 1e-3 # AdamW 学习率 #WEIGHT_DECAY = 1e-4 # AdamW 权重衰减 # 归一化常量(与 env_features 的 5 维对应) # STACK_NORMALIZE = 20000.0 已从 env_adapter 导入(与 Botzone 初始筹码对齐) STREET_NORMALIZE = 3.0 # 街道归一化因子 (0/3, 1/3, 2/3, 3/3) CLIP_GRAD_NORM = 1.0 # 梯度裁剪阈值,防止训练不稳定 # Card Model 权重路径(设为 None 则使用随机初始化) CARD_MODEL_CHECKPOINT = "card_model/data/best_card_model.pt" # 设备 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ── 检查点配置 ── CHECKPOINT_DIR = os.path.join(_POKER_DIR, "checkpoints") CHECKPOINT_INTERVAL = 50 # 每 50 个 iteration 保存一次 KEEP_LAST_N_CHECKPOINTS = 3 # 本地最多保留最近的 3 个检查点文件,防止磁盘撑爆 # ───────────────────── 辅助函数 ───────────────────── def build_env_features(env_info: dict) -> torch.Tensor: """ 将 extract_env_state 返回的字典归一化为 5 维 env_features Tensor。 5 维内容: [pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position] Args: env_info: extract_env_state() 返回的字典 Returns: [5] 的 FloatTensor """ features = [ env_info["pot"] / STACK_NORMALIZE, env_info["p0_stack"] / STACK_NORMALIZE, env_info["p1_stack"] / STACK_NORMALIZE, env_info["street"] / STREET_NORMALIZE, float(env_info["position"]), ] return torch.tensor(features, dtype=torch.float32) # ── CardModel 结果缓存 ── # 在扑克博弈树中,只要没有发新牌,牌面概率绝对不变。 # 以 (hole_cards_tuple, board_cards_tuple) 为 key 缓存 CardModel 输出, # 避免在同一个发牌阶段(如 Flop)反复推理。 CARD_CACHE: Dict[str, torch.Tensor] = {} def build_card_features( card_model: CardModel, state, ) -> torch.Tensor: """ 使用 CardModel 从当前 state 提取 50 维胜率直方图 (card_features)。 带字典缓存:只要 hole_cards 和 board_cards 不变,直接返回缓存结果。 在每次 iteration 开始前由主循环调用 CARD_CACHE.clear() 防止内存泄漏。 流程: 1. 调用 extract_cards_from_state(state) 获取 hole_cards, board_cards 2. 构造 cache_key,命中则直接返回 3. 未命中则构造 Tensor 进行推理(在模型所在设备上进行) 4. 返回 [50] 的 FloatTensor(CPU) Args: card_model: 已加载权重的 CardModel(eval 模式) state: OpenSpiel State Returns: [50] 的 FloatTensor(脱离计算图,在 CPU 上) """ global CARD_CACHE hole_cards, board_cards = extract_cards_from_state(state) # 将卡牌列表转为可哈希的 key cache_key = f"{tuple(hole_cards)}_{tuple(board_cards)}" if cache_key in CARD_CACHE: return CARD_CACHE[cache_key] # 构造输入 Tensor(使用 card_model 当前所在设备) model_device = next(card_model.parameters()).device x_hole = torch.tensor([hole_cards], dtype=torch.int64, device=model_device) # [1, 2] # board_cards 不足 5 张时用 PAD_TOKEN 填充 padded_board = board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards)) x_board = torch.tensor([padded_board], dtype=torch.int64, device=model_device) # [1, 5] with torch.no_grad(): _, pred_histogram = card_model(x_hole, x_board) # [1, 50] # 脱离计算图,转到 CPU,去掉 batch 维度 result = pred_histogram.squeeze(0).cpu() CARD_CACHE[cache_key] = result return result # ───────────────────── 核心:外部采样遍历(单进程版本,保留兼容) ───────────────────── def traverse( state, traversing_player: int, card_model: CardModel, cfr_net: CFRNetwork, buffer: CFRBuffer, translator: BetTranslator, depth: int = 0, ) -> float: """ 外部采样 MCCFR 的博弈树遍历函数(单进程版本)。 递归地遍历博弈树,在遍历者回合计算每个动作的反事实遗憾值 (Regret), 并将经验 [info_state, legal_mask, regrets, strategy] 存入 Buffer。 === 三种节点类型的处理 === 1. Terminal Node (终局): 直接返回 traversing_player 的收益值 state.returns()[traversing_player] 2. Chance Node (发牌): 按概率随机采样一个发牌动作,递归继续 3. Player Node (玩家行动): 提取信息集特征,通过 CFRNetwork 获取当前策略 σ(I,·) 然后分两种情况: a) 对方回合 (current_player != traversing_player): 根据 σ(I,·) 的概率分布采样 1 个动作,执行后递归 → 这是"外部采样"的含义:只采样 1 条路径,大幅减少计算量 b) 遍历者回合 (current_player == traversing_player): 遍历所有合法动作 a: - clone state,执行动作 a,递归得到 v(I, a) 计算加权期望收益: v(I) = Σ_a σ(I,a) * v(I,a) 计算遗憾值: regret(I, a) = v(I,a) - v(I) 将经验存入 Buffer Args: state: OpenSpiel 的 State 对象 traversing_player: 当前遍历者 (0 或 1) card_model: Card Model(eval 模式),用于提取牌面特征 cfr_net: CFR 策略网络(eval 模式),用于获取当前策略 buffer: CFR 经验回放池 translator: BetTranslator,CFR 动作 ↔ 引擎动作转换 Returns: float: 当前节点对 traversing_player 的期望收益 """ # ── 深度熔断保护:超过极限深度,强制截断博弈树 ── if depth >= 40: return 0.0 # ── 1. Terminal Node: 终局,直接返回收益 ── if state.is_terminal(): return float(state.returns()[traversing_player]) # ── 2. Chance Node: 随机发牌 ── if state.is_chance_node(): outcomes = state.chance_outcomes() action_list, prob_list = zip(*outcomes) chance_action = random.choices(action_list, weights=prob_list, k=1)[0] child_state = state.clone() child_state.apply_action(chance_action) utility = traverse(child_state, traversing_player, card_model, cfr_net, buffer, translator, depth + 1) del child_state return utility # ── 3. Player Node: 玩家行动 ── current_player = state.current_player() env_info = extract_env_state(state) env_features = build_env_features(env_info) card_features = build_card_features(card_model, state) legal_mask = env_info["legal_mask"] info_state = torch.cat([card_features, env_features], dim=0) # 使用 cfr_net 当前所在设备(数据生成阶段为 CPU) net_device = next(cfr_net.parameters()).device legal_mask_tensor = torch.tensor( [legal_mask], dtype=torch.float32, device=net_device ) card_input = card_features.unsqueeze(0).to(net_device) env_input = env_features.unsqueeze(0).to(net_device) with torch.no_grad(): current_strategy, _ = cfr_net.get_strategy( card_input, env_input, legal_mask_tensor ) strategy_list = current_strategy.squeeze(0).cpu().tolist() # ── 3a. 对方回合 ── if current_player != traversing_player: legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: legal_indices = [1] probs = [strategy_list[i] for i in legal_indices] prob_sum = sum(probs) if prob_sum > 0: probs = [p / prob_sum for p in probs] else: probs = [1.0 / len(legal_indices)] * len(legal_indices) sampled_idx = random.choices(legal_indices, weights=probs, k=1)[0] engine_action = translator.cfr_to_engine_action(state, sampled_idx) child_state = state.clone() child_state.apply_action(engine_action) utility = traverse(child_state, traversing_player, card_model, cfr_net, buffer, translator, depth + 1) del child_state return utility # ── 3b. 遍历者回合 ── legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: return 0.0 action_utilities = [0.0] * NUM_CFR_ACTIONS for a in legal_indices: child_state = state.clone() engine_action = translator.cfr_to_engine_action(child_state, a) child_state.apply_action(engine_action) action_utilities[a] = traverse( child_state, traversing_player, card_model, cfr_net, buffer, translator, depth + 1 ) del child_state node_utility = 0.0 for a in legal_indices: node_utility += strategy_list[a] * action_utilities[a] regrets = [0.0] * NUM_CFR_ACTIONS for a in legal_indices: regrets[a] = (action_utilities[a] - node_utility) / STACK_NORMALIZE buffer.add( info_state=info_state, legal_mask=legal_mask, regrets=regrets, strategy=strategy_list, ) return node_utility # ───────────────────── 多进程 Worker 基础设施 ───────────────────── # # Worker 进程通过 _init_worker 初始化,在其中自行创建: # - pyspiel 游戏实例(不可跨进程序列化) # - CPU 上的 CFRNetwork(从主进程传入的 state_dict 加载权重) # - CPU 上的 CardModel(从主进程传入的 state_dict 加载权重) # - BetTranslator(轻量无状态对象) # # 权重热更新机制: # 模型参数极小(< 500KB),直接通过 worker_traverse_batch 函数参数传递, # 无需临时文件中转,彻底消除并发读写冲突和 BrokenProcessPool 风险。 # # IPC 安全机制: # - 使用 spawn 启动方式,避免 fork 继承 CUDA Context 导致死锁 # - Worker 返回经验全部使用 Python 基本类型(list[float]/list[int]) # - 每个 Worker 批量处理多局游戏,减少 IPC 消息数量,防止管道阻塞 # Worker 进程的全局状态 _WORKER_STATE: Dict[str, Any] = {} def _init_worker( cfr_state_dict: dict, card_state_dict: dict, ) -> None: """ ProcessPoolExecutor 的 worker 初始化函数。 在每个 worker 进程启动时调用一次,创建该进程专属的: - OpenSpiel 游戏实例 - CPU 上的 CFRNetwork(eval 模式) - CPU 上的 CardModel(eval 模式) - BetTranslator 传入的 state_dict 中所有 Tensor 已通过 .cpu() 确保在 CPU 上, 避免 CUDA 跨进程问题。 权重热更新: 每次 iteration 主进程通过 worker_traverse_batch 的 cfr_state_dict 参数 直接传递最新权重,Worker 在 batch 开始时 load_state_dict 加载, 彻底消除临时文件并发读写冲突。 Args: cfr_state_dict: 主进程 CFRNetwork 的 state_dict(CPU Tensor) card_state_dict: 主进程 CardModel 的 state_dict(CPU Tensor) """ global _WORKER_STATE # 每个 worker 自行创建 OpenSpiel 游戏实例(不可跨进程传递) _WORKER_STATE["game"] = pyspiel.load_game(HUNL_FULLGAME_STRING) # 在 CPU 上创建 CFRNetwork 并加载主进程的最新权重 cfr_net = CFRNetwork( card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS, ) cfr_net.load_state_dict(cfr_state_dict) cfr_net.eval() _WORKER_STATE["cfr_net"] = cfr_net # 在 CPU 上创建 CardModel 并加载权重 # CardModel 权重在训练期间不会变化,只需初始化时加载一次 card_model = CardModel() card_model.load_state_dict(card_state_dict) card_model.eval() _WORKER_STATE["card_model"] = card_model # BetTranslator 是无状态对象,可以直接创建 _WORKER_STATE["translator"] = BetTranslator() def worker_traverse_batch( game_indices: List[int], cfr_state_dict: dict, ) -> Tuple[List[Tuple[int, float]], str]: """ 多进程版本的批量 MCCFR 遍历函数,在 Worker 进程中执行。 一个 Worker 一次处理多局游戏(由 game_indices 指定), 大幅减少 IPC 消息数量,防止管道被海量小消息撑爆。 权重通过函数参数直接传入(模型 < 500KB),无需临时文件中转, 彻底消除并发读写冲突和 BrokenProcessPool 风险。 经验数据通过临时文件回传,避免 multiprocessing.Pipe 序列化爆炸: Worker 将经验列表 torch.save 到临时 .pt 文件,仅返回文件路径。 主进程 torch.load 读取后立即删除临时文件。 Args: game_indices: 要处理的游戏索引列表(长度约 GAMES_PER_ITER/NUM_WORKERS) cfr_state_dict: 主进程当前 iteration 的 CFRNetwork 权重(CPU Tensor) Returns: game_results: [(game_idx, utility), ...] 每局的索引和收益 exp_file_path: 经验临时文件路径,主进程需 torch.load 读取后删除 """ global _WORKER_STATE global CARD_CACHE # ── 清除本 Worker 上个批次的卡牌缓存,彻底封杀内存泄露! ── CARD_CACHE.clear() # ── 直接通过参数加载最新权重,无需文件中转 ── _WORKER_STATE["cfr_net"].load_state_dict(cfr_state_dict) game = _WORKER_STATE["game"] cfr_net = _WORKER_STATE["cfr_net"] card_model = _WORKER_STATE["card_model"] translator = _WORKER_STATE["translator"] game_results: List[Tuple[int, float]] = [] all_experiences: List[Tuple] = [] for game_idx in game_indices: traversing_player = game_idx % 2 state = game.new_initial_state() experiences: List[Tuple] = [] utility = _traverse_worker( state, traversing_player, card_model, cfr_net, experiences, translator, ) game_results.append((game_idx, utility)) all_experiences.extend(experiences) #print(f"Worker 跑完了一局! (Game {game_idx})", flush=True) # ── 将经验落盘,避免 IPC 管道序列化爆炸 ── # 直接 return all_experiences 会导致 multiprocessing.Pipe 被撑爆 # (单次可达数百 MB 的 Python list 序列化数据),引发 OS OOM 和 BrokenProcessPool。 # 改用临时文件:torch.save 落盘 → 仅传输文件路径字符串 → 主进程读取后删除。 exp_file_path = os.path.join( tempfile.gettempdir(), f"cfr_exp_{uuid.uuid4().hex}.pt", ) torch.save(all_experiences, exp_file_path) del all_experiences # 立即释放 Worker 内存,避免文件缓存双倍占用 return game_results, exp_file_path def _traverse_worker( state, traversing_player: int, card_model: CardModel, cfr_net: CFRNetwork, experiences: List[Tuple], translator: BetTranslator, depth: int = 0, ) -> float: """ Worker 进程内部的递归遍历函数。 与 traverse() 逻辑完全一致,唯一区别是将经验追加到 本地 experiences 列表(而非全局 CFRBuffer),且存入的 info_state 使用 list[float] 而非 Tensor,以便跨进程传输。 内存安全保证: - 所有 Tensor 操作在 torch.no_grad() 下进行 - info_state_tensor.tolist() 后,原始 Tensor 即可被 GC 回收 - 递归过程中 state.clone() 产生的对象在递归返回后也被 GC - 最终返回的 experiences 只包含 Python 基本类型, 主进程 torch.tensor() 新建时无 grad_fn,buffer.add() 安全 设备说明: Worker 中所有模型始终在 CPU 上。Tensor 设备由模型参数所在 设备自动推导(next(cfr_net.parameters()).device),无需显式传递。 Args: state: OpenSpiel State traversing_player: 当前遍历者 card_model: CardModel(eval,CPU) cfr_net: CFRNetwork(eval,CPU) experiences: 本地经验暂存列表 translator: BetTranslator Returns: float: 当前节点对 traversing_player 的期望收益 """ # ── 深度熔断保护:超过极限深度,强制截断博弈树 ── if depth >= 40: return 0.0 # ── 1. Terminal Node ── if state.is_terminal(): return float(state.returns()[traversing_player]) # ── 2. Chance Node ── if state.is_chance_node(): outcomes = state.chance_outcomes() action_list, prob_list = zip(*outcomes) chance_action = random.choices(action_list, weights=prob_list, k=1)[0] child_state = state.clone() child_state.apply_action(chance_action) utility = _traverse_worker(child_state, traversing_player, card_model, cfr_net, experiences, translator, depth + 1) del child_state return utility # ── 3. Player Node ── current_player = state.current_player() env_info = extract_env_state(state) env_features = build_env_features(env_info) card_features = build_card_features(card_model, state) legal_mask = env_info["legal_mask"] info_state_tensor = torch.cat([card_features, env_features], dim=0) # 使用 cfr_net 当前所在设备(Worker 中始终为 CPU) net_device = next(cfr_net.parameters()).device legal_mask_tensor = torch.tensor( [legal_mask], dtype=torch.float32, device=net_device ) card_input = card_features.unsqueeze(0).to(net_device) env_input = env_features.unsqueeze(0).to(net_device) with torch.no_grad(): current_strategy, _ = cfr_net.get_strategy( card_input, env_input, legal_mask_tensor ) strategy_list = current_strategy.squeeze(0).cpu().tolist() # ── 3a. 对方回合 ── if current_player != traversing_player: legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: legal_indices = [1] probs = [strategy_list[i] for i in legal_indices] prob_sum = sum(probs) if prob_sum > 0: probs = [p / prob_sum for p in probs] else: probs = [1.0 / len(legal_indices)] * len(legal_indices) sampled_idx = random.choices(legal_indices, weights=probs, k=1)[0] engine_action = translator.cfr_to_engine_action(state, sampled_idx) child_state = state.clone() child_state.apply_action(engine_action) utility = _traverse_worker(child_state, traversing_player, card_model, cfr_net, experiences, translator, depth + 1) del child_state return utility # ── 3b. 遍历者回合 ── legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: return 0.0 action_utilities = [0.0] * NUM_CFR_ACTIONS for a in legal_indices: child_state = state.clone() engine_action = translator.cfr_to_engine_action(child_state, a) child_state.apply_action(engine_action) action_utilities[a] = _traverse_worker( child_state, traversing_player, card_model, cfr_net, experiences, translator, depth + 1, ) del child_state node_utility = 0.0 for a in legal_indices: node_utility += strategy_list[a] * action_utilities[a] regrets = [0.0] * NUM_CFR_ACTIONS for a in legal_indices: regrets[a] = (action_utilities[a] - node_utility) / STACK_NORMALIZE # 存入本地经验列表(info_state 转为 list[float] 以便跨进程传输) # .tolist() 后 info_state_tensor 可被 GC 回收,不会造成内存累积 info_state_list = info_state_tensor.tolist() experiences.append((info_state_list, legal_mask, regrets, strategy_list)) return node_utility # ───────────────────── 检查点保存与恢复 ───────────────────── def save_checkpoint( iteration: int, cfr_net: CFRNetwork, buffer: CFRBuffer, optimizer, checkpoint_dir: str, keep_last_n: int, ) -> None: """ 原子保存检查点,并清理旧文件。 原子操作流程: 1. 先写入临时文件 ckpt_tmp.pt 2. 写入成功后 os.replace 重命名为 ckpt_iter_{iteration}.pt 3. 复制一份 latest_ckpt.pt 作为默认恢复入口 4. 清理超过 keep_last_n 的旧检查点 Args: iteration: 当前迭代次数 cfr_net: CFR 策略网络 buffer: CFR 经验回放池 optimizer: AdamW 优化器(可能为 None) checkpoint_dir: 检查点保存目录 keep_last_n: 保留最近 N 个检查点 """ os.makedirs(checkpoint_dir, exist_ok=True) # 构造保存字典 save_dict = { "iteration": iteration, "model_state_dict": {k: v.cpu() for k, v in cfr_net.state_dict().items()}, "buffer_state_dict": buffer.state_dict(), } if optimizer is not None: # 优化器状态中的 Tensor 可能在 GPU 上,需转 CPU opt_state = optimizer.state_dict() opt_state_cpu = {} for k, v in opt_state.items(): if isinstance(v, torch.Tensor): opt_state_cpu[k] = v.cpu() elif isinstance(v, dict): opt_state_cpu[k] = _recursive_tensor_to_cpu(v) else: opt_state_cpu[k] = v save_dict["optimizer_state_dict"] = opt_state_cpu # 原子保存:先写临时文件,再 rename ckpt_path = os.path.join(checkpoint_dir, f"ckpt_iter_{iteration}.pt") tmp_path = os.path.join(checkpoint_dir, "ckpt_tmp.pt") torch.save(save_dict, tmp_path) os.replace(tmp_path, ckpt_path) # 更新 latest_ckpt.pt(用 copy 而非 symlink,更跨平台) latest_path = os.path.join(checkpoint_dir, "latest_ckpt.pt") # os.replace 是原子的,先用临时文件再替换 tmp_latest = os.path.join(checkpoint_dir, "latest_ckpt_tmp.pt") torch.save(save_dict, tmp_latest) os.replace(tmp_latest, latest_path) print(f"[Checkpoint] 已保存: iter={iteration}, buffer={len(buffer)}") # 清理旧检查点(保留最近 keep_last_n 个) _cleanup_old_checkpoints(checkpoint_dir, keep_last_n) def _recursive_tensor_to_cpu(d: dict) -> dict: """递归地将字典中所有 Tensor 转到 CPU。""" out = {} for k, v in d.items(): if isinstance(v, torch.Tensor): out[k] = v.cpu() elif isinstance(v, dict): out[k] = _recursive_tensor_to_cpu(v) elif isinstance(v, list): out[k] = [ item.cpu() if isinstance(item, torch.Tensor) else item for item in v ] else: out[k] = v return out def _cleanup_old_checkpoints(checkpoint_dir: str, keep_last_n: int) -> None: """删除旧的检查点文件,只保留最近 keep_last_n 个。""" pattern = os.path.join(checkpoint_dir, "ckpt_iter_*.pt") files = glob.glob(pattern) if len(files) <= keep_last_n: return # 按文件名中的 iteration 数字排序 def _extract_iter(path: str) -> int: basename = os.path.basename(path) # ckpt_iter_123.pt -> 123 num = basename.replace("ckpt_iter_", "").replace(".pt", "") try: return int(num) except ValueError: return -1 files_with_iter = [(f, _extract_iter(f)) for f in files if _extract_iter(f) >= 0] files_with_iter.sort(key=lambda x: x[1]) # 删除多余的旧文件 to_delete = files_with_iter[:-keep_last_n] for path, it in to_delete: os.remove(path) print(f"[Checkpoint] 已删除旧检查点: iter={it}") def _optimizer_state_to_device(optimizer: torch.optim.Optimizer, device: torch.device) -> None: """将优化器状态中所有 Tensor 移到指定设备。 PyTorch 的 optimizer.load_state_dict() 不会自动将 Tensor 移到 当前参数所在的设备,需要手动处理。这对于从 CPU 检查点恢复到 GPU 上的场景至关重要。 Args: optimizer: PyTorch 优化器 device: 目标设备 """ for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.to(device) def load_checkpoint( checkpoint_path: str, cfr_net: CFRNetwork, buffer: CFRBuffer, ) -> dict: """ 从检查点文件恢复训练状态。 Args: checkpoint_path: 检查点文件路径 cfr_net: CFR 策略网络(将被加载权重) buffer: CFR 经验回放池(将被加载数据) Returns: dict: 包含 iteration 和 optimizer_state_dict(如果有)的字典 """ print(f"[Checkpoint] 正在加载: {checkpoint_path}") ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False) # 恢复模型权重 cfr_net.load_state_dict(ckpt["model_state_dict"]) # 恢复 Buffer buffer.load_state_dict(ckpt["buffer_state_dict"]) result = { "iteration": ckpt.get("iteration", 0), } if "optimizer_state_dict" in ckpt: result["optimizer_state_dict"] = ckpt["optimizer_state_dict"] print(f"[Checkpoint] 恢复成功: iter={result['iteration']}, buffer={len(buffer)}") return result # ───────────────────── 训练一步 ───────────────────── def train_step( cfr_net: CFRNetwork, buffer: CFRBuffer, optimizer: torch.optim.Optimizer, device: torch.device, ) -> Tuple[float, float, float]: """ 从 Buffer 采样一个 mini-batch,训练 CFRNetwork 一步。 Loss 组成: 1. Regret Loss: MSE(预测 regret, 目标 regret),只在 legal_mask=1 的位置计算 2. Policy Loss: MSE(预测策略, 目标策略),拟合 Buffer 中存储的策略分布 总 Loss = Regret Loss + Policy Loss Args: cfr_net: CFR 策略网络 buffer: CFR 经验回放池 optimizer: AdamW 优化器 device: 运行设备 Returns: (total_loss, regret_loss, policy_loss): 三个 float 标量 """ cfr_net.train() info_states, legal_masks, target_regrets, target_strategies = buffer.sample( TRAIN_BATCH_SIZE ) info_states = info_states.to(device) legal_masks = legal_masks.to(device) target_regrets = target_regrets.to(device) target_strategies = target_strategies.to(device) card_features = info_states[:, :CARD_DIM] env_features = info_states[:, CARD_DIM:] pred_regrets, pred_policy_logits = cfr_net(card_features, env_features) # ── Loss 1: Regret Loss ── regret_diff = (pred_regrets - target_regrets) ** 2 regret_loss = (regret_diff * legal_masks).sum() / legal_masks.sum().clamp(min=1.0) # ── Loss 2: Policy Loss ── masked_logits = pred_policy_logits.masked_fill(legal_masks == 0, float("-inf")) pred_strategy = F.softmax(masked_logits, dim=-1) all_illegal = (legal_masks.sum(dim=-1, keepdim=True) == 0) num_legal = legal_masks.sum(dim=-1, keepdim=True).clamp(min=1.0) uniform = legal_masks / num_legal pred_strategy = torch.where(all_illegal, uniform, pred_strategy) policy_diff = (pred_strategy - target_strategies) ** 2 policy_loss = (policy_diff * legal_masks).sum() / legal_masks.sum().clamp(min=1.0) total_loss = regret_loss + policy_loss optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(cfr_net.parameters(), max_norm=CLIP_GRAD_NORM) optimizer.step() return ( total_loss.item(), regret_loss.item(), policy_loss.item(), ) # ───────────────────── 主训练循环 ───────────────────── def main(): """ MCCFR 训练主循环(多进程版本)。 每个 iteration 分两个阶段: 阶段 A (Data Generation) — 多进程并行: - 使用 ProcessPoolExecutor(mp_context=spawn) 并发执行 - 进程池跨 iteration 复用,通过共享内存 dict 热更新权重 - 每个 Worker 批量处理多局,减少 IPC 消息数量 - 主进程汇总所有 Worker 的经验到全局 CFRBuffer 阶段 B (Network Training) — 单进程 GPU: - 如果 Buffer 中数据量足够,进行 mini-batch 训练 - 更新 Regret Head 和 Policy Head 日志输出: - 每 iteration 打印 Buffer 大小、Loss 变化 """ # ── 强制使用 spawn 启动方式,防止 fork 继承 CUDA Context 死锁 ── # Linux 默认 fork,fork 后子进程继承父进程的 CUDA Context, # 但 CUDA 运行时内部锁状态不一致,任何后续 CUDA 操作都会永久死锁。 # spawn 从零启动 Python 解释器,完全隔离 CUDA 状态。 # 必须在 if __name__ == "__main__" 保护内、任何 CUDA 操作之前调用。 mp.set_start_method('spawn', force=True) print("=" * 70) print(" MCCFR Self-Play Training Pipeline (Multi-Process)") print(" External Sampling + Deep CFR") print(f" NUM_WORKERS = {NUM_WORKERS}") print(f" Start method = spawn") print("=" * 70) # ── 1. 初始化 OpenSpiel 环境 ── game = pyspiel.load_game(HUNL_FULLGAME_STRING) translator = BetTranslator() print(f"\n[环境] 游戏加载成功: {HUNL_FULLGAME_STRING}") # ── 2. 初始化 Card Model(先放 CPU,数据生成阶段用 CPU 推理) ── card_model = CardModel() # 初始在 CPU if CARD_MODEL_CHECKPOINT and os.path.isfile(CARD_MODEL_CHECKPOINT): ckpt = torch.load(CARD_MODEL_CHECKPOINT, map_location="cpu", weights_only=False) card_model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt) print(f"[CardModel] 已加载权重: {CARD_MODEL_CHECKPOINT}") else: print(f"[CardModel] 使用随机初始化(未提供 checkpoint)") card_model.eval() # ── 3. 初始化 CFR Network(先放 CPU,数据生成阶段用 CPU 推理) ── cfr_net = CFRNetwork( card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS, ) # 初始在 CPU total_params = sum(p.numel() for p in cfr_net.parameters()) print(f"[CFRNetwork] 参数量: {total_params:,},初始设备: CPU") # ── 4. 初始化 Buffer ── buffer = CFRBuffer(max_size=BUFFER_MAX_SIZE) print(f"[Buffer] 容量: {BUFFER_MAX_SIZE:,}") # Optimizer 延迟到首次阶段 B 时创建(此时 cfr_net 已移至 GPU) # 但如果从检查点恢复且存档中有 optimizer_state_dict,则提前初始化 optimizer = None # ── 检查点恢复 ── # 检查 latest_ckpt.pt 是否存在,存在则自动恢复 start_iteration = 1 latest_ckpt_path = os.path.join(CHECKPOINT_DIR, "latest_ckpt.pt") if os.path.isfile(latest_ckpt_path): ckpt_info = load_checkpoint(latest_ckpt_path, cfr_net, buffer) start_iteration = ckpt_info["iteration"] + 1 # 如果存档中有 optimizer 状态,必须现在就初始化并恢复 # 原因:optimizer 的动量缓冲区绑定在参数对象上, # 必须在 cfr_net 移到 GPU 后、任何训练之前加载 if "optimizer_state_dict" in ckpt_info: # 先将网络移至 GPU(optimizer 的参数必须在正确设备上) cfr_net.to(DEVICE) optimizer = torch.optim.AdamW( cfr_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, ) # 加载优化器状态(可能包含 GPU 上的 Tensor,需映射到正确设备) optimizer.load_state_dict(ckpt_info["optimizer_state_dict"]) # 将优化器状态中的所有 Tensor 移到与参数相同的设备 _optimizer_state_to_device(optimizer, DEVICE) # 网络移回 CPU,等阶段 A 再用 cfr_net.to("cpu") print(f"[Optimizer] 已从检查点恢复优化器状态, device={DEVICE}") print(f"[Checkpoint] 将从 Iteration {start_iteration} 继续训练") else: print("[Checkpoint] 未找到检查点,从头开始训练") # ── 5. 初始化 CPU state_dict(给 Worker 首次初始化用) ── # 必须在创建进程池之前转为 CPU,避免 CUDA Tensor 跨进程传输 cfr_state_dict_cpu = {k: v.cpu() for k, v in cfr_net.state_dict().items()} card_state_dict_cpu = {k: v.cpu() for k, v in card_model.state_dict().items()} # ── 6. 创建进程池(跨 iteration 复用,避免重复初始化开销) ── # mp_context=spawn 确保子进程不继承 CUDA Context # 使用 spawn context 创建 ProcessPoolExecutor,而非全局 set_start_method spawn_ctx = mp.get_context('spawn') print(f"\n{'='*70}") print(f" 开始训练: {NUM_ITERATIONS} iterations, " f"每 iter {GAMES_PER_ITER} 局自对弈, " f"{NUM_WORKERS} workers 并行") print(f"{'='*70}") # 将 GAMES_PER_ITER 局均匀分配给 NUM_WORKERS 个 Worker # 每个 Worker 一次处理一批游戏,减少 IPC 消息数量 game_indices_all = list(range(GAMES_PER_ITER)) # 计算每个 Worker 分配的游戏索引 chunks: List[List[int]] = [[] for _ in range(NUM_WORKERS)] for i, gi in enumerate(game_indices_all): chunks[i % NUM_WORKERS].append(gi) with ProcessPoolExecutor( max_workers=NUM_WORKERS, initializer=_init_worker, initargs=(cfr_state_dict_cpu, card_state_dict_cpu), mp_context=spawn_ctx, ) as executor: for iteration in range(start_iteration, NUM_ITERATIONS + 1): # ──────────── 阶段 A: 数据生成 — 多进程 ──────────── # 动态设备切换:数据生成阶段,模型放 CPU,避免 Batch=1 的 GPU PCIe 延迟 card_model.to("cpu") cfr_net.to("cpu") cfr_net.eval() # 清除 CardModel 缓存,防止内存泄漏 CARD_CACHE.clear() # 获取当前网络权重,直接通过函数参数传递给 Worker(无需临时文件) cfr_state_dict_cpu = {k: v.cpu() for k, v in cfr_net.state_dict().items()} # 提交批量任务(每个 Worker 处理一批游戏,权重直接参数传递) futures = {} for worker_id, chunk in enumerate(chunks): if not chunk: continue future = executor.submit(worker_traverse_batch, chunk, cfr_state_dict_cpu) futures[future] = worker_id # 收集结果并汇总经验 game_results = [] total_experiences = 0 for future in as_completed(futures): worker_id = futures[future] try: worker_game_results, exp_file_path = future.result() except Exception as e: print(f"[警告] Worker {worker_id} 执行失败: {e}") continue # 记录对局收益 for game_idx, utility in worker_game_results: traversing_player = game_idx % 2 game_results.append((traversing_player, utility)) # 从临时文件加载经验(避免 IPC 管道爆炸) try: experiences = torch.load(exp_file_path, map_location="cpu", weights_only=False) except Exception as e: print(f"[警告] Worker {worker_id} 经验文件加载失败: {e}") continue finally: # 无论加载成功与否,立即删除临时文件,防止 /tmp 堆积 if os.path.isfile(exp_file_path): os.remove(exp_file_path) # 将 worker 返回的经验写入全局 Buffer for info_state_list, legal_mask, regrets, strategy in experiences: buffer.add( info_state=info_state_list, legal_mask=legal_mask, regrets=regrets, strategy=strategy, ) total_experiences += len(experiences) del experiences # 立即释放,避免主进程内存峰值 # 统计本 iteration 数据生成的概况 p0_utils = [u for p, u in game_results if p == 0] p1_utils = [u for p, u in game_results if p == 1] avg_p0 = sum(p0_utils) / len(p0_utils) if p0_utils else 0.0 avg_p1 = sum(p1_utils) / len(p1_utils) if p1_utils else 0.0 # ──────────── 阶段 B: 网络训练 — GPU ──────────── # 动态设备切换:训练阶段,cfr_net 移至 GPU 利用大 Batch 高效训练 cfr_net.to(DEVICE) # 首次进入训练阶段时创建 optimizer(此时参数已在 GPU 上) if optimizer is None: optimizer = torch.optim.AdamW( cfr_net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, ) print(f"[Optimizer] AdamW lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}, device={DEVICE}") avg_total_loss = 0.0 avg_regret_loss = 0.0 avg_policy_loss = 0.0 trained_steps = 0 if len(buffer) >= MIN_BUFFER_SIZE_FOR_TRAIN: for step in range(TRAIN_STEPS_PER_ITER): total_loss, regret_loss, policy_loss = train_step( cfr_net, buffer, optimizer, DEVICE ) avg_total_loss += total_loss avg_regret_loss += regret_loss avg_policy_loss += policy_loss trained_steps += 1 avg_total_loss /= trained_steps avg_regret_loss /= trained_steps avg_policy_loss /= trained_steps # ──────────── 打印日志 ──────────── buffer_size = len(buffer) loss_str = ( f"Total={avg_total_loss:.6f} " f"Regret={avg_regret_loss:.6f} " f"Policy={avg_policy_loss:.6f}" if trained_steps > 0 else "训练未开始(Buffer 不足)" ) print( f"[Iter {iteration:3d}/{NUM_ITERATIONS}] " f"Buffer={buffer_size:6d} " f"Exp={total_experiences:6d} " f"P0_util={avg_p0:+.1f} P1_util={avg_p1:+.1f} " f"Loss: {loss_str}" ) # ──────────── 定时保存检查点 ──────────── if iteration % CHECKPOINT_INTERVAL == 0: # 保存时网络在 GPU 上,save_checkpoint 会将权重转到 CPU save_checkpoint( iteration=iteration, cfr_net=cfr_net, buffer=buffer, optimizer=optimizer, checkpoint_dir=CHECKPOINT_DIR, keep_last_n=KEEP_LAST_N_CHECKPOINTS, ) # ── 8. 训练结束,保存最终检查点 ── save_checkpoint( iteration=NUM_ITERATIONS, cfr_net=cfr_net, buffer=buffer, optimizer=optimizer, checkpoint_dir=CHECKPOINT_DIR, keep_last_n=KEEP_LAST_N_CHECKPOINTS, ) # 同时保存一份到根目录(向后兼容) save_path = os.path.join(_POKER_DIR, "cfr_net_checkpoint.pt") torch.save({ "model_state_dict": {k: v.cpu() for k, v in cfr_net.state_dict().items()}, "optimizer_state_dict": optimizer.state_dict() if optimizer else {}, "iteration": NUM_ITERATIONS, "buffer_size": len(buffer), }, save_path) print(f"[保存] 模型已保存到: {save_path}") print(f"\n{'='*70}") print(f" 训练完成!共 {NUM_ITERATIONS} iterations") print(f" 最终 Buffer 大小: {len(buffer)}") print(f"{'='*70}") # ───────────────────── 入口 ───────────────────── if __name__ == "__main__": main()