446 lines
16 KiB
Python
446 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
eval_elo.py — Deep CFR 模型对抗评估脚本
|
||
|
||
让两个不同轮次的 checkpoint 模型 (Model A / Model B) 进行 100,000 局对战,
|
||
计算 Model A 对 Model B 的百手赢率 (bb/100)。
|
||
|
||
=== 核心设计 ===
|
||
|
||
1. 多进程并发: ProcessPoolExecutor(mp_context='spawn'), 22 Workers
|
||
2. 绝对公平: 50,000 局 Model A 做 P0(小盲), 50,000 局 Model A 做 P1(大盲)
|
||
3. 严格推理: 使用 avg_strategy + legal_mask 归一化 + numpy.random.choice 采样
|
||
4. bb/100 指标: (总筹码收益 / BB) / (总局数 / 100)
|
||
|
||
用法:
|
||
python eval_elo.py --model_a ckpt_10000.pt --model_b ckpt_5000.pt
|
||
python eval_elo.py --model_a checkpoints/ckpt_iter_10000.pt --model_b checkpoints/ckpt_iter_5000.pt --num_games 10000
|
||
"""
|
||
|
||
# ── 必须在 import torch 之前,锁死 C++ 线性代数库多线程,防止 spawn 模式下 OpenMP 死锁 ──
|
||
import os
|
||
os.environ["OMP_NUM_THREADS"] = "1"
|
||
os.environ["MKL_NUM_THREADS"] = "1"
|
||
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
||
|
||
import argparse
|
||
import random
|
||
import multiprocessing as mp
|
||
from typing import Dict, List, Tuple
|
||
|
||
import numpy as np
|
||
import torch
|
||
import pyspiel
|
||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||
|
||
torch.set_num_threads(1)
|
||
|
||
|
||
# ── 将 poker/ 根目录加入 sys.path ──
|
||
import sys
|
||
_POKER_DIR = os.path.abspath(os.path.dirname(__file__))
|
||
if _POKER_DIR not in sys.path:
|
||
sys.path.insert(0, _POKER_DIR)
|
||
|
||
from env_adapter import (
|
||
HUNL_SB_BB_STRING,
|
||
BetTranslator,
|
||
extract_env_state,
|
||
CFR_ACTIONS,
|
||
NUM_CFR_ACTIONS,
|
||
STACK_NORMALIZE,
|
||
)
|
||
from card_model.config import PAD_TOKEN, BOARD_SIZE
|
||
from card_model.data_generator import extract_cards_from_state
|
||
from card_model.model import CardModel
|
||
from cfr_net import CFRNetwork, CARD_DIM, ENV_DIM, NUM_ACTIONS
|
||
|
||
|
||
# ───────────────────── 常量 ─────────────────────
|
||
|
||
STREET_NORMALIZE = 3.0
|
||
BIG_BLIND = 100 # BB = 100 筹码
|
||
NUM_WORKERS = 22 # 并行 Worker 数
|
||
|
||
CARD_MODEL_CHECKPOINT = os.path.join(_POKER_DIR, "card_model", "data", "best_card_model.pt")
|
||
|
||
|
||
# ───────────────────── 全局 Worker 状态 ─────────────────────
|
||
|
||
_WORKER_STATE: Dict = {}
|
||
|
||
|
||
# ───────────────────── 特征构建 ─────────────────────
|
||
|
||
def build_env_features(env_info: dict) -> torch.Tensor:
|
||
"""将 extract_env_state 返回的字典归一化为 5 维 env_features Tensor。"""
|
||
features = [
|
||
env_info["pot"] / STACK_NORMALIZE,
|
||
env_info["p0_stack"] / STACK_NORMALIZE,
|
||
env_info["p1_stack"] / STACK_NORMALIZE,
|
||
env_info["street"] / STREET_NORMALIZE,
|
||
float(env_info["position"]),
|
||
]
|
||
return torch.tensor(features, dtype=torch.float32)
|
||
|
||
|
||
def build_card_features(card_model: CardModel, state) -> torch.Tensor:
|
||
"""
|
||
使用 CardModel 从当前 state 提取 50 维胜率直方图 (card_features)。
|
||
|
||
必须使用 extract_cards_from_state(state, player_id=current_player)
|
||
提取当前行动玩家的底牌,绝不能看错牌!
|
||
"""
|
||
current_player = state.current_player()
|
||
hole_cards, board_cards = extract_cards_from_state(state, player_id=current_player)
|
||
|
||
model_device = next(card_model.parameters()).device
|
||
x_hole = torch.tensor([hole_cards], dtype=torch.int64, device=model_device) # [1, 2]
|
||
padded_board = board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards))
|
||
x_board = torch.tensor([padded_board], dtype=torch.int64, device=model_device) # [1, 5]
|
||
|
||
with torch.no_grad():
|
||
_, pred_histogram = card_model(x_hole, x_board) # [1, 50]
|
||
|
||
return pred_histogram.squeeze(0).cpu()
|
||
|
||
|
||
# ───────────────────── Worker 初始化 ─────────────────────
|
||
|
||
def _init_worker(
|
||
model_a_state_dict: dict,
|
||
model_b_state_dict: dict,
|
||
card_state_dict: dict,
|
||
) -> None:
|
||
"""
|
||
ProcessPoolExecutor 的 Worker 初始化函数。
|
||
|
||
在每个 Worker 进程启动时调用一次,创建该进程专属的:
|
||
- OpenSpiel 游戏实例 (标准 P0=SB, P1=BB)
|
||
- CPU 上的两个 CFRNetwork(Model A / Model B,eval 模式)
|
||
- CPU 上的 CardModel(eval 模式)
|
||
- BetTranslator
|
||
|
||
所有模型权重通过 state_dict 参数传入(CPU Tensor),
|
||
避免 CUDA 跨进程问题。
|
||
"""
|
||
global _WORKER_STATE
|
||
|
||
# 每个 Worker 自行创建 OpenSpiel 游戏实例(不可跨进程传递)
|
||
_WORKER_STATE["game"] = pyspiel.load_game(HUNL_SB_BB_STRING) # P0=SB(50), P1=BB(100)
|
||
|
||
# 在 CPU 上创建 Model A 的 CFRNetwork 并加载权重
|
||
cfr_net_a = CFRNetwork(
|
||
card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS,
|
||
)
|
||
cfr_net_a.load_state_dict(model_a_state_dict)
|
||
cfr_net_a.eval()
|
||
_WORKER_STATE["cfr_net_a"] = cfr_net_a
|
||
|
||
# 在 CPU 上创建 Model B 的 CFRNetwork 并加载权重
|
||
cfr_net_b = CFRNetwork(
|
||
card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS,
|
||
)
|
||
cfr_net_b.load_state_dict(model_b_state_dict)
|
||
cfr_net_b.eval()
|
||
_WORKER_STATE["cfr_net_b"] = cfr_net_b
|
||
|
||
# 在 CPU 上创建 CardModel 并加载权重
|
||
# CardModel 是共享的——两个模型用相同的牌面编码器
|
||
card_model = CardModel()
|
||
card_model.load_state_dict(card_state_dict)
|
||
card_model.eval()
|
||
_WORKER_STATE["card_model"] = card_model
|
||
|
||
# BetTranslator 是无状态对象,可以直接创建
|
||
_WORKER_STATE["translator"] = BetTranslator()
|
||
|
||
|
||
# ───────────────────── 单步推理 ─────────────────────
|
||
|
||
def _choose_action(
|
||
cfr_net: CFRNetwork,
|
||
card_model: CardModel,
|
||
translator: BetTranslator,
|
||
state,
|
||
) -> int:
|
||
"""
|
||
使用 avg_strategy 进行动作采样,返回引擎原生 action ID。
|
||
|
||
严格规则:
|
||
1. 使用 extract_cards_from_state(state, player_id=current_player) 提取底牌
|
||
2. 调用 get_strategy() 后丢弃第一个返回值,使用第二个返回值 avg_strategy
|
||
3. 绝对禁止 argmax!用 avg_strategy * legal_mask 归一化后 numpy.random.choice 采样
|
||
4. 用 BetTranslator.cfr_to_engine_action() 转换为引擎动作
|
||
"""
|
||
current_player = state.current_player()
|
||
|
||
# 提取环境特征
|
||
env_info = extract_env_state(state, translator)
|
||
env_features = build_env_features(env_info) # [5]
|
||
card_features = build_card_features(card_model, state) # [50]
|
||
legal_mask = env_info["legal_mask"] # list[int], 长度5
|
||
|
||
# 构造网络输入(加 batch 维度)
|
||
card_input = card_features.unsqueeze(0) # [1, 50]
|
||
env_input = env_features.unsqueeze(0) # [1, 5]
|
||
legal_mask_tensor = torch.tensor([legal_mask], dtype=torch.float32) # [1, 5]
|
||
|
||
# ── 核心:调用 get_strategy,丢弃 current_strategy,只用 avg_strategy ──
|
||
with torch.no_grad():
|
||
_, avg_strategy = cfr_net.get_strategy(card_input, env_input, legal_mask_tensor)
|
||
|
||
# 取出 [5] 向量,转为 numpy
|
||
avg_strat = avg_strategy.squeeze(0).cpu().numpy() # [5], numpy array
|
||
|
||
# ── avg_strategy 与 legal_mask 相乘,重新归一化 ──
|
||
legal_mask_np = np.array(legal_mask, dtype=np.float64)
|
||
masked_probs = avg_strat.astype(np.float64) * legal_mask_np
|
||
|
||
prob_sum = masked_probs.sum()
|
||
if prob_sum > 1e-9:
|
||
masked_probs = masked_probs / prob_sum
|
||
else:
|
||
# 极端兜底:在合法动作上均匀分布
|
||
num_legal = int(legal_mask_np.sum())
|
||
if num_legal > 0:
|
||
masked_probs = legal_mask_np / num_legal
|
||
else:
|
||
masked_probs[1] = 1.0 # 最后兜底:CALL
|
||
|
||
# ── 使用 numpy.random.choice 按概率分布采样 ──
|
||
cfr_action_idx = int(np.random.choice(NUM_CFR_ACTIONS, p=masked_probs))
|
||
|
||
# ── 通过 BetTranslator 将 CFR 离散动作映射为引擎原生 action ID ──
|
||
engine_action = translator.cfr_to_engine_action(state, cfr_action_idx)
|
||
return engine_action
|
||
|
||
|
||
# ───────────────────── Worker 对战函数 ─────────────────────
|
||
|
||
def worker_play_games(
|
||
num_games: int,
|
||
model_a_is_p0: bool,
|
||
) -> float:
|
||
"""
|
||
Worker 进程中的对战函数。
|
||
|
||
执行 num_games 局对战,返回 Model A 的总筹码收益。
|
||
|
||
始终使用标准游戏 (P0=SB, P1=BB),通过 nets 列表的顺序实现位置轮换:
|
||
- model_a_is_p0=True: Model A = P0(SB), Model B = P1(BB)
|
||
- model_a_is_p0=False: Model A = P1(BB), Model B = P0(SB)
|
||
|
||
在 OpenSpiel 中,state.returns()[0] 始终是 P0 的收益,
|
||
state.returns()[1] 始终是 P1 的收益。
|
||
因此需要根据 model_a_is_p0 将 P0/P1 收益映射到 Model A 的收益。
|
||
"""
|
||
global _WORKER_STATE
|
||
|
||
cfr_net_a = _WORKER_STATE["cfr_net_a"]
|
||
cfr_net_b = _WORKER_STATE["cfr_net_b"]
|
||
card_model = _WORKER_STATE["card_model"]
|
||
translator = _WORKER_STATE["translator"]
|
||
game = _WORKER_STATE["game"]
|
||
|
||
if model_a_is_p0:
|
||
# Model A 坐在 P0(SB) 位置
|
||
nets = [cfr_net_a, cfr_net_b]
|
||
else:
|
||
# Model A 坐在 P1(BB) 位置
|
||
nets = [cfr_net_b, cfr_net_a]
|
||
|
||
total_chips_won_by_a = 0.0
|
||
|
||
for _ in range(num_games):
|
||
state = game.new_initial_state()
|
||
|
||
while not state.is_terminal():
|
||
# ── Chance Node: 随机发牌 ──
|
||
if state.is_chance_node():
|
||
outcomes = state.chance_outcomes()
|
||
action_list, prob_list = zip(*outcomes)
|
||
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
|
||
state.apply_action(chance_action)
|
||
continue
|
||
|
||
# ── Player Node: 由对应模型推理选动作 ──
|
||
current_player = state.current_player()
|
||
current_net = nets[current_player]
|
||
|
||
engine_action = _choose_action(current_net, card_model, translator, state)
|
||
state.apply_action(engine_action)
|
||
|
||
# ── 终局:收集 Model A 的筹码收益 ──
|
||
returns = state.returns()
|
||
if model_a_is_p0:
|
||
# Model A 是 P0
|
||
total_chips_won_by_a += returns[0]
|
||
else:
|
||
# Model A 是 P1
|
||
total_chips_won_by_a += returns[1]
|
||
|
||
return total_chips_won_by_a
|
||
|
||
|
||
# ───────────────────── 加载模型权重 ─────────────────────
|
||
|
||
def load_cfr_state_dict(checkpoint_path: str) -> dict:
|
||
"""
|
||
加载 CFRNetwork 的 state_dict。
|
||
|
||
支持两种 checkpoint 格式:
|
||
- 训练存档格式: {"model_state_dict": ..., "optimizer_state_dict": ..., "iteration": ...}
|
||
- 纯模型格式: 直接的 state_dict(如 export_model.py 导出的)
|
||
"""
|
||
print(f"[加载] 正在读取: {checkpoint_path}")
|
||
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||
|
||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||
state_dict = ckpt["model_state_dict"]
|
||
iter_info = ckpt.get("iteration", "未知")
|
||
print(f"[加载] 训练存档格式, iteration={iter_info}")
|
||
else:
|
||
state_dict = ckpt
|
||
print(f"[加载] 纯模型权重格式")
|
||
|
||
# 确保所有 Tensor 在 CPU 上
|
||
return {k: v.cpu() for k, v in state_dict.items()}
|
||
|
||
|
||
def load_card_state_dict(checkpoint_path: str) -> dict:
|
||
"""加载 CardModel 的 state_dict。"""
|
||
print(f"[加载] CardModel: {checkpoint_path}")
|
||
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||
|
||
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
||
state_dict = ckpt["model_state_dict"]
|
||
else:
|
||
state_dict = ckpt
|
||
|
||
return {k: v.cpu() for k, v in state_dict.items()}
|
||
|
||
|
||
# ───────────────────── 主函数 ─────────────────────
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="Deep CFR 模型对抗评估 — 计算 Model A 对 Model B 的 bb/100"
|
||
)
|
||
parser.add_argument(
|
||
"--model_a", type=str, required=True,
|
||
help="Model A 的 checkpoint 路径"
|
||
)
|
||
parser.add_argument(
|
||
"--model_b", type=str, required=True,
|
||
help="Model B 的 checkpoint 路径"
|
||
)
|
||
parser.add_argument(
|
||
"--num_games", type=int, default=100_000,
|
||
help="总对战局数 (默认 100,000)"
|
||
)
|
||
parser.add_argument(
|
||
"--num_workers", type=int, default=NUM_WORKERS,
|
||
help=f"并行 Worker 数 (默认 {NUM_WORKERS})"
|
||
)
|
||
parser.add_argument(
|
||
"--card_model", type=str, default=CARD_MODEL_CHECKPOINT,
|
||
help=f"CardModel 权重路径 (默认 {CARD_MODEL_CHECKPOINT})"
|
||
)
|
||
args = parser.parse_args()
|
||
|
||
total_games = args.num_games
|
||
num_workers = args.num_workers
|
||
|
||
# ── 1. 加载模型权重 ──
|
||
model_a_state_dict = load_cfr_state_dict(args.model_a)
|
||
model_b_state_dict = load_cfr_state_dict(args.model_b)
|
||
card_state_dict = load_card_state_dict(args.card_model)
|
||
|
||
# ── 2. 分配任务:半数局 Model A 做 P0,半数做 P1 ──
|
||
games_per_side = total_games // 2
|
||
# 每一侧均匀分配给 num_workers 个 Worker
|
||
games_per_worker_per_side = games_per_side // num_workers
|
||
remainder_per_side = games_per_side % num_workers
|
||
|
||
# 构建 Worker 任务列表: (num_games, model_a_is_p0)
|
||
tasks: List[Tuple[int, bool]] = []
|
||
|
||
# Model A 做 P0 (小盲) 的任务
|
||
for i in range(num_workers):
|
||
games = games_per_worker_per_side + (1 if i < remainder_per_side else 0)
|
||
if games > 0:
|
||
tasks.append((games, True))
|
||
|
||
# Model A 做 P1 (大盲) 的任务
|
||
for i in range(num_workers):
|
||
games = games_per_worker_per_side + (1 if i < remainder_per_side else 0)
|
||
if games > 0:
|
||
tasks.append((games, False))
|
||
|
||
actual_total = sum(t[0] for t in tasks)
|
||
actual_a_as_p0 = sum(t[0] for t in tasks if t[1])
|
||
actual_a_as_p1 = sum(t[0] for t in tasks if not t[1])
|
||
|
||
print(f"\n{'='*70}")
|
||
print(f" Deep CFR 模型对抗评估")
|
||
print(f" Model A: {args.model_a}")
|
||
print(f" Model B: {args.model_b}")
|
||
print(f" 总局数: {actual_total} (A做P0: {actual_a_as_p0}, A做P1: {actual_a_as_p1})")
|
||
print(f" Workers: {num_workers} | BB = {BIG_BLIND} 筹码")
|
||
print(f"{'='*70}\n")
|
||
|
||
# ── 3. 启动多进程对战 ──
|
||
spawn_ctx = mp.get_context('spawn')
|
||
|
||
with ProcessPoolExecutor(
|
||
max_workers=num_workers,
|
||
initializer=_init_worker,
|
||
initargs=(model_a_state_dict, model_b_state_dict, card_state_dict),
|
||
mp_context=spawn_ctx,
|
||
) as executor:
|
||
|
||
# 提交所有任务
|
||
futures = []
|
||
for task_idx, (num_games_task, model_a_is_p0) in enumerate(tasks):
|
||
future = executor.submit(worker_play_games, num_games_task, model_a_is_p0)
|
||
futures.append((future, task_idx, num_games_task, model_a_is_p0))
|
||
|
||
# 收集结果
|
||
total_chips_won_by_a = 0.0
|
||
completed = 0
|
||
|
||
for future, task_idx, num_games_task, model_a_is_p0 in futures:
|
||
try:
|
||
chips = future.result()
|
||
total_chips_won_by_a += chips
|
||
completed += 1
|
||
position_str = "A=P0(SB)" if model_a_is_p0 else "A=P1(BB)"
|
||
if completed % 5 == 0 or completed == len(futures):
|
||
print(f" [进度] {completed}/{len(futures)} 任务完成 "
|
||
f"({position_str}, {num_games_task}局, "
|
||
f"Model A 筹码: {chips:+.0f})")
|
||
except Exception as e:
|
||
print(f" [警告] 任务 {task_idx} 执行失败: {e}")
|
||
# 失败的任务不影响整体,仅跳过
|
||
|
||
# ── 4. 计算 bb/100 ──
|
||
# bb/100 = (总筹码收益 / BB) / (总局数 / 100)
|
||
# = (total_chips_won / BIG_BLIND) / (total_games / 100)
|
||
# = total_chips_won * 100 / (BIG_BLIND * total_games)
|
||
bb_per_100 = (total_chips_won_by_a / BIG_BLIND) / (actual_total / 100)
|
||
|
||
# ── 5. 打印结果 ──
|
||
sign = "+" if bb_per_100 >= 0 else ""
|
||
print(f"\n{'='*70}")
|
||
print(f" 评估完成!")
|
||
print(f" 总局数: {actual_total}")
|
||
print(f" Model A 总筹码收益: {total_chips_won_by_a:+.0f}")
|
||
print(f" 经过 {actual_total:,} 局对抗,Model A 对 Model B 的百手赢率为: "
|
||
f"{sign}{bb_per_100:.1f} bb/100")
|
||
print(f"{'='*70}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|