1187 lines
45 KiB
Python
1187 lines
45 KiB
Python
"""
|
||
mccfr_trainer.py — MCCFR 自我博弈训练流水线
|
||
|
||
基于外部采样 (External Sampling MCCFR) 的 Deep CFR 训练器。
|
||
核心流程:
|
||
1. traverse(state, traversing_player) 递归遍历博弈树,
|
||
计算每个信息集的反事实遗憾值 (Counterfactual Regret)
|
||
2. 将 [info_state, legal_mask, regrets, strategy] 存入 CFRBuffer
|
||
3. 从 Buffer 采样 mini-batch,训练 CFRNetwork 的 Regret Head 和 Policy Head
|
||
|
||
=== 多进程并行数据生成 ===
|
||
|
||
阶段 A (Data Generation) 已改造为多进程版本:
|
||
- 使用 ProcessPoolExecutor(mp_context=spawn) 并发执行
|
||
- spawn 启动方式避免 fork 继承 CUDA Context 导致死锁
|
||
- 进程池跨 iteration 复用,通过函数参数直接传递权重(内存安全)
|
||
- 每个 Worker 批量处理多局,大幅减少 IPC 消息数量
|
||
- Worker 返回 (utility_list, experience_list),经验以 Python 基本类型传输
|
||
|
||
=== 外部采样 MCCFR 核心逻辑 ===
|
||
|
||
MCCFR 是一种近似求解大规模扩展式博弈纳什均衡的算法。
|
||
"外部采样" 指的是:在对方回合,根据当前策略采样 1 个动作继续;
|
||
在遍历者回合,遍历所有合法动作以精确计算 regret。
|
||
|
||
设 v(I, a) 为在信息集 I 采取动作 a 后的期望收益,
|
||
v(I) = sum_a(σ(I,a) * v(I,a)) 为加权期望收益,
|
||
则反事实遗憾值:
|
||
regret(I, a) = v(I, a) - v(I)
|
||
|
||
当遍历者反复经历相同信息集时,累计 regret 越大的动作
|
||
越应该被优先选择——这就是 Regret Matching 的核心直觉。
|
||
"""
|
||
|
||
# ── 必须在 import torch 之前,锁死所有 C++ 线性代数库多线程,防止 spawn 模式下 OpenMP 死锁 ──
|
||
import os
|
||
os.environ["OMP_NUM_THREADS"] = "1"
|
||
os.environ["MKL_NUM_THREADS"] = "1"
|
||
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
||
|
||
import random
|
||
import sys
|
||
import os
|
||
import glob
|
||
import tempfile
|
||
import uuid
|
||
import multiprocessing as mp
|
||
from typing import List, Tuple, Dict, Any
|
||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
import pyspiel
|
||
|
||
torch.set_num_threads(1)
|
||
|
||
|
||
# ── 将 poker/ 根目录加入 sys.path,方便导入同目录模块 ──
|
||
_POKER_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
if _POKER_DIR not in sys.path:
|
||
sys.path.insert(0, _POKER_DIR)
|
||
|
||
# ── 将项目根目录加入 sys.path,方便导入 card_model 包 ──
|
||
_PROJECT_DIR = os.path.dirname(_POKER_DIR)
|
||
if _PROJECT_DIR not in sys.path:
|
||
sys.path.insert(0, _PROJECT_DIR)
|
||
|
||
from env_adapter import (
|
||
HUNL_FULLGAME_STRING,
|
||
BetTranslator,
|
||
extract_env_state,
|
||
CFR_ACTIONS,
|
||
NUM_CFR_ACTIONS,
|
||
STACK_NORMALIZE,
|
||
)
|
||
from card_model.config import NUM_BINS, PAD_TOKEN, BOARD_SIZE
|
||
from card_model.data_generator import extract_cards_from_state
|
||
from card_model.model import CardModel
|
||
from cfr_net import CFRNetwork, CARD_DIM, ENV_DIM, NUM_ACTIONS
|
||
from cfr_buffer import CFRBuffer
|
||
|
||
|
||
|
||
# ───────────────────── 终极版超参数 (128G RAM + 24G VRAM, 24核) ─────────────────────
|
||
|
||
# ── 大规模并行与迭代 ──
|
||
NUM_ITERATIONS = 30000 # 总训练迭代次数(Deep CFR: 高频迭代,少局数)
|
||
GAMES_PER_ITER = 1000 # 每次迭代搜索的局数(快速进入更新阶段)
|
||
# NUM_WORKERS: 留 2 核给主进程(GPU 训练 + 日志 + OS),避免上下文切换惩罚
|
||
# 24 核 - 2 = 22 个 worker,每个 worker 纯 CPU 密集型遍历
|
||
NUM_WORKERS = 22
|
||
|
||
# ── 巨型经验池 (吞噬你的 128G 内存) ──
|
||
# 1条样本大约 1KB 多。500 万条样本约占 6~8 GB 内存。
|
||
# 设为 1000 万,充分利用你的系统内存,避免灾难性遗忘 (Catastrophic Forgetting)
|
||
BUFFER_MAX_SIZE = 10_000_000
|
||
MIN_BUFFER_SIZE_FOR_TRAIN = 200_000 # 初期攒够 2 万条就开始训练
|
||
|
||
# ── 显存利用 (24G VRAM) ──
|
||
# CFR 网络极其小 (3层 MLP, ~100K 参数)。BS=32768 时:
|
||
# Forward 中间激活: ~33 MB, Backward: ~33 MB, Optimizer: ~0.8 MB
|
||
# 总计 < 100 MB,在 24G VRAM 上毫无压力
|
||
# 超大 Batch Size 能让策略收敛平滑,减少策略震荡
|
||
TRAIN_BATCH_SIZE = 16384
|
||
TRAIN_STEPS_PER_ITER = 64 # 每次迭代后,网络更新 64 步
|
||
|
||
# ── 优化器细节 ──
|
||
LEARNING_RATE = 5e-4 # BS 变大,LR 稍微降一点,求稳
|
||
WEIGHT_DECAY = 1e-4 # 防止过拟合
|
||
|
||
|
||
# 训练
|
||
#NUM_ITERATIONS = 100 # 总训练迭代次数
|
||
#GAMES_PER_ITER = 50 # 每个 iteration 跑多少局自对弈
|
||
#TRAIN_BATCH_SIZE = 256 # 训练 mini-batch 大小
|
||
#TRAIN_STEPS_PER_ITER = 20 # 每 iteration 训练多少步
|
||
#MIN_BUFFER_SIZE_FOR_TRAIN = 1000 # Buffer 至少有多少条数据才开始训练
|
||
#LEARNING_RATE = 1e-3 # AdamW 学习率
|
||
#WEIGHT_DECAY = 1e-4 # AdamW 权重衰减
|
||
|
||
# 归一化常量(与 env_features 的 5 维对应)
|
||
# STACK_NORMALIZE = 20000.0 已从 env_adapter 导入(与 Botzone 初始筹码对齐)
|
||
STREET_NORMALIZE = 3.0 # 街道归一化因子 (0/3, 1/3, 2/3, 3/3)
|
||
|
||
CLIP_GRAD_NORM = 1.0 # 梯度裁剪阈值,防止训练不稳定
|
||
|
||
# Card Model 权重路径(设为 None 则使用随机初始化)
|
||
CARD_MODEL_CHECKPOINT = "card_model/data/best_card_model.pt"
|
||
|
||
# 设备
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# ── 检查点配置 ──
|
||
CHECKPOINT_DIR = os.path.join(_POKER_DIR, "checkpoints")
|
||
CHECKPOINT_INTERVAL = 50 # 每 50 个 iteration 保存一次
|
||
KEEP_LAST_N_CHECKPOINTS = 3 # 本地最多保留最近的 3 个检查点文件,防止磁盘撑爆
|
||
|
||
|
||
# ───────────────────── 辅助函数 ─────────────────────
|
||
|
||
def build_env_features(env_info: dict) -> torch.Tensor:
|
||
"""
|
||
将 extract_env_state 返回的字典归一化为 5 维 env_features Tensor。
|
||
|
||
5 维内容:
|
||
[pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position]
|
||
|
||
Args:
|
||
env_info: extract_env_state() 返回的字典
|
||
|
||
Returns:
|
||
[5] 的 FloatTensor
|
||
"""
|
||
features = [
|
||
env_info["pot"] / STACK_NORMALIZE,
|
||
env_info["p0_stack"] / STACK_NORMALIZE,
|
||
env_info["p1_stack"] / STACK_NORMALIZE,
|
||
env_info["street"] / STREET_NORMALIZE,
|
||
float(env_info["position"]),
|
||
]
|
||
return torch.tensor(features, dtype=torch.float32)
|
||
|
||
|
||
# ── CardModel 结果缓存 ──
|
||
# 在扑克博弈树中,只要没有发新牌,牌面概率绝对不变。
|
||
# 以 (hole_cards_tuple, board_cards_tuple) 为 key 缓存 CardModel 输出,
|
||
# 避免在同一个发牌阶段(如 Flop)反复推理。
|
||
CARD_CACHE: Dict[str, torch.Tensor] = {}
|
||
|
||
|
||
def build_card_features(
|
||
card_model: CardModel,
|
||
state,
|
||
) -> torch.Tensor:
|
||
"""
|
||
使用 CardModel 从当前 state 提取 50 维胜率直方图 (card_features)。
|
||
|
||
带字典缓存:只要 hole_cards 和 board_cards 不变,直接返回缓存结果。
|
||
在每次 iteration 开始前由主循环调用 CARD_CACHE.clear() 防止内存泄漏。
|
||
|
||
流程:
|
||
1. 调用 extract_cards_from_state(state) 获取 hole_cards, board_cards
|
||
2. 构造 cache_key,命中则直接返回
|
||
3. 未命中则构造 Tensor 进行推理(在模型所在设备上进行)
|
||
4. 返回 [50] 的 FloatTensor(CPU)
|
||
|
||
Args:
|
||
card_model: 已加载权重的 CardModel(eval 模式)
|
||
state: OpenSpiel State
|
||
|
||
Returns:
|
||
[50] 的 FloatTensor(脱离计算图,在 CPU 上)
|
||
"""
|
||
global CARD_CACHE
|
||
|
||
hole_cards, board_cards = extract_cards_from_state(state)
|
||
|
||
# 将卡牌列表转为可哈希的 key
|
||
cache_key = f"{tuple(hole_cards)}_{tuple(board_cards)}"
|
||
|
||
if cache_key in CARD_CACHE:
|
||
return CARD_CACHE[cache_key]
|
||
|
||
# 构造输入 Tensor(使用 card_model 当前所在设备)
|
||
model_device = next(card_model.parameters()).device
|
||
x_hole = torch.tensor([hole_cards], dtype=torch.int64, device=model_device) # [1, 2]
|
||
# board_cards 不足 5 张时用 PAD_TOKEN 填充
|
||
padded_board = board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards))
|
||
x_board = torch.tensor([padded_board], dtype=torch.int64, device=model_device) # [1, 5]
|
||
|
||
with torch.no_grad():
|
||
_, pred_histogram = card_model(x_hole, x_board) # [1, 50]
|
||
|
||
# 脱离计算图,转到 CPU,去掉 batch 维度
|
||
result = pred_histogram.squeeze(0).cpu()
|
||
CARD_CACHE[cache_key] = result
|
||
return result
|
||
|
||
|
||
# ───────────────────── 核心:外部采样遍历(单进程版本,保留兼容) ─────────────────────
|
||
|
||
def traverse(
|
||
state,
|
||
traversing_player: int,
|
||
card_model: CardModel,
|
||
cfr_net: CFRNetwork,
|
||
buffer: CFRBuffer,
|
||
translator: BetTranslator,
|
||
depth: int = 0,
|
||
) -> float:
|
||
"""
|
||
外部采样 MCCFR 的博弈树遍历函数(单进程版本)。
|
||
|
||
递归地遍历博弈树,在遍历者回合计算每个动作的反事实遗憾值 (Regret),
|
||
并将经验 [info_state, legal_mask, regrets, strategy] 存入 Buffer。
|
||
|
||
=== 三种节点类型的处理 ===
|
||
|
||
1. Terminal Node (终局):
|
||
直接返回 traversing_player 的收益值 state.returns()[traversing_player]
|
||
|
||
2. Chance Node (发牌):
|
||
按概率随机采样一个发牌动作,递归继续
|
||
|
||
3. Player Node (玩家行动):
|
||
提取信息集特征,通过 CFRNetwork 获取当前策略 σ(I,·)
|
||
然后分两种情况:
|
||
|
||
a) 对方回合 (current_player != traversing_player):
|
||
根据 σ(I,·) 的概率分布采样 1 个动作,执行后递归
|
||
→ 这是"外部采样"的含义:只采样 1 条路径,大幅减少计算量
|
||
|
||
b) 遍历者回合 (current_player == traversing_player):
|
||
遍历所有合法动作 a:
|
||
- clone state,执行动作 a,递归得到 v(I, a)
|
||
计算加权期望收益: v(I) = Σ_a σ(I,a) * v(I,a)
|
||
计算遗憾值: regret(I, a) = v(I,a) - v(I)
|
||
将经验存入 Buffer
|
||
|
||
Args:
|
||
state: OpenSpiel 的 State 对象
|
||
traversing_player: 当前遍历者 (0 或 1)
|
||
card_model: Card Model(eval 模式),用于提取牌面特征
|
||
cfr_net: CFR 策略网络(eval 模式),用于获取当前策略
|
||
buffer: CFR 经验回放池
|
||
translator: BetTranslator,CFR 动作 ↔ 引擎动作转换
|
||
|
||
Returns:
|
||
float: 当前节点对 traversing_player 的期望收益
|
||
"""
|
||
|
||
# ── 深度熔断保护:超过极限深度,强制截断博弈树 ──
|
||
if depth >= 40:
|
||
return 0.0
|
||
|
||
# ── 1. Terminal Node: 终局,直接返回收益 ──
|
||
if state.is_terminal():
|
||
return float(state.returns()[traversing_player])
|
||
|
||
# ── 2. Chance Node: 随机发牌 ──
|
||
if state.is_chance_node():
|
||
outcomes = state.chance_outcomes()
|
||
action_list, prob_list = zip(*outcomes)
|
||
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
|
||
child_state = state.clone()
|
||
child_state.apply_action(chance_action)
|
||
utility = traverse(child_state, traversing_player, card_model, cfr_net,
|
||
buffer, translator, depth + 1)
|
||
del child_state
|
||
return utility
|
||
|
||
# ── 3. Player Node: 玩家行动 ──
|
||
current_player = state.current_player()
|
||
|
||
env_info = extract_env_state(state)
|
||
env_features = build_env_features(env_info)
|
||
card_features = build_card_features(card_model, state)
|
||
legal_mask = env_info["legal_mask"]
|
||
|
||
info_state = torch.cat([card_features, env_features], dim=0)
|
||
|
||
# 使用 cfr_net 当前所在设备(数据生成阶段为 CPU)
|
||
net_device = next(cfr_net.parameters()).device
|
||
legal_mask_tensor = torch.tensor(
|
||
[legal_mask], dtype=torch.float32, device=net_device
|
||
)
|
||
card_input = card_features.unsqueeze(0).to(net_device)
|
||
env_input = env_features.unsqueeze(0).to(net_device)
|
||
|
||
with torch.no_grad():
|
||
current_strategy, _ = cfr_net.get_strategy(
|
||
card_input, env_input, legal_mask_tensor
|
||
)
|
||
|
||
strategy_list = current_strategy.squeeze(0).cpu().tolist()
|
||
|
||
# ── 3a. 对方回合 ──
|
||
if current_player != traversing_player:
|
||
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
|
||
if not legal_indices:
|
||
legal_indices = [1]
|
||
|
||
probs = [strategy_list[i] for i in legal_indices]
|
||
prob_sum = sum(probs)
|
||
if prob_sum > 0:
|
||
probs = [p / prob_sum for p in probs]
|
||
else:
|
||
probs = [1.0 / len(legal_indices)] * len(legal_indices)
|
||
|
||
sampled_idx = random.choices(legal_indices, weights=probs, k=1)[0]
|
||
engine_action = translator.cfr_to_engine_action(state, sampled_idx)
|
||
child_state = state.clone()
|
||
child_state.apply_action(engine_action)
|
||
|
||
utility = traverse(child_state, traversing_player, card_model, cfr_net,
|
||
buffer, translator, depth + 1)
|
||
del child_state
|
||
return utility
|
||
|
||
# ── 3b. 遍历者回合 ──
|
||
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
|
||
if not legal_indices:
|
||
return 0.0
|
||
|
||
action_utilities = [0.0] * NUM_CFR_ACTIONS
|
||
|
||
for a in legal_indices:
|
||
child_state = state.clone()
|
||
engine_action = translator.cfr_to_engine_action(child_state, a)
|
||
child_state.apply_action(engine_action)
|
||
action_utilities[a] = traverse(
|
||
child_state, traversing_player, card_model, cfr_net,
|
||
buffer, translator, depth + 1
|
||
)
|
||
del child_state
|
||
|
||
node_utility = 0.0
|
||
for a in legal_indices:
|
||
node_utility += strategy_list[a] * action_utilities[a]
|
||
|
||
regrets = [0.0] * NUM_CFR_ACTIONS
|
||
for a in legal_indices:
|
||
regrets[a] = (action_utilities[a] - node_utility) / STACK_NORMALIZE
|
||
|
||
buffer.add(
|
||
info_state=info_state,
|
||
legal_mask=legal_mask,
|
||
regrets=regrets,
|
||
strategy=strategy_list,
|
||
)
|
||
|
||
return node_utility
|
||
|
||
|
||
# ───────────────────── 多进程 Worker 基础设施 ─────────────────────
|
||
#
|
||
# Worker 进程通过 _init_worker 初始化,在其中自行创建:
|
||
# - pyspiel 游戏实例(不可跨进程序列化)
|
||
# - CPU 上的 CFRNetwork(从主进程传入的 state_dict 加载权重)
|
||
# - CPU 上的 CardModel(从主进程传入的 state_dict 加载权重)
|
||
# - BetTranslator(轻量无状态对象)
|
||
#
|
||
# 权重热更新机制:
|
||
# 模型参数极小(< 500KB),直接通过 worker_traverse_batch 函数参数传递,
|
||
# 无需临时文件中转,彻底消除并发读写冲突和 BrokenProcessPool 风险。
|
||
#
|
||
# IPC 安全机制:
|
||
# - 使用 spawn 启动方式,避免 fork 继承 CUDA Context 导致死锁
|
||
# - Worker 返回经验全部使用 Python 基本类型(list[float]/list[int])
|
||
# - 每个 Worker 批量处理多局游戏,减少 IPC 消息数量,防止管道阻塞
|
||
|
||
# Worker 进程的全局状态
|
||
_WORKER_STATE: Dict[str, Any] = {}
|
||
|
||
|
||
def _init_worker(
|
||
cfr_state_dict: dict,
|
||
card_state_dict: dict,
|
||
) -> None:
|
||
"""
|
||
ProcessPoolExecutor 的 worker 初始化函数。
|
||
|
||
在每个 worker 进程启动时调用一次,创建该进程专属的:
|
||
- OpenSpiel 游戏实例
|
||
- CPU 上的 CFRNetwork(eval 模式)
|
||
- CPU 上的 CardModel(eval 模式)
|
||
- BetTranslator
|
||
|
||
传入的 state_dict 中所有 Tensor 已通过 .cpu() 确保在 CPU 上,
|
||
避免 CUDA 跨进程问题。
|
||
|
||
权重热更新:
|
||
每次 iteration 主进程通过 worker_traverse_batch 的 cfr_state_dict 参数
|
||
直接传递最新权重,Worker 在 batch 开始时 load_state_dict 加载,
|
||
彻底消除临时文件并发读写冲突。
|
||
|
||
Args:
|
||
cfr_state_dict: 主进程 CFRNetwork 的 state_dict(CPU Tensor)
|
||
card_state_dict: 主进程 CardModel 的 state_dict(CPU Tensor)
|
||
"""
|
||
global _WORKER_STATE
|
||
|
||
# 每个 worker 自行创建 OpenSpiel 游戏实例(不可跨进程传递)
|
||
_WORKER_STATE["game"] = pyspiel.load_game(HUNL_FULLGAME_STRING)
|
||
|
||
# 在 CPU 上创建 CFRNetwork 并加载主进程的最新权重
|
||
cfr_net = CFRNetwork(
|
||
card_dim=CARD_DIM,
|
||
env_dim=ENV_DIM,
|
||
num_actions=NUM_ACTIONS,
|
||
)
|
||
cfr_net.load_state_dict(cfr_state_dict)
|
||
cfr_net.eval()
|
||
_WORKER_STATE["cfr_net"] = cfr_net
|
||
|
||
# 在 CPU 上创建 CardModel 并加载权重
|
||
# CardModel 权重在训练期间不会变化,只需初始化时加载一次
|
||
card_model = CardModel()
|
||
card_model.load_state_dict(card_state_dict)
|
||
card_model.eval()
|
||
_WORKER_STATE["card_model"] = card_model
|
||
|
||
# BetTranslator 是无状态对象,可以直接创建
|
||
_WORKER_STATE["translator"] = BetTranslator()
|
||
|
||
|
||
def worker_traverse_batch(
|
||
game_indices: List[int],
|
||
cfr_state_dict: dict,
|
||
) -> Tuple[List[Tuple[int, float]], str]:
|
||
"""
|
||
多进程版本的批量 MCCFR 遍历函数,在 Worker 进程中执行。
|
||
|
||
一个 Worker 一次处理多局游戏(由 game_indices 指定),
|
||
大幅减少 IPC 消息数量,防止管道被海量小消息撑爆。
|
||
|
||
权重通过函数参数直接传入(模型 < 500KB),无需临时文件中转,
|
||
彻底消除并发读写冲突和 BrokenProcessPool 风险。
|
||
|
||
经验数据通过临时文件回传,避免 multiprocessing.Pipe 序列化爆炸:
|
||
Worker 将经验列表 torch.save 到临时 .pt 文件,仅返回文件路径。
|
||
主进程 torch.load 读取后立即删除临时文件。
|
||
|
||
Args:
|
||
game_indices: 要处理的游戏索引列表(长度约 GAMES_PER_ITER/NUM_WORKERS)
|
||
cfr_state_dict: 主进程当前 iteration 的 CFRNetwork 权重(CPU Tensor)
|
||
|
||
Returns:
|
||
game_results: [(game_idx, utility), ...] 每局的索引和收益
|
||
exp_file_path: 经验临时文件路径,主进程需 torch.load 读取后删除
|
||
"""
|
||
global _WORKER_STATE
|
||
global CARD_CACHE
|
||
|
||
# ── 清除本 Worker 上个批次的卡牌缓存,彻底封杀内存泄露! ──
|
||
CARD_CACHE.clear()
|
||
|
||
# ── 直接通过参数加载最新权重,无需文件中转 ──
|
||
_WORKER_STATE["cfr_net"].load_state_dict(cfr_state_dict)
|
||
|
||
game = _WORKER_STATE["game"]
|
||
cfr_net = _WORKER_STATE["cfr_net"]
|
||
card_model = _WORKER_STATE["card_model"]
|
||
translator = _WORKER_STATE["translator"]
|
||
|
||
game_results: List[Tuple[int, float]] = []
|
||
all_experiences: List[Tuple] = []
|
||
|
||
for game_idx in game_indices:
|
||
traversing_player = game_idx % 2
|
||
state = game.new_initial_state()
|
||
|
||
experiences: List[Tuple] = []
|
||
utility = _traverse_worker(
|
||
state, traversing_player, card_model, cfr_net,
|
||
experiences, translator,
|
||
)
|
||
|
||
game_results.append((game_idx, utility))
|
||
all_experiences.extend(experiences)
|
||
|
||
#print(f"Worker 跑完了一局! (Game {game_idx})", flush=True)
|
||
|
||
|
||
# ── 将经验落盘,避免 IPC 管道序列化爆炸 ──
|
||
# 直接 return all_experiences 会导致 multiprocessing.Pipe 被撑爆
|
||
# (单次可达数百 MB 的 Python list 序列化数据),引发 OS OOM 和 BrokenProcessPool。
|
||
# 改用临时文件:torch.save 落盘 → 仅传输文件路径字符串 → 主进程读取后删除。
|
||
exp_file_path = os.path.join(
|
||
tempfile.gettempdir(),
|
||
f"cfr_exp_{uuid.uuid4().hex}.pt",
|
||
)
|
||
torch.save(all_experiences, exp_file_path)
|
||
del all_experiences # 立即释放 Worker 内存,避免文件缓存双倍占用
|
||
|
||
return game_results, exp_file_path
|
||
|
||
|
||
def _traverse_worker(
|
||
state,
|
||
traversing_player: int,
|
||
card_model: CardModel,
|
||
cfr_net: CFRNetwork,
|
||
experiences: List[Tuple],
|
||
translator: BetTranslator,
|
||
depth: int = 0,
|
||
) -> float:
|
||
"""
|
||
Worker 进程内部的递归遍历函数。
|
||
|
||
与 traverse() 逻辑完全一致,唯一区别是将经验追加到
|
||
本地 experiences 列表(而非全局 CFRBuffer),且存入的
|
||
info_state 使用 list[float] 而非 Tensor,以便跨进程传输。
|
||
|
||
内存安全保证:
|
||
- 所有 Tensor 操作在 torch.no_grad() 下进行
|
||
- info_state_tensor.tolist() 后,原始 Tensor 即可被 GC 回收
|
||
- 递归过程中 state.clone() 产生的对象在递归返回后也被 GC
|
||
- 最终返回的 experiences 只包含 Python 基本类型,
|
||
主进程 torch.tensor() 新建时无 grad_fn,buffer.add() 安全
|
||
|
||
设备说明:
|
||
Worker 中所有模型始终在 CPU 上。Tensor 设备由模型参数所在
|
||
设备自动推导(next(cfr_net.parameters()).device),无需显式传递。
|
||
|
||
Args:
|
||
state: OpenSpiel State
|
||
traversing_player: 当前遍历者
|
||
card_model: CardModel(eval,CPU)
|
||
cfr_net: CFRNetwork(eval,CPU)
|
||
experiences: 本地经验暂存列表
|
||
translator: BetTranslator
|
||
|
||
Returns:
|
||
float: 当前节点对 traversing_player 的期望收益
|
||
"""
|
||
# ── 深度熔断保护:超过极限深度,强制截断博弈树 ──
|
||
if depth >= 40:
|
||
return 0.0
|
||
|
||
# ── 1. Terminal Node ──
|
||
if state.is_terminal():
|
||
return float(state.returns()[traversing_player])
|
||
|
||
# ── 2. Chance Node ──
|
||
if state.is_chance_node():
|
||
outcomes = state.chance_outcomes()
|
||
action_list, prob_list = zip(*outcomes)
|
||
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
|
||
child_state = state.clone()
|
||
child_state.apply_action(chance_action)
|
||
utility = _traverse_worker(child_state, traversing_player, card_model, cfr_net,
|
||
experiences, translator, depth + 1)
|
||
del child_state
|
||
return utility
|
||
|
||
# ── 3. Player Node ──
|
||
current_player = state.current_player()
|
||
|
||
env_info = extract_env_state(state)
|
||
env_features = build_env_features(env_info)
|
||
card_features = build_card_features(card_model, state)
|
||
legal_mask = env_info["legal_mask"]
|
||
|
||
info_state_tensor = torch.cat([card_features, env_features], dim=0)
|
||
|
||
# 使用 cfr_net 当前所在设备(Worker 中始终为 CPU)
|
||
net_device = next(cfr_net.parameters()).device
|
||
legal_mask_tensor = torch.tensor(
|
||
[legal_mask], dtype=torch.float32, device=net_device
|
||
)
|
||
card_input = card_features.unsqueeze(0).to(net_device)
|
||
env_input = env_features.unsqueeze(0).to(net_device)
|
||
|
||
with torch.no_grad():
|
||
current_strategy, _ = cfr_net.get_strategy(
|
||
card_input, env_input, legal_mask_tensor
|
||
)
|
||
|
||
strategy_list = current_strategy.squeeze(0).cpu().tolist()
|
||
|
||
# ── 3a. 对方回合 ──
|
||
if current_player != traversing_player:
|
||
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
|
||
if not legal_indices:
|
||
legal_indices = [1]
|
||
|
||
probs = [strategy_list[i] for i in legal_indices]
|
||
prob_sum = sum(probs)
|
||
if prob_sum > 0:
|
||
probs = [p / prob_sum for p in probs]
|
||
else:
|
||
probs = [1.0 / len(legal_indices)] * len(legal_indices)
|
||
|
||
sampled_idx = random.choices(legal_indices, weights=probs, k=1)[0]
|
||
engine_action = translator.cfr_to_engine_action(state, sampled_idx)
|
||
child_state = state.clone()
|
||
child_state.apply_action(engine_action)
|
||
|
||
utility = _traverse_worker(child_state, traversing_player, card_model, cfr_net,
|
||
experiences, translator, depth + 1)
|
||
del child_state
|
||
return utility
|
||
|
||
# ── 3b. 遍历者回合 ──
|
||
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
|
||
if not legal_indices:
|
||
return 0.0
|
||
|
||
action_utilities = [0.0] * NUM_CFR_ACTIONS
|
||
|
||
for a in legal_indices:
|
||
child_state = state.clone()
|
||
engine_action = translator.cfr_to_engine_action(child_state, a)
|
||
child_state.apply_action(engine_action)
|
||
action_utilities[a] = _traverse_worker(
|
||
child_state, traversing_player, card_model, cfr_net,
|
||
experiences, translator, depth + 1,
|
||
)
|
||
del child_state
|
||
|
||
node_utility = 0.0
|
||
for a in legal_indices:
|
||
node_utility += strategy_list[a] * action_utilities[a]
|
||
|
||
regrets = [0.0] * NUM_CFR_ACTIONS
|
||
for a in legal_indices:
|
||
regrets[a] = (action_utilities[a] - node_utility) / STACK_NORMALIZE
|
||
|
||
# 存入本地经验列表(info_state 转为 list[float] 以便跨进程传输)
|
||
# .tolist() 后 info_state_tensor 可被 GC 回收,不会造成内存累积
|
||
info_state_list = info_state_tensor.tolist()
|
||
experiences.append((info_state_list, legal_mask, regrets, strategy_list))
|
||
|
||
return node_utility
|
||
|
||
|
||
# ───────────────────── 检查点保存与恢复 ─────────────────────
|
||
|
||
def save_checkpoint(
|
||
iteration: int,
|
||
cfr_net: CFRNetwork,
|
||
buffer: CFRBuffer,
|
||
optimizer,
|
||
checkpoint_dir: str,
|
||
keep_last_n: int,
|
||
) -> None:
|
||
"""
|
||
原子保存检查点,并清理旧文件。
|
||
|
||
原子操作流程:
|
||
1. 先写入临时文件 ckpt_tmp.pt
|
||
2. 写入成功后 os.replace 重命名为 ckpt_iter_{iteration}.pt
|
||
3. 复制一份 latest_ckpt.pt 作为默认恢复入口
|
||
4. 清理超过 keep_last_n 的旧检查点
|
||
|
||
Args:
|
||
iteration: 当前迭代次数
|
||
cfr_net: CFR 策略网络
|
||
buffer: CFR 经验回放池
|
||
optimizer: AdamW 优化器(可能为 None)
|
||
checkpoint_dir: 检查点保存目录
|
||
keep_last_n: 保留最近 N 个检查点
|
||
"""
|
||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||
|
||
# 构造保存字典
|
||
save_dict = {
|
||
"iteration": iteration,
|
||
"model_state_dict": {k: v.cpu() for k, v in cfr_net.state_dict().items()},
|
||
"buffer_state_dict": buffer.state_dict(),
|
||
}
|
||
if optimizer is not None:
|
||
# 优化器状态中的 Tensor 可能在 GPU 上,需转 CPU
|
||
opt_state = optimizer.state_dict()
|
||
opt_state_cpu = {}
|
||
for k, v in opt_state.items():
|
||
if isinstance(v, torch.Tensor):
|
||
opt_state_cpu[k] = v.cpu()
|
||
elif isinstance(v, dict):
|
||
opt_state_cpu[k] = _recursive_tensor_to_cpu(v)
|
||
else:
|
||
opt_state_cpu[k] = v
|
||
save_dict["optimizer_state_dict"] = opt_state_cpu
|
||
|
||
# 原子保存:先写临时文件,再 rename
|
||
ckpt_path = os.path.join(checkpoint_dir, f"ckpt_iter_{iteration}.pt")
|
||
tmp_path = os.path.join(checkpoint_dir, "ckpt_tmp.pt")
|
||
|
||
torch.save(save_dict, tmp_path)
|
||
os.replace(tmp_path, ckpt_path)
|
||
|
||
# 更新 latest_ckpt.pt(用 copy 而非 symlink,更跨平台)
|
||
latest_path = os.path.join(checkpoint_dir, "latest_ckpt.pt")
|
||
# os.replace 是原子的,先用临时文件再替换
|
||
tmp_latest = os.path.join(checkpoint_dir, "latest_ckpt_tmp.pt")
|
||
torch.save(save_dict, tmp_latest)
|
||
os.replace(tmp_latest, latest_path)
|
||
|
||
print(f"[Checkpoint] 已保存: iter={iteration}, buffer={len(buffer)}")
|
||
|
||
# 清理旧检查点(保留最近 keep_last_n 个)
|
||
_cleanup_old_checkpoints(checkpoint_dir, keep_last_n)
|
||
|
||
|
||
def _recursive_tensor_to_cpu(d: dict) -> dict:
|
||
"""递归地将字典中所有 Tensor 转到 CPU。"""
|
||
out = {}
|
||
for k, v in d.items():
|
||
if isinstance(v, torch.Tensor):
|
||
out[k] = v.cpu()
|
||
elif isinstance(v, dict):
|
||
out[k] = _recursive_tensor_to_cpu(v)
|
||
elif isinstance(v, list):
|
||
out[k] = [
|
||
item.cpu() if isinstance(item, torch.Tensor) else item
|
||
for item in v
|
||
]
|
||
else:
|
||
out[k] = v
|
||
return out
|
||
|
||
|
||
def _cleanup_old_checkpoints(checkpoint_dir: str, keep_last_n: int) -> None:
|
||
"""删除旧的检查点文件,只保留最近 keep_last_n 个。"""
|
||
pattern = os.path.join(checkpoint_dir, "ckpt_iter_*.pt")
|
||
files = glob.glob(pattern)
|
||
if len(files) <= keep_last_n:
|
||
return
|
||
|
||
# 按文件名中的 iteration 数字排序
|
||
def _extract_iter(path: str) -> int:
|
||
basename = os.path.basename(path)
|
||
# ckpt_iter_123.pt -> 123
|
||
num = basename.replace("ckpt_iter_", "").replace(".pt", "")
|
||
try:
|
||
return int(num)
|
||
except ValueError:
|
||
return -1
|
||
|
||
files_with_iter = [(f, _extract_iter(f)) for f in files if _extract_iter(f) >= 0]
|
||
files_with_iter.sort(key=lambda x: x[1])
|
||
|
||
# 删除多余的旧文件
|
||
to_delete = files_with_iter[:-keep_last_n]
|
||
for path, it in to_delete:
|
||
os.remove(path)
|
||
print(f"[Checkpoint] 已删除旧检查点: iter={it}")
|
||
|
||
|
||
def _optimizer_state_to_device(optimizer: torch.optim.Optimizer, device: torch.device) -> None:
|
||
"""将优化器状态中所有 Tensor 移到指定设备。
|
||
|
||
PyTorch 的 optimizer.load_state_dict() 不会自动将 Tensor 移到
|
||
当前参数所在的设备,需要手动处理。这对于从 CPU 检查点恢复到
|
||
GPU 上的场景至关重要。
|
||
|
||
Args:
|
||
optimizer: PyTorch 优化器
|
||
device: 目标设备
|
||
"""
|
||
for state in optimizer.state.values():
|
||
for k, v in state.items():
|
||
if isinstance(v, torch.Tensor):
|
||
state[k] = v.to(device)
|
||
|
||
|
||
def load_checkpoint(
|
||
checkpoint_path: str,
|
||
cfr_net: CFRNetwork,
|
||
buffer: CFRBuffer,
|
||
) -> dict:
|
||
"""
|
||
从检查点文件恢复训练状态。
|
||
|
||
Args:
|
||
checkpoint_path: 检查点文件路径
|
||
cfr_net: CFR 策略网络(将被加载权重)
|
||
buffer: CFR 经验回放池(将被加载数据)
|
||
|
||
Returns:
|
||
dict: 包含 iteration 和 optimizer_state_dict(如果有)的字典
|
||
"""
|
||
print(f"[Checkpoint] 正在加载: {checkpoint_path}")
|
||
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||
|
||
# 恢复模型权重
|
||
cfr_net.load_state_dict(ckpt["model_state_dict"])
|
||
|
||
# 恢复 Buffer
|
||
buffer.load_state_dict(ckpt["buffer_state_dict"])
|
||
|
||
result = {
|
||
"iteration": ckpt.get("iteration", 0),
|
||
}
|
||
if "optimizer_state_dict" in ckpt:
|
||
result["optimizer_state_dict"] = ckpt["optimizer_state_dict"]
|
||
|
||
print(f"[Checkpoint] 恢复成功: iter={result['iteration']}, buffer={len(buffer)}")
|
||
return result
|
||
|
||
|
||
# ───────────────────── 训练一步 ─────────────────────
|
||
|
||
def train_step(
|
||
cfr_net: CFRNetwork,
|
||
buffer: CFRBuffer,
|
||
optimizer: torch.optim.Optimizer,
|
||
device: torch.device,
|
||
) -> Tuple[float, float, float]:
|
||
"""
|
||
从 Buffer 采样一个 mini-batch,训练 CFRNetwork 一步。
|
||
|
||
Loss 组成:
|
||
1. Regret Loss: MSE(预测 regret, 目标 regret),只在 legal_mask=1 的位置计算
|
||
2. Policy Loss: MSE(预测策略, 目标策略),拟合 Buffer 中存储的策略分布
|
||
|
||
总 Loss = Regret Loss + Policy Loss
|
||
|
||
Args:
|
||
cfr_net: CFR 策略网络
|
||
buffer: CFR 经验回放池
|
||
optimizer: AdamW 优化器
|
||
device: 运行设备
|
||
|
||
Returns:
|
||
(total_loss, regret_loss, policy_loss): 三个 float 标量
|
||
"""
|
||
cfr_net.train()
|
||
|
||
info_states, legal_masks, target_regrets, target_strategies = buffer.sample(
|
||
TRAIN_BATCH_SIZE
|
||
)
|
||
|
||
info_states = info_states.to(device)
|
||
legal_masks = legal_masks.to(device)
|
||
target_regrets = target_regrets.to(device)
|
||
target_strategies = target_strategies.to(device)
|
||
|
||
card_features = info_states[:, :CARD_DIM]
|
||
env_features = info_states[:, CARD_DIM:]
|
||
|
||
pred_regrets, pred_policy_logits = cfr_net(card_features, env_features)
|
||
|
||
# ── Loss 1: Regret Loss ──
|
||
regret_diff = (pred_regrets - target_regrets) ** 2
|
||
regret_loss = (regret_diff * legal_masks).sum() / legal_masks.sum().clamp(min=1.0)
|
||
|
||
# ── Loss 2: Policy Loss ──
|
||
masked_logits = pred_policy_logits.masked_fill(legal_masks == 0, float("-inf"))
|
||
pred_strategy = F.softmax(masked_logits, dim=-1)
|
||
|
||
all_illegal = (legal_masks.sum(dim=-1, keepdim=True) == 0)
|
||
num_legal = legal_masks.sum(dim=-1, keepdim=True).clamp(min=1.0)
|
||
uniform = legal_masks / num_legal
|
||
pred_strategy = torch.where(all_illegal, uniform, pred_strategy)
|
||
|
||
policy_diff = (pred_strategy - target_strategies) ** 2
|
||
policy_loss = (policy_diff * legal_masks).sum() / legal_masks.sum().clamp(min=1.0)
|
||
|
||
total_loss = regret_loss + policy_loss
|
||
|
||
optimizer.zero_grad()
|
||
total_loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(cfr_net.parameters(), max_norm=CLIP_GRAD_NORM)
|
||
optimizer.step()
|
||
|
||
return (
|
||
total_loss.item(),
|
||
regret_loss.item(),
|
||
policy_loss.item(),
|
||
)
|
||
|
||
|
||
# ───────────────────── 主训练循环 ─────────────────────
|
||
|
||
def main():
|
||
"""
|
||
MCCFR 训练主循环(多进程版本)。
|
||
|
||
每个 iteration 分两个阶段:
|
||
阶段 A (Data Generation) — 多进程并行:
|
||
- 使用 ProcessPoolExecutor(mp_context=spawn) 并发执行
|
||
- 进程池跨 iteration 复用,通过共享内存 dict 热更新权重
|
||
- 每个 Worker 批量处理多局,减少 IPC 消息数量
|
||
- 主进程汇总所有 Worker 的经验到全局 CFRBuffer
|
||
|
||
阶段 B (Network Training) — 单进程 GPU:
|
||
- 如果 Buffer 中数据量足够,进行 mini-batch 训练
|
||
- 更新 Regret Head 和 Policy Head
|
||
|
||
日志输出:
|
||
- 每 iteration 打印 Buffer 大小、Loss 变化
|
||
"""
|
||
# ── 强制使用 spawn 启动方式,防止 fork 继承 CUDA Context 死锁 ──
|
||
# Linux 默认 fork,fork 后子进程继承父进程的 CUDA Context,
|
||
# 但 CUDA 运行时内部锁状态不一致,任何后续 CUDA 操作都会永久死锁。
|
||
# spawn 从零启动 Python 解释器,完全隔离 CUDA 状态。
|
||
# 必须在 if __name__ == "__main__" 保护内、任何 CUDA 操作之前调用。
|
||
mp.set_start_method('spawn', force=True)
|
||
|
||
print("=" * 70)
|
||
print(" MCCFR Self-Play Training Pipeline (Multi-Process)")
|
||
print(" External Sampling + Deep CFR")
|
||
print(f" NUM_WORKERS = {NUM_WORKERS}")
|
||
print(f" Start method = spawn")
|
||
print("=" * 70)
|
||
|
||
# ── 1. 初始化 OpenSpiel 环境 ──
|
||
game = pyspiel.load_game(HUNL_FULLGAME_STRING)
|
||
translator = BetTranslator()
|
||
print(f"\n[环境] 游戏加载成功: {HUNL_FULLGAME_STRING}")
|
||
|
||
# ── 2. 初始化 Card Model(先放 CPU,数据生成阶段用 CPU 推理) ──
|
||
card_model = CardModel() # 初始在 CPU
|
||
if CARD_MODEL_CHECKPOINT and os.path.isfile(CARD_MODEL_CHECKPOINT):
|
||
ckpt = torch.load(CARD_MODEL_CHECKPOINT, map_location="cpu", weights_only=False)
|
||
card_model.load_state_dict(ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt)
|
||
print(f"[CardModel] 已加载权重: {CARD_MODEL_CHECKPOINT}")
|
||
else:
|
||
print(f"[CardModel] 使用随机初始化(未提供 checkpoint)")
|
||
card_model.eval()
|
||
|
||
# ── 3. 初始化 CFR Network(先放 CPU,数据生成阶段用 CPU 推理) ──
|
||
cfr_net = CFRNetwork(
|
||
card_dim=CARD_DIM,
|
||
env_dim=ENV_DIM,
|
||
num_actions=NUM_ACTIONS,
|
||
) # 初始在 CPU
|
||
total_params = sum(p.numel() for p in cfr_net.parameters())
|
||
print(f"[CFRNetwork] 参数量: {total_params:,},初始设备: CPU")
|
||
|
||
# ── 4. 初始化 Buffer ──
|
||
buffer = CFRBuffer(max_size=BUFFER_MAX_SIZE)
|
||
print(f"[Buffer] 容量: {BUFFER_MAX_SIZE:,}")
|
||
|
||
# Optimizer 延迟到首次阶段 B 时创建(此时 cfr_net 已移至 GPU)
|
||
# 但如果从检查点恢复且存档中有 optimizer_state_dict,则提前初始化
|
||
optimizer = None
|
||
|
||
# ── 检查点恢复 ──
|
||
# 检查 latest_ckpt.pt 是否存在,存在则自动恢复
|
||
start_iteration = 1
|
||
latest_ckpt_path = os.path.join(CHECKPOINT_DIR, "latest_ckpt.pt")
|
||
if os.path.isfile(latest_ckpt_path):
|
||
ckpt_info = load_checkpoint(latest_ckpt_path, cfr_net, buffer)
|
||
start_iteration = ckpt_info["iteration"] + 1
|
||
|
||
# 如果存档中有 optimizer 状态,必须现在就初始化并恢复
|
||
# 原因:optimizer 的动量缓冲区绑定在参数对象上,
|
||
# 必须在 cfr_net 移到 GPU 后、任何训练之前加载
|
||
if "optimizer_state_dict" in ckpt_info:
|
||
# 先将网络移至 GPU(optimizer 的参数必须在正确设备上)
|
||
cfr_net.to(DEVICE)
|
||
optimizer = torch.optim.AdamW(
|
||
cfr_net.parameters(),
|
||
lr=LEARNING_RATE,
|
||
weight_decay=WEIGHT_DECAY,
|
||
)
|
||
# 加载优化器状态(可能包含 GPU 上的 Tensor,需映射到正确设备)
|
||
optimizer.load_state_dict(ckpt_info["optimizer_state_dict"])
|
||
# 将优化器状态中的所有 Tensor 移到与参数相同的设备
|
||
_optimizer_state_to_device(optimizer, DEVICE)
|
||
# 网络移回 CPU,等阶段 A 再用
|
||
cfr_net.to("cpu")
|
||
print(f"[Optimizer] 已从检查点恢复优化器状态, device={DEVICE}")
|
||
|
||
print(f"[Checkpoint] 将从 Iteration {start_iteration} 继续训练")
|
||
else:
|
||
print("[Checkpoint] 未找到检查点,从头开始训练")
|
||
|
||
# ── 5. 初始化 CPU state_dict(给 Worker 首次初始化用) ──
|
||
# 必须在创建进程池之前转为 CPU,避免 CUDA Tensor 跨进程传输
|
||
cfr_state_dict_cpu = {k: v.cpu() for k, v in cfr_net.state_dict().items()}
|
||
card_state_dict_cpu = {k: v.cpu() for k, v in card_model.state_dict().items()}
|
||
|
||
# ── 6. 创建进程池(跨 iteration 复用,避免重复初始化开销) ──
|
||
# mp_context=spawn 确保子进程不继承 CUDA Context
|
||
# 使用 spawn context 创建 ProcessPoolExecutor,而非全局 set_start_method
|
||
spawn_ctx = mp.get_context('spawn')
|
||
|
||
print(f"\n{'='*70}")
|
||
print(f" 开始训练: {NUM_ITERATIONS} iterations, "
|
||
f"每 iter {GAMES_PER_ITER} 局自对弈, "
|
||
f"{NUM_WORKERS} workers 并行")
|
||
print(f"{'='*70}")
|
||
|
||
# 将 GAMES_PER_ITER 局均匀分配给 NUM_WORKERS 个 Worker
|
||
# 每个 Worker 一次处理一批游戏,减少 IPC 消息数量
|
||
game_indices_all = list(range(GAMES_PER_ITER))
|
||
# 计算每个 Worker 分配的游戏索引
|
||
chunks: List[List[int]] = [[] for _ in range(NUM_WORKERS)]
|
||
for i, gi in enumerate(game_indices_all):
|
||
chunks[i % NUM_WORKERS].append(gi)
|
||
|
||
with ProcessPoolExecutor(
|
||
max_workers=NUM_WORKERS,
|
||
initializer=_init_worker,
|
||
initargs=(cfr_state_dict_cpu, card_state_dict_cpu),
|
||
mp_context=spawn_ctx,
|
||
) as executor:
|
||
|
||
for iteration in range(start_iteration, NUM_ITERATIONS + 1):
|
||
|
||
# ──────────── 阶段 A: 数据生成 — 多进程 ────────────
|
||
|
||
# 动态设备切换:数据生成阶段,模型放 CPU,避免 Batch=1 的 GPU PCIe 延迟
|
||
card_model.to("cpu")
|
||
cfr_net.to("cpu")
|
||
cfr_net.eval()
|
||
|
||
# 清除 CardModel 缓存,防止内存泄漏
|
||
CARD_CACHE.clear()
|
||
|
||
# 获取当前网络权重,直接通过函数参数传递给 Worker(无需临时文件)
|
||
cfr_state_dict_cpu = {k: v.cpu() for k, v in cfr_net.state_dict().items()}
|
||
|
||
# 提交批量任务(每个 Worker 处理一批游戏,权重直接参数传递)
|
||
futures = {}
|
||
for worker_id, chunk in enumerate(chunks):
|
||
if not chunk:
|
||
continue
|
||
future = executor.submit(worker_traverse_batch, chunk, cfr_state_dict_cpu)
|
||
futures[future] = worker_id
|
||
|
||
# 收集结果并汇总经验
|
||
game_results = []
|
||
total_experiences = 0
|
||
|
||
for future in as_completed(futures):
|
||
worker_id = futures[future]
|
||
try:
|
||
worker_game_results, exp_file_path = future.result()
|
||
except Exception as e:
|
||
print(f"[警告] Worker {worker_id} 执行失败: {e}")
|
||
continue
|
||
|
||
# 记录对局收益
|
||
for game_idx, utility in worker_game_results:
|
||
traversing_player = game_idx % 2
|
||
game_results.append((traversing_player, utility))
|
||
|
||
# 从临时文件加载经验(避免 IPC 管道爆炸)
|
||
try:
|
||
experiences = torch.load(exp_file_path, map_location="cpu", weights_only=False)
|
||
except Exception as e:
|
||
print(f"[警告] Worker {worker_id} 经验文件加载失败: {e}")
|
||
continue
|
||
finally:
|
||
# 无论加载成功与否,立即删除临时文件,防止 /tmp 堆积
|
||
if os.path.isfile(exp_file_path):
|
||
os.remove(exp_file_path)
|
||
|
||
# 将 worker 返回的经验写入全局 Buffer
|
||
for info_state_list, legal_mask, regrets, strategy in experiences:
|
||
buffer.add(
|
||
info_state=info_state_list,
|
||
legal_mask=legal_mask,
|
||
regrets=regrets,
|
||
strategy=strategy,
|
||
)
|
||
|
||
total_experiences += len(experiences)
|
||
del experiences # 立即释放,避免主进程内存峰值
|
||
|
||
# 统计本 iteration 数据生成的概况
|
||
p0_utils = [u for p, u in game_results if p == 0]
|
||
p1_utils = [u for p, u in game_results if p == 1]
|
||
avg_p0 = sum(p0_utils) / len(p0_utils) if p0_utils else 0.0
|
||
avg_p1 = sum(p1_utils) / len(p1_utils) if p1_utils else 0.0
|
||
|
||
# ──────────── 阶段 B: 网络训练 — GPU ────────────
|
||
|
||
# 动态设备切换:训练阶段,cfr_net 移至 GPU 利用大 Batch 高效训练
|
||
cfr_net.to(DEVICE)
|
||
|
||
# 首次进入训练阶段时创建 optimizer(此时参数已在 GPU 上)
|
||
if optimizer is None:
|
||
optimizer = torch.optim.AdamW(
|
||
cfr_net.parameters(),
|
||
lr=LEARNING_RATE,
|
||
weight_decay=WEIGHT_DECAY,
|
||
)
|
||
print(f"[Optimizer] AdamW lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}, device={DEVICE}")
|
||
|
||
avg_total_loss = 0.0
|
||
avg_regret_loss = 0.0
|
||
avg_policy_loss = 0.0
|
||
trained_steps = 0
|
||
|
||
if len(buffer) >= MIN_BUFFER_SIZE_FOR_TRAIN:
|
||
for step in range(TRAIN_STEPS_PER_ITER):
|
||
total_loss, regret_loss, policy_loss = train_step(
|
||
cfr_net, buffer, optimizer, DEVICE
|
||
)
|
||
avg_total_loss += total_loss
|
||
avg_regret_loss += regret_loss
|
||
avg_policy_loss += policy_loss
|
||
trained_steps += 1
|
||
|
||
avg_total_loss /= trained_steps
|
||
avg_regret_loss /= trained_steps
|
||
avg_policy_loss /= trained_steps
|
||
|
||
# ──────────── 打印日志 ────────────
|
||
|
||
buffer_size = len(buffer)
|
||
loss_str = (
|
||
f"Total={avg_total_loss:.6f} "
|
||
f"Regret={avg_regret_loss:.6f} "
|
||
f"Policy={avg_policy_loss:.6f}"
|
||
if trained_steps > 0
|
||
else "训练未开始(Buffer 不足)"
|
||
)
|
||
|
||
print(
|
||
f"[Iter {iteration:3d}/{NUM_ITERATIONS}] "
|
||
f"Buffer={buffer_size:6d} "
|
||
f"Exp={total_experiences:6d} "
|
||
f"P0_util={avg_p0:+.1f} P1_util={avg_p1:+.1f} "
|
||
f"Loss: {loss_str}"
|
||
)
|
||
|
||
# ──────────── 定时保存检查点 ────────────
|
||
|
||
if iteration % CHECKPOINT_INTERVAL == 0:
|
||
# 保存时网络在 GPU 上,save_checkpoint 会将权重转到 CPU
|
||
save_checkpoint(
|
||
iteration=iteration,
|
||
cfr_net=cfr_net,
|
||
buffer=buffer,
|
||
optimizer=optimizer,
|
||
checkpoint_dir=CHECKPOINT_DIR,
|
||
keep_last_n=KEEP_LAST_N_CHECKPOINTS,
|
||
)
|
||
|
||
# ── 8. 训练结束,保存最终检查点 ──
|
||
save_checkpoint(
|
||
iteration=NUM_ITERATIONS,
|
||
cfr_net=cfr_net,
|
||
buffer=buffer,
|
||
optimizer=optimizer,
|
||
checkpoint_dir=CHECKPOINT_DIR,
|
||
keep_last_n=KEEP_LAST_N_CHECKPOINTS,
|
||
)
|
||
# 同时保存一份到根目录(向后兼容)
|
||
save_path = os.path.join(_POKER_DIR, "cfr_net_checkpoint.pt")
|
||
torch.save({
|
||
"model_state_dict": {k: v.cpu() for k, v in cfr_net.state_dict().items()},
|
||
"optimizer_state_dict": optimizer.state_dict() if optimizer else {},
|
||
"iteration": NUM_ITERATIONS,
|
||
"buffer_size": len(buffer),
|
||
}, save_path)
|
||
print(f"[保存] 模型已保存到: {save_path}")
|
||
|
||
print(f"\n{'='*70}")
|
||
print(f" 训练完成!共 {NUM_ITERATIONS} iterations")
|
||
print(f" 最终 Buffer 大小: {len(buffer)}")
|
||
print(f"{'='*70}")
|
||
|
||
|
||
# ───────────────────── 入口 ─────────────────────
|
||
|
||
if __name__ == "__main__":
|
||
main()
|