Files
new/mccfr_trainer.py
2026-05-06 18:07:21 +08:00

1187 lines
45 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
mccfr_trainer.py — MCCFR 自我博弈训练流水线
基于外部采样 (External Sampling MCCFR) 的 Deep CFR 训练器。
核心流程:
1. traverse(state, traversing_player) 递归遍历博弈树,
计算每个信息集的反事实遗憾值 (Counterfactual Regret)
2. 将 [info_state, legal_mask, regrets, strategy] 存入 CFRBuffer
3. 从 Buffer 采样 mini-batch训练 CFRNetwork 的 Regret Head 和 Policy Head
=== 多进程并行数据生成 ===
阶段 A (Data Generation) 已改造为多进程版本:
- 使用 ProcessPoolExecutor(mp_context=spawn) 并发执行
- spawn 启动方式避免 fork 继承 CUDA Context 导致死锁
- 进程池跨 iteration 复用,通过函数参数直接传递权重(内存安全)
- 每个 Worker 批量处理多局,大幅减少 IPC 消息数量
- Worker 返回 (utility_list, experience_list),经验以 Python 基本类型传输
=== 外部采样 MCCFR 核心逻辑 ===
MCCFR 是一种近似求解大规模扩展式博弈纳什均衡的算法。
"外部采样" 指的是:在对方回合,根据当前策略采样 1 个动作继续;
在遍历者回合,遍历所有合法动作以精确计算 regret。
设 v(I, a) 为在信息集 I 采取动作 a 后的期望收益,
v(I) = sum_a(σ(I,a) * v(I,a)) 为加权期望收益,
则反事实遗憾值:
regret(I, a) = v(I, a) - v(I)
当遍历者反复经历相同信息集时,累计 regret 越大的动作
越应该被优先选择——这就是 Regret Matching 的核心直觉。
"""
# ── 必须在 import torch 之前,锁死所有 C++ 线性代数库多线程,防止 spawn 模式下 OpenMP 死锁 ──
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
import random
import sys
import os
import glob
import tempfile
import uuid
import multiprocessing as mp
from typing import List, Tuple, Dict, Any
from concurrent.futures import ProcessPoolExecutor, as_completed
import torch
import torch.nn as nn
import torch.nn.functional as F
import pyspiel
torch.set_num_threads(1)
# ── 将 poker/ 根目录加入 sys.path方便导入同目录模块 ──
_POKER_DIR = os.path.dirname(os.path.abspath(__file__))
if _POKER_DIR not in sys.path:
sys.path.insert(0, _POKER_DIR)
# ── 将项目根目录加入 sys.path方便导入 card_model 包 ──
_PROJECT_DIR = os.path.dirname(_POKER_DIR)
if _PROJECT_DIR not in sys.path:
sys.path.insert(0, _PROJECT_DIR)
from env_adapter import (
HUNL_FULLGAME_STRING,
BetTranslator,
extract_env_state,
CFR_ACTIONS,
NUM_CFR_ACTIONS,
STACK_NORMALIZE,
)
from card_model.config import NUM_BINS, PAD_TOKEN, BOARD_SIZE
from card_model.data_generator import extract_cards_from_state
from card_model.model import CardModel
from cfr_net import CFRNetwork, CARD_DIM, ENV_DIM, NUM_ACTIONS
from cfr_buffer import CFRBuffer
# ───────────────────── 终极版超参数 (128G RAM + 24G VRAM, 24核) ─────────────────────
# ── 大规模并行与迭代 ──
NUM_ITERATIONS = 30000 # 总训练迭代次数Deep CFR: 高频迭代,少局数)
GAMES_PER_ITER = 1000 # 每次迭代搜索的局数(快速进入更新阶段)
# NUM_WORKERS: 留 2 核给主进程GPU 训练 + 日志 + OS避免上下文切换惩罚
# 24 核 - 2 = 22 个 worker每个 worker 纯 CPU 密集型遍历
NUM_WORKERS = 22
# ── 巨型经验池 (吞噬你的 128G 内存) ──
# 1条样本大约 1KB 多。500 万条样本约占 6~8 GB 内存。
# 设为 1000 万,充分利用你的系统内存,避免灾难性遗忘 (Catastrophic Forgetting)
BUFFER_MAX_SIZE = 10_000_000
MIN_BUFFER_SIZE_FOR_TRAIN = 200_000 # 初期攒够 2 万条就开始训练
# ── 显存利用 (24G VRAM) ──
# CFR 网络极其小 (3层 MLP, ~100K 参数)。BS=32768 时:
# Forward 中间激活: ~33 MB, Backward: ~33 MB, Optimizer: ~0.8 MB
# 总计 < 100 MB在 24G VRAM 上毫无压力
# 超大 Batch Size 能让策略收敛平滑,减少策略震荡
TRAIN_BATCH_SIZE = 16384
TRAIN_STEPS_PER_ITER = 64 # 每次迭代后,网络更新 64 步
# ── 优化器细节 ──
LEARNING_RATE = 5e-4 # BS 变大LR 稍微降一点,求稳
WEIGHT_DECAY = 1e-4 # 防止过拟合
# 训练
#NUM_ITERATIONS = 100 # 总训练迭代次数
#GAMES_PER_ITER = 50 # 每个 iteration 跑多少局自对弈
#TRAIN_BATCH_SIZE = 256 # 训练 mini-batch 大小
#TRAIN_STEPS_PER_ITER = 20 # 每 iteration 训练多少步
#MIN_BUFFER_SIZE_FOR_TRAIN = 1000 # Buffer 至少有多少条数据才开始训练
#LEARNING_RATE = 1e-3 # AdamW 学习率
#WEIGHT_DECAY = 1e-4 # AdamW 权重衰减
# 归一化常量(与 env_features 的 5 维对应)
# STACK_NORMALIZE = 20000.0 已从 env_adapter 导入(与 Botzone 初始筹码对齐)
STREET_NORMALIZE = 3.0 # 街道归一化因子 (0/3, 1/3, 2/3, 3/3)
CLIP_GRAD_NORM = 1.0 # 梯度裁剪阈值,防止训练不稳定
# Card Model 权重路径(设为 None 则使用随机初始化)
CARD_MODEL_CHECKPOINT = "card_model/data/best_card_model.pt"
# 设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ── 检查点配置 ──
CHECKPOINT_DIR = os.path.join(_POKER_DIR, "checkpoints")
CHECKPOINT_INTERVAL = 50 # 每 50 个 iteration 保存一次
KEEP_LAST_N_CHECKPOINTS = 3 # 本地最多保留最近的 3 个检查点文件,防止磁盘撑爆
# ───────────────────── 辅助函数 ─────────────────────
def build_env_features(env_info: dict) -> torch.Tensor:
"""
将 extract_env_state 返回的字典归一化为 5 维 env_features Tensor。
5 维内容:
[pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position]
Args:
env_info: extract_env_state() 返回的字典
Returns:
[5] 的 FloatTensor
"""
features = [
env_info["pot"] / STACK_NORMALIZE,
env_info["p0_stack"] / STACK_NORMALIZE,
env_info["p1_stack"] / STACK_NORMALIZE,
env_info["street"] / STREET_NORMALIZE,
float(env_info["position"]),
]
return torch.tensor(features, dtype=torch.float32)
# ── CardModel 结果缓存 ──
# 在扑克博弈树中,只要没有发新牌,牌面概率绝对不变。
# 以 (hole_cards_tuple, board_cards_tuple) 为 key 缓存 CardModel 输出,
# 避免在同一个发牌阶段(如 Flop反复推理。
CARD_CACHE: Dict[str, torch.Tensor] = {}
def build_card_features(
card_model: CardModel,
state,
) -> torch.Tensor:
"""
使用 CardModel 从当前 state 提取 50 维胜率直方图 (card_features)。
带字典缓存:只要 hole_cards 和 board_cards 不变,直接返回缓存结果。
在每次 iteration 开始前由主循环调用 CARD_CACHE.clear() 防止内存泄漏。
流程:
1. 调用 extract_cards_from_state(state) 获取 hole_cards, board_cards
2. 构造 cache_key命中则直接返回
3. 未命中则构造 Tensor 进行推理(在模型所在设备上进行)
4. 返回 [50] 的 FloatTensorCPU
Args:
card_model: 已加载权重的 CardModeleval 模式)
state: OpenSpiel State
Returns:
[50] 的 FloatTensor脱离计算图在 CPU 上)
"""
global CARD_CACHE
hole_cards, board_cards = extract_cards_from_state(state)
# 将卡牌列表转为可哈希的 key
cache_key = f"{tuple(hole_cards)}_{tuple(board_cards)}"
if cache_key in CARD_CACHE:
return CARD_CACHE[cache_key]
# 构造输入 Tensor使用 card_model 当前所在设备)
model_device = next(card_model.parameters()).device
x_hole = torch.tensor([hole_cards], dtype=torch.int64, device=model_device) # [1, 2]
# board_cards 不足 5 张时用 PAD_TOKEN 填充
padded_board = board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards))
x_board = torch.tensor([padded_board], dtype=torch.int64, device=model_device) # [1, 5]
with torch.no_grad():
_, pred_histogram = card_model(x_hole, x_board) # [1, 50]
# 脱离计算图,转到 CPU去掉 batch 维度
result = pred_histogram.squeeze(0).cpu()
CARD_CACHE[cache_key] = result
return result
# ───────────────────── 核心:外部采样遍历(单进程版本,保留兼容) ─────────────────────
def traverse(
state,
traversing_player: int,
card_model: CardModel,
cfr_net: CFRNetwork,
buffer: CFRBuffer,
translator: BetTranslator,
depth: int = 0,
) -> float:
"""
外部采样 MCCFR 的博弈树遍历函数(单进程版本)。
递归地遍历博弈树,在遍历者回合计算每个动作的反事实遗憾值 (Regret)
并将经验 [info_state, legal_mask, regrets, strategy] 存入 Buffer。
=== 三种节点类型的处理 ===
1. Terminal Node (终局):
直接返回 traversing_player 的收益值 state.returns()[traversing_player]
2. Chance Node (发牌):
按概率随机采样一个发牌动作,递归继续
3. Player Node (玩家行动):
提取信息集特征,通过 CFRNetwork 获取当前策略 σ(I,·)
然后分两种情况:
a) 对方回合 (current_player != traversing_player):
根据 σ(I,·) 的概率分布采样 1 个动作,执行后递归
→ 这是"外部采样"的含义:只采样 1 条路径,大幅减少计算量
b) 遍历者回合 (current_player == traversing_player):
遍历所有合法动作 a:
- clone state执行动作 a递归得到 v(I, a)
计算加权期望收益: v(I) = Σ_a σ(I,a) * v(I,a)
计算遗憾值: regret(I, a) = v(I,a) - v(I)
将经验存入 Buffer
Args:
state: OpenSpiel 的 State 对象
traversing_player: 当前遍历者 (0 或 1)
card_model: Card Modeleval 模式),用于提取牌面特征
cfr_net: CFR 策略网络eval 模式),用于获取当前策略
buffer: CFR 经验回放池
translator: BetTranslatorCFR 动作 ↔ 引擎动作转换
Returns:
float: 当前节点对 traversing_player 的期望收益
"""
# ── 深度熔断保护:超过极限深度,强制截断博弈树 ──
if depth >= 40:
return 0.0
# ── 1. Terminal Node: 终局,直接返回收益 ──
if state.is_terminal():
return float(state.returns()[traversing_player])
# ── 2. Chance Node: 随机发牌 ──
if state.is_chance_node():
outcomes = state.chance_outcomes()
action_list, prob_list = zip(*outcomes)
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
child_state = state.clone()
child_state.apply_action(chance_action)
utility = traverse(child_state, traversing_player, card_model, cfr_net,
buffer, translator, depth + 1)
del child_state
return utility
# ── 3. Player Node: 玩家行动 ──
current_player = state.current_player()
env_info = extract_env_state(state)
env_features = build_env_features(env_info)
card_features = build_card_features(card_model, state)
legal_mask = env_info["legal_mask"]
info_state = torch.cat([card_features, env_features], dim=0)
# 使用 cfr_net 当前所在设备(数据生成阶段为 CPU
net_device = next(cfr_net.parameters()).device
legal_mask_tensor = torch.tensor(
[legal_mask], dtype=torch.float32, device=net_device
)
card_input = card_features.unsqueeze(0).to(net_device)
env_input = env_features.unsqueeze(0).to(net_device)
with torch.no_grad():
current_strategy, _ = cfr_net.get_strategy(
card_input, env_input, legal_mask_tensor
)
strategy_list = current_strategy.squeeze(0).cpu().tolist()
# ── 3a. 对方回合 ──
if current_player != traversing_player:
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
if not legal_indices:
legal_indices = [1]
probs = [strategy_list[i] for i in legal_indices]
prob_sum = sum(probs)
if prob_sum > 0:
probs = [p / prob_sum for p in probs]
else:
probs = [1.0 / len(legal_indices)] * len(legal_indices)
sampled_idx = random.choices(legal_indices, weights=probs, k=1)[0]
engine_action = translator.cfr_to_engine_action(state, sampled_idx)
child_state = state.clone()
child_state.apply_action(engine_action)
utility = traverse(child_state, traversing_player, card_model, cfr_net,
buffer, translator, depth + 1)
del child_state
return utility
# ── 3b. 遍历者回合 ──
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
if not legal_indices:
return 0.0
action_utilities = [0.0] * NUM_CFR_ACTIONS
for a in legal_indices:
child_state = state.clone()
engine_action = translator.cfr_to_engine_action(child_state, a)
child_state.apply_action(engine_action)
action_utilities[a] = traverse(
child_state, traversing_player, card_model, cfr_net,
buffer, translator, depth + 1
)
del child_state
node_utility = 0.0
for a in legal_indices:
node_utility += strategy_list[a] * action_utilities[a]
regrets = [0.0] * NUM_CFR_ACTIONS
for a in legal_indices:
regrets[a] = (action_utilities[a] - node_utility) / STACK_NORMALIZE
buffer.add(
info_state=info_state,
legal_mask=legal_mask,
regrets=regrets,
strategy=strategy_list,
)
return node_utility
# ───────────────────── 多进程 Worker 基础设施 ─────────────────────
#
# Worker 进程通过 _init_worker 初始化,在其中自行创建:
# - pyspiel 游戏实例(不可跨进程序列化)
# - CPU 上的 CFRNetwork从主进程传入的 state_dict 加载权重)
# - CPU 上的 CardModel从主进程传入的 state_dict 加载权重)
# - BetTranslator轻量无状态对象
#
# 权重热更新机制:
# 模型参数极小(< 500KB直接通过 worker_traverse_batch 函数参数传递,
# 无需临时文件中转,彻底消除并发读写冲突和 BrokenProcessPool 风险。
#
# IPC 安全机制:
# - 使用 spawn 启动方式,避免 fork 继承 CUDA Context 导致死锁
# - Worker 返回经验全部使用 Python 基本类型list[float]/list[int]
# - 每个 Worker 批量处理多局游戏,减少 IPC 消息数量,防止管道阻塞
# Worker 进程的全局状态
_WORKER_STATE: Dict[str, Any] = {}
def _init_worker(
cfr_state_dict: dict,
card_state_dict: dict,
) -> None:
"""
ProcessPoolExecutor 的 worker 初始化函数。
在每个 worker 进程启动时调用一次,创建该进程专属的:
- OpenSpiel 游戏实例
- CPU 上的 CFRNetworkeval 模式)
- CPU 上的 CardModeleval 模式)
- BetTranslator
传入的 state_dict 中所有 Tensor 已通过 .cpu() 确保在 CPU 上,
避免 CUDA 跨进程问题。
权重热更新:
每次 iteration 主进程通过 worker_traverse_batch 的 cfr_state_dict 参数
直接传递最新权重Worker 在 batch 开始时 load_state_dict 加载,
彻底消除临时文件并发读写冲突。
Args:
cfr_state_dict: 主进程 CFRNetwork 的 state_dictCPU Tensor
card_state_dict: 主进程 CardModel 的 state_dictCPU Tensor
"""
global _WORKER_STATE
# 每个 worker 自行创建 OpenSpiel 游戏实例(不可跨进程传递)
_WORKER_STATE["game"] = pyspiel.load_game(HUNL_FULLGAME_STRING)
# 在 CPU 上创建 CFRNetwork 并加载主进程的最新权重
cfr_net = CFRNetwork(
card_dim=CARD_DIM,
env_dim=ENV_DIM,
num_actions=NUM_ACTIONS,
)
cfr_net.load_state_dict(cfr_state_dict)
cfr_net.eval()
_WORKER_STATE["cfr_net"] = cfr_net
# 在 CPU 上创建 CardModel 并加载权重
# CardModel 权重在训练期间不会变化,只需初始化时加载一次
card_model = CardModel()
card_model.load_state_dict(card_state_dict)
card_model.eval()
_WORKER_STATE["card_model"] = card_model
# BetTranslator 是无状态对象,可以直接创建
_WORKER_STATE["translator"] = BetTranslator()
def worker_traverse_batch(
game_indices: List[int],
cfr_state_dict: dict,
) -> Tuple[List[Tuple[int, float]], str]:
"""
多进程版本的批量 MCCFR 遍历函数,在 Worker 进程中执行。
一个 Worker 一次处理多局游戏(由 game_indices 指定),
大幅减少 IPC 消息数量,防止管道被海量小消息撑爆。
权重通过函数参数直接传入(模型 < 500KB无需临时文件中转
彻底消除并发读写冲突和 BrokenProcessPool 风险。
经验数据通过临时文件回传,避免 multiprocessing.Pipe 序列化爆炸:
Worker 将经验列表 torch.save 到临时 .pt 文件,仅返回文件路径。
主进程 torch.load 读取后立即删除临时文件。
Args:
game_indices: 要处理的游戏索引列表(长度约 GAMES_PER_ITER/NUM_WORKERS
cfr_state_dict: 主进程当前 iteration 的 CFRNetwork 权重CPU Tensor
Returns:
game_results: [(game_idx, utility), ...] 每局的索引和收益
exp_file_path: 经验临时文件路径,主进程需 torch.load 读取后删除
"""
global _WORKER_STATE
global CARD_CACHE
# ── 清除本 Worker 上个批次的卡牌缓存,彻底封杀内存泄露! ──
CARD_CACHE.clear()
# ── 直接通过参数加载最新权重,无需文件中转 ──
_WORKER_STATE["cfr_net"].load_state_dict(cfr_state_dict)
game = _WORKER_STATE["game"]
cfr_net = _WORKER_STATE["cfr_net"]
card_model = _WORKER_STATE["card_model"]
translator = _WORKER_STATE["translator"]
game_results: List[Tuple[int, float]] = []
all_experiences: List[Tuple] = []
for game_idx in game_indices:
traversing_player = game_idx % 2
state = game.new_initial_state()
experiences: List[Tuple] = []
utility = _traverse_worker(
state, traversing_player, card_model, cfr_net,
experiences, translator,
)
game_results.append((game_idx, utility))
all_experiences.extend(experiences)
#print(f"Worker 跑完了一局! (Game {game_idx})", flush=True)
# ── 将经验落盘,避免 IPC 管道序列化爆炸 ──
# 直接 return all_experiences 会导致 multiprocessing.Pipe 被撑爆
# (单次可达数百 MB 的 Python list 序列化数据),引发 OS OOM 和 BrokenProcessPool。
# 改用临时文件torch.save 落盘 → 仅传输文件路径字符串 → 主进程读取后删除。
exp_file_path = os.path.join(
tempfile.gettempdir(),
f"cfr_exp_{uuid.uuid4().hex}.pt",
)
torch.save(all_experiences, exp_file_path)
del all_experiences # 立即释放 Worker 内存,避免文件缓存双倍占用
return game_results, exp_file_path
def _traverse_worker(
state,
traversing_player: int,
card_model: CardModel,
cfr_net: CFRNetwork,
experiences: List[Tuple],
translator: BetTranslator,
depth: int = 0,
) -> float:
"""
Worker 进程内部的递归遍历函数。
与 traverse() 逻辑完全一致,唯一区别是将经验追加到
本地 experiences 列表(而非全局 CFRBuffer且存入的
info_state 使用 list[float] 而非 Tensor以便跨进程传输。
内存安全保证:
- 所有 Tensor 操作在 torch.no_grad() 下进行
- info_state_tensor.tolist() 后,原始 Tensor 即可被 GC 回收
- 递归过程中 state.clone() 产生的对象在递归返回后也被 GC
- 最终返回的 experiences 只包含 Python 基本类型,
主进程 torch.tensor() 新建时无 grad_fnbuffer.add() 安全
设备说明:
Worker 中所有模型始终在 CPU 上。Tensor 设备由模型参数所在
设备自动推导next(cfr_net.parameters()).device无需显式传递。
Args:
state: OpenSpiel State
traversing_player: 当前遍历者
card_model: CardModelevalCPU
cfr_net: CFRNetworkevalCPU
experiences: 本地经验暂存列表
translator: BetTranslator
Returns:
float: 当前节点对 traversing_player 的期望收益
"""
# ── 深度熔断保护:超过极限深度,强制截断博弈树 ──
if depth >= 40:
return 0.0
# ── 1. Terminal Node ──
if state.is_terminal():
return float(state.returns()[traversing_player])
# ── 2. Chance Node ──
if state.is_chance_node():
outcomes = state.chance_outcomes()
action_list, prob_list = zip(*outcomes)
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
child_state = state.clone()
child_state.apply_action(chance_action)
utility = _traverse_worker(child_state, traversing_player, card_model, cfr_net,
experiences, translator, depth + 1)
del child_state
return utility
# ── 3. Player Node ──
current_player = state.current_player()
env_info = extract_env_state(state)
env_features = build_env_features(env_info)
card_features = build_card_features(card_model, state)
legal_mask = env_info["legal_mask"]
info_state_tensor = torch.cat([card_features, env_features], dim=0)
# 使用 cfr_net 当前所在设备Worker 中始终为 CPU
net_device = next(cfr_net.parameters()).device
legal_mask_tensor = torch.tensor(
[legal_mask], dtype=torch.float32, device=net_device
)
card_input = card_features.unsqueeze(0).to(net_device)
env_input = env_features.unsqueeze(0).to(net_device)
with torch.no_grad():
current_strategy, _ = cfr_net.get_strategy(
card_input, env_input, legal_mask_tensor
)
strategy_list = current_strategy.squeeze(0).cpu().tolist()
# ── 3a. 对方回合 ──
if current_player != traversing_player:
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
if not legal_indices:
legal_indices = [1]
probs = [strategy_list[i] for i in legal_indices]
prob_sum = sum(probs)
if prob_sum > 0:
probs = [p / prob_sum for p in probs]
else:
probs = [1.0 / len(legal_indices)] * len(legal_indices)
sampled_idx = random.choices(legal_indices, weights=probs, k=1)[0]
engine_action = translator.cfr_to_engine_action(state, sampled_idx)
child_state = state.clone()
child_state.apply_action(engine_action)
utility = _traverse_worker(child_state, traversing_player, card_model, cfr_net,
experiences, translator, depth + 1)
del child_state
return utility
# ── 3b. 遍历者回合 ──
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
if not legal_indices:
return 0.0
action_utilities = [0.0] * NUM_CFR_ACTIONS
for a in legal_indices:
child_state = state.clone()
engine_action = translator.cfr_to_engine_action(child_state, a)
child_state.apply_action(engine_action)
action_utilities[a] = _traverse_worker(
child_state, traversing_player, card_model, cfr_net,
experiences, translator, depth + 1,
)
del child_state
node_utility = 0.0
for a in legal_indices:
node_utility += strategy_list[a] * action_utilities[a]
regrets = [0.0] * NUM_CFR_ACTIONS
for a in legal_indices:
regrets[a] = (action_utilities[a] - node_utility) / STACK_NORMALIZE
# 存入本地经验列表info_state 转为 list[float] 以便跨进程传输)
# .tolist() 后 info_state_tensor 可被 GC 回收,不会造成内存累积
info_state_list = info_state_tensor.tolist()
experiences.append((info_state_list, legal_mask, regrets, strategy_list))
return node_utility
# ───────────────────── 检查点保存与恢复 ─────────────────────
def save_checkpoint(
iteration: int,
cfr_net: CFRNetwork,
buffer: CFRBuffer,
optimizer,
checkpoint_dir: str,
keep_last_n: int,
) -> None:
"""
原子保存检查点,并清理旧文件。
原子操作流程:
1. 先写入临时文件 ckpt_tmp.pt
2. 写入成功后 os.replace 重命名为 ckpt_iter_{iteration}.pt
3. 复制一份 latest_ckpt.pt 作为默认恢复入口
4. 清理超过 keep_last_n 的旧检查点
Args:
iteration: 当前迭代次数
cfr_net: CFR 策略网络
buffer: CFR 经验回放池
optimizer: AdamW 优化器(可能为 None
checkpoint_dir: 检查点保存目录
keep_last_n: 保留最近 N 个检查点
"""
os.makedirs(checkpoint_dir, exist_ok=True)
# 构造保存字典
save_dict = {
"iteration": iteration,
"model_state_dict": {k: v.cpu() for k, v in cfr_net.state_dict().items()},
"buffer_state_dict": buffer.state_dict(),
}
if optimizer is not None:
# 优化器状态中的 Tensor 可能在 GPU 上,需转 CPU
opt_state = optimizer.state_dict()
opt_state_cpu = {}
for k, v in opt_state.items():
if isinstance(v, torch.Tensor):
opt_state_cpu[k] = v.cpu()
elif isinstance(v, dict):
opt_state_cpu[k] = _recursive_tensor_to_cpu(v)
else:
opt_state_cpu[k] = v
save_dict["optimizer_state_dict"] = opt_state_cpu
# 原子保存:先写临时文件,再 rename
ckpt_path = os.path.join(checkpoint_dir, f"ckpt_iter_{iteration}.pt")
tmp_path = os.path.join(checkpoint_dir, "ckpt_tmp.pt")
torch.save(save_dict, tmp_path)
os.replace(tmp_path, ckpt_path)
# 更新 latest_ckpt.pt用 copy 而非 symlink更跨平台
latest_path = os.path.join(checkpoint_dir, "latest_ckpt.pt")
# os.replace 是原子的,先用临时文件再替换
tmp_latest = os.path.join(checkpoint_dir, "latest_ckpt_tmp.pt")
torch.save(save_dict, tmp_latest)
os.replace(tmp_latest, latest_path)
print(f"[Checkpoint] 已保存: iter={iteration}, buffer={len(buffer)}")
# 清理旧检查点(保留最近 keep_last_n 个)
_cleanup_old_checkpoints(checkpoint_dir, keep_last_n)
def _recursive_tensor_to_cpu(d: dict) -> dict:
"""递归地将字典中所有 Tensor 转到 CPU。"""
out = {}
for k, v in d.items():
if isinstance(v, torch.Tensor):
out[k] = v.cpu()
elif isinstance(v, dict):
out[k] = _recursive_tensor_to_cpu(v)
elif isinstance(v, list):
out[k] = [
item.cpu() if isinstance(item, torch.Tensor) else item
for item in v
]
else:
out[k] = v
return out
def _cleanup_old_checkpoints(checkpoint_dir: str, keep_last_n: int) -> None:
"""删除旧的检查点文件,只保留最近 keep_last_n 个。"""
pattern = os.path.join(checkpoint_dir, "ckpt_iter_*.pt")
files = glob.glob(pattern)
if len(files) <= keep_last_n:
return
# 按文件名中的 iteration 数字排序
def _extract_iter(path: str) -> int:
basename = os.path.basename(path)
# ckpt_iter_123.pt -> 123
num = basename.replace("ckpt_iter_", "").replace(".pt", "")
try:
return int(num)
except ValueError:
return -1
files_with_iter = [(f, _extract_iter(f)) for f in files if _extract_iter(f) >= 0]
files_with_iter.sort(key=lambda x: x[1])
# 删除多余的旧文件
to_delete = files_with_iter[:-keep_last_n]
for path, it in to_delete:
os.remove(path)
print(f"[Checkpoint] 已删除旧检查点: iter={it}")
def _optimizer_state_to_device(optimizer: torch.optim.Optimizer, device: torch.device) -> None:
"""将优化器状态中所有 Tensor 移到指定设备。
PyTorch 的 optimizer.load_state_dict() 不会自动将 Tensor 移到
当前参数所在的设备,需要手动处理。这对于从 CPU 检查点恢复到
GPU 上的场景至关重要。
Args:
optimizer: PyTorch 优化器
device: 目标设备
"""
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
def load_checkpoint(
checkpoint_path: str,
cfr_net: CFRNetwork,
buffer: CFRBuffer,
) -> dict:
"""
从检查点文件恢复训练状态。
Args:
checkpoint_path: 检查点文件路径
cfr_net: CFR 策略网络(将被加载权重)
buffer: CFR 经验回放池(将被加载数据)
Returns:
dict: 包含 iteration 和 optimizer_state_dict如果有的字典
"""
print(f"[Checkpoint] 正在加载: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
# 恢复模型权重
cfr_net.load_state_dict(ckpt["model_state_dict"])
# 恢复 Buffer
buffer.load_state_dict(ckpt["buffer_state_dict"])
result = {
"iteration": ckpt.get("iteration", 0),
}
if "optimizer_state_dict" in ckpt:
result["optimizer_state_dict"] = ckpt["optimizer_state_dict"]
print(f"[Checkpoint] 恢复成功: iter={result['iteration']}, buffer={len(buffer)}")
return result
# ───────────────────── 训练一步 ─────────────────────
def train_step(
cfr_net: CFRNetwork,
buffer: CFRBuffer,
optimizer: torch.optim.Optimizer,
device: torch.device,
) -> Tuple[float, float, float]:
"""
从 Buffer 采样一个 mini-batch训练 CFRNetwork 一步。
Loss 组成:
1. Regret Loss: MSE(预测 regret, 目标 regret),只在 legal_mask=1 的位置计算
2. Policy Loss: MSE(预测策略, 目标策略),拟合 Buffer 中存储的策略分布
总 Loss = Regret Loss + Policy Loss
Args:
cfr_net: CFR 策略网络
buffer: CFR 经验回放池
optimizer: AdamW 优化器
device: 运行设备
Returns:
(total_loss, regret_loss, policy_loss): 三个 float 标量
"""
cfr_net.train()
info_states, legal_masks, target_regrets, target_strategies = buffer.sample(
TRAIN_BATCH_SIZE
)
info_states = info_states.to(device)
legal_masks = legal_masks.to(device)
target_regrets = target_regrets.to(device)
target_strategies = target_strategies.to(device)
card_features = info_states[:, :CARD_DIM]
env_features = info_states[:, CARD_DIM:]
pred_regrets, pred_policy_logits = cfr_net(card_features, env_features)
# ── Loss 1: Regret Loss ──
regret_diff = (pred_regrets - target_regrets) ** 2
regret_loss = (regret_diff * legal_masks).sum() / legal_masks.sum().clamp(min=1.0)
# ── Loss 2: Policy Loss ──
masked_logits = pred_policy_logits.masked_fill(legal_masks == 0, float("-inf"))
pred_strategy = F.softmax(masked_logits, dim=-1)
all_illegal = (legal_masks.sum(dim=-1, keepdim=True) == 0)
num_legal = legal_masks.sum(dim=-1, keepdim=True).clamp(min=1.0)
uniform = legal_masks / num_legal
pred_strategy = torch.where(all_illegal, uniform, pred_strategy)
policy_diff = (pred_strategy - target_strategies) ** 2
policy_loss = (policy_diff * legal_masks).sum() / legal_masks.sum().clamp(min=1.0)
total_loss = regret_loss + policy_loss
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(cfr_net.parameters(), max_norm=CLIP_GRAD_NORM)
optimizer.step()
return (
total_loss.item(),
regret_loss.item(),
policy_loss.item(),
)
# ───────────────────── 主训练循环 ─────────────────────
def main():
"""
MCCFR 训练主循环(多进程版本)。
每个 iteration 分两个阶段:
阶段 A (Data Generation) — 多进程并行:
- 使用 ProcessPoolExecutor(mp_context=spawn) 并发执行
- 进程池跨 iteration 复用,通过共享内存 dict 热更新权重
- 每个 Worker 批量处理多局,减少 IPC 消息数量
- 主进程汇总所有 Worker 的经验到全局 CFRBuffer
阶段 B (Network Training) — 单进程 GPU:
- 如果 Buffer 中数据量足够,进行 mini-batch 训练
- 更新 Regret Head 和 Policy Head
日志输出:
- 每 iteration 打印 Buffer 大小、Loss 变化
"""
# ── 强制使用 spawn 启动方式,防止 fork 继承 CUDA Context 死锁 ──
# Linux 默认 forkfork 后子进程继承父进程的 CUDA Context
# 但 CUDA 运行时内部锁状态不一致,任何后续 CUDA 操作都会永久死锁。
# spawn 从零启动 Python 解释器,完全隔离 CUDA 状态。
# 必须在 if __name__ == "__main__" 保护内、任何 CUDA 操作之前调用。
mp.set_start_method('spawn', force=True)
print("=" * 70)
print(" MCCFR Self-Play Training Pipeline (Multi-Process)")
print(" External Sampling + Deep CFR")
print(f" NUM_WORKERS = {NUM_WORKERS}")
print(f" Start method = spawn")
print("=" * 70)
# ── 1. 初始化 OpenSpiel 环境 ──
game = pyspiel.load_game(HUNL_FULLGAME_STRING)
translator = BetTranslator()
print(f"\n[环境] 游戏加载成功: {HUNL_FULLGAME_STRING}")
# ── 2. 初始化 Card Model先放 CPU数据生成阶段用 CPU 推理) ──
card_model = CardModel() # 初始在 CPU
if CARD_MODEL_CHECKPOINT and os.path.isfile(CARD_MODEL_CHECKPOINT):
ckpt = torch.load(CARD_MODEL_CHECKPOINT, map_location="cpu", weights_only=False)
card_model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt)
print(f"[CardModel] 已加载权重: {CARD_MODEL_CHECKPOINT}")
else:
print(f"[CardModel] 使用随机初始化(未提供 checkpoint")
card_model.eval()
# ── 3. 初始化 CFR Network先放 CPU数据生成阶段用 CPU 推理) ──
cfr_net = CFRNetwork(
card_dim=CARD_DIM,
env_dim=ENV_DIM,
num_actions=NUM_ACTIONS,
) # 初始在 CPU
total_params = sum(p.numel() for p in cfr_net.parameters())
print(f"[CFRNetwork] 参数量: {total_params:,},初始设备: CPU")
# ── 4. 初始化 Buffer ──
buffer = CFRBuffer(max_size=BUFFER_MAX_SIZE)
print(f"[Buffer] 容量: {BUFFER_MAX_SIZE:,}")
# Optimizer 延迟到首次阶段 B 时创建(此时 cfr_net 已移至 GPU
# 但如果从检查点恢复且存档中有 optimizer_state_dict则提前初始化
optimizer = None
# ── 检查点恢复 ──
# 检查 latest_ckpt.pt 是否存在,存在则自动恢复
start_iteration = 1
latest_ckpt_path = os.path.join(CHECKPOINT_DIR, "latest_ckpt.pt")
if os.path.isfile(latest_ckpt_path):
ckpt_info = load_checkpoint(latest_ckpt_path, cfr_net, buffer)
start_iteration = ckpt_info["iteration"] + 1
# 如果存档中有 optimizer 状态,必须现在就初始化并恢复
# 原因optimizer 的动量缓冲区绑定在参数对象上,
# 必须在 cfr_net 移到 GPU 后、任何训练之前加载
if "optimizer_state_dict" in ckpt_info:
# 先将网络移至 GPUoptimizer 的参数必须在正确设备上)
cfr_net.to(DEVICE)
optimizer = torch.optim.AdamW(
cfr_net.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
)
# 加载优化器状态(可能包含 GPU 上的 Tensor需映射到正确设备
optimizer.load_state_dict(ckpt_info["optimizer_state_dict"])
# 将优化器状态中的所有 Tensor 移到与参数相同的设备
_optimizer_state_to_device(optimizer, DEVICE)
# 网络移回 CPU等阶段 A 再用
cfr_net.to("cpu")
print(f"[Optimizer] 已从检查点恢复优化器状态, device={DEVICE}")
print(f"[Checkpoint] 将从 Iteration {start_iteration} 继续训练")
else:
print("[Checkpoint] 未找到检查点,从头开始训练")
# ── 5. 初始化 CPU state_dict给 Worker 首次初始化用) ──
# 必须在创建进程池之前转为 CPU避免 CUDA Tensor 跨进程传输
cfr_state_dict_cpu = {k: v.cpu() for k, v in cfr_net.state_dict().items()}
card_state_dict_cpu = {k: v.cpu() for k, v in card_model.state_dict().items()}
# ── 6. 创建进程池(跨 iteration 复用,避免重复初始化开销) ──
# mp_context=spawn 确保子进程不继承 CUDA Context
# 使用 spawn context 创建 ProcessPoolExecutor而非全局 set_start_method
spawn_ctx = mp.get_context('spawn')
print(f"\n{'='*70}")
print(f" 开始训练: {NUM_ITERATIONS} iterations, "
f"每 iter {GAMES_PER_ITER} 局自对弈, "
f"{NUM_WORKERS} workers 并行")
print(f"{'='*70}")
# 将 GAMES_PER_ITER 局均匀分配给 NUM_WORKERS 个 Worker
# 每个 Worker 一次处理一批游戏,减少 IPC 消息数量
game_indices_all = list(range(GAMES_PER_ITER))
# 计算每个 Worker 分配的游戏索引
chunks: List[List[int]] = [[] for _ in range(NUM_WORKERS)]
for i, gi in enumerate(game_indices_all):
chunks[i % NUM_WORKERS].append(gi)
with ProcessPoolExecutor(
max_workers=NUM_WORKERS,
initializer=_init_worker,
initargs=(cfr_state_dict_cpu, card_state_dict_cpu),
mp_context=spawn_ctx,
) as executor:
for iteration in range(start_iteration, NUM_ITERATIONS + 1):
# ──────────── 阶段 A: 数据生成 — 多进程 ────────────
# 动态设备切换:数据生成阶段,模型放 CPU避免 Batch=1 的 GPU PCIe 延迟
card_model.to("cpu")
cfr_net.to("cpu")
cfr_net.eval()
# 清除 CardModel 缓存,防止内存泄漏
CARD_CACHE.clear()
# 获取当前网络权重,直接通过函数参数传递给 Worker无需临时文件
cfr_state_dict_cpu = {k: v.cpu() for k, v in cfr_net.state_dict().items()}
# 提交批量任务(每个 Worker 处理一批游戏,权重直接参数传递)
futures = {}
for worker_id, chunk in enumerate(chunks):
if not chunk:
continue
future = executor.submit(worker_traverse_batch, chunk, cfr_state_dict_cpu)
futures[future] = worker_id
# 收集结果并汇总经验
game_results = []
total_experiences = 0
for future in as_completed(futures):
worker_id = futures[future]
try:
worker_game_results, exp_file_path = future.result()
except Exception as e:
print(f"[警告] Worker {worker_id} 执行失败: {e}")
continue
# 记录对局收益
for game_idx, utility in worker_game_results:
traversing_player = game_idx % 2
game_results.append((traversing_player, utility))
# 从临时文件加载经验(避免 IPC 管道爆炸)
try:
experiences = torch.load(exp_file_path, map_location="cpu", weights_only=False)
except Exception as e:
print(f"[警告] Worker {worker_id} 经验文件加载失败: {e}")
continue
finally:
# 无论加载成功与否,立即删除临时文件,防止 /tmp 堆积
if os.path.isfile(exp_file_path):
os.remove(exp_file_path)
# 将 worker 返回的经验写入全局 Buffer
for info_state_list, legal_mask, regrets, strategy in experiences:
buffer.add(
info_state=info_state_list,
legal_mask=legal_mask,
regrets=regrets,
strategy=strategy,
)
total_experiences += len(experiences)
del experiences # 立即释放,避免主进程内存峰值
# 统计本 iteration 数据生成的概况
p0_utils = [u for p, u in game_results if p == 0]
p1_utils = [u for p, u in game_results if p == 1]
avg_p0 = sum(p0_utils) / len(p0_utils) if p0_utils else 0.0
avg_p1 = sum(p1_utils) / len(p1_utils) if p1_utils else 0.0
# ──────────── 阶段 B: 网络训练 — GPU ────────────
# 动态设备切换训练阶段cfr_net 移至 GPU 利用大 Batch 高效训练
cfr_net.to(DEVICE)
# 首次进入训练阶段时创建 optimizer此时参数已在 GPU 上)
if optimizer is None:
optimizer = torch.optim.AdamW(
cfr_net.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
)
print(f"[Optimizer] AdamW lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}, device={DEVICE}")
avg_total_loss = 0.0
avg_regret_loss = 0.0
avg_policy_loss = 0.0
trained_steps = 0
if len(buffer) >= MIN_BUFFER_SIZE_FOR_TRAIN:
for step in range(TRAIN_STEPS_PER_ITER):
total_loss, regret_loss, policy_loss = train_step(
cfr_net, buffer, optimizer, DEVICE
)
avg_total_loss += total_loss
avg_regret_loss += regret_loss
avg_policy_loss += policy_loss
trained_steps += 1
avg_total_loss /= trained_steps
avg_regret_loss /= trained_steps
avg_policy_loss /= trained_steps
# ──────────── 打印日志 ────────────
buffer_size = len(buffer)
loss_str = (
f"Total={avg_total_loss:.6f} "
f"Regret={avg_regret_loss:.6f} "
f"Policy={avg_policy_loss:.6f}"
if trained_steps > 0
else "训练未开始Buffer 不足)"
)
print(
f"[Iter {iteration:3d}/{NUM_ITERATIONS}] "
f"Buffer={buffer_size:6d} "
f"Exp={total_experiences:6d} "
f"P0_util={avg_p0:+.1f} P1_util={avg_p1:+.1f} "
f"Loss: {loss_str}"
)
# ──────────── 定时保存检查点 ────────────
if iteration % CHECKPOINT_INTERVAL == 0:
# 保存时网络在 GPU 上save_checkpoint 会将权重转到 CPU
save_checkpoint(
iteration=iteration,
cfr_net=cfr_net,
buffer=buffer,
optimizer=optimizer,
checkpoint_dir=CHECKPOINT_DIR,
keep_last_n=KEEP_LAST_N_CHECKPOINTS,
)
# ── 8. 训练结束,保存最终检查点 ──
save_checkpoint(
iteration=NUM_ITERATIONS,
cfr_net=cfr_net,
buffer=buffer,
optimizer=optimizer,
checkpoint_dir=CHECKPOINT_DIR,
keep_last_n=KEEP_LAST_N_CHECKPOINTS,
)
# 同时保存一份到根目录(向后兼容)
save_path = os.path.join(_POKER_DIR, "cfr_net_checkpoint.pt")
torch.save({
"model_state_dict": {k: v.cpu() for k, v in cfr_net.state_dict().items()},
"optimizer_state_dict": optimizer.state_dict() if optimizer else {},
"iteration": NUM_ITERATIONS,
"buffer_size": len(buffer),
}, save_path)
print(f"[保存] 模型已保存到: {save_path}")
print(f"\n{'='*70}")
print(f" 训练完成!共 {NUM_ITERATIONS} iterations")
print(f" 最终 Buffer 大小: {len(buffer)}")
print(f"{'='*70}")
# ───────────────────── 入口 ─────────────────────
if __name__ == "__main__":
main()