Files
new/CODE_STRUCTURE.md
e2hang ed2fadb625 What
2026-04-20 20:25:35 +08:00

20 KiB
Raw Permalink Blame History

项目代码结构文档

概述

这是一个基于 Deep CFR (Deep CounterfactualRegret Minimization) 的德州扑克AI系统使用PyTorch和OpenSpiel实现。系统分为两大模块

  1. Card Model - 牌面胜率预测模型
  2. CFR Trainer - 策略训练器基于MCCFR算法

一、文件清单与功能说明

1.1 根目录核心文件

文件 功能
cfr_buffer.py CFR经验回放池存储自对弈产生的训练数据
cfr_net.py CFR策略网络双头MLPRegret Head + Policy Head
env_adapter.py OpenSpiel环境适配器动作转换与状态提取
mccfr_trainer.py MCCFR训练主流程自对弈+网络训练
train_card_model.py Card Model训练入口包装脚本

1.2 card_model 包

文件 功能
__init__.py 包入口,导出主要接口
config.py 配置常量(牌数、维度、路径等)
model.py CardModel神经网络双头输出scalar + histogram
data_generator.py Monte Carlo数据生成器
dataset.py PyTorch Dataset封装
train_card_model.py Card Model训练脚本

二、各文件详细说明

2.1 cfr_buffer.py

功能CFR经验回放池存储自我博弈过程中收集的训练数据。

核心数据结构

  • 使用 deque(maxlen=N) 实现固定容量的FIFO滑动窗口
  • 存储4类数据info_state、legal_mask、regrets、strategy

存储内容

- info_state: 信息集特征向量55维 = 牌面50维 + 局势5维
- legal_mask: 6个动作的合法性掩码 [0/1]
- regrets: 累计遗憾值用于训练Regret网络
- strategy: 平均策略概率分布用于训练Policy网络

主要方法

  • add(info_state, legal_mask, regrets, strategy) - 添加一条经验
  • sample(batch_size) - 随机采样一个batch
  • clear() - 清空缓冲区

数据流:被 mccfr_trainer.pytraverse() 函数填充,被 train_step() 函数消费。


2.2 cfr_net.py

功能Deep CFR的策略网络接受信息集特征输出6个动作的遗憾值和平均策略。

网络结构

输入: [card_features(50) + env_features(5)] = 55维
  ↓
MLP骨干: concat_dim(55) → 256 → 256 → 128
  ↓
双输出头:
  - regret_head: 6维无激活regret可为负数
  - policy_head: 6维Softmax输出概率分布

核心方法

  • forward(card_features, env_features)(regrets, policy_logits)
  • get_strategy(card_features, env_features, legal_mask)(current_strategy, avg_strategy)

Regret Matching算法

  1. 将负regret截断为0positive_regret = relu(regret)
  2. 只保留合法动作:masked_regret = positive_regret * legal_mask
  3. 归一化得到即时策略
  4. 若所有regret为0则使用均匀分布

数据流

  • 输入:来自 card_model 的50维胜率直方图 + env_adapter 提取的5维局势特征
  • 输出6维动作策略驱动 env_adapter 的动作选择

2.3 env_adapter.py

功能OpenSpiel引擎与CFR系统之间的桥梁负责

  1. 动作离散化将引擎的连续下注空间映射为6个CFR标准动作
  2. 状态提取:从引擎状态中提取局势特征
  3. 下注映射将CFR动作转换为引擎可执行的动作ID

CFR动作定义对应索引0-5

0: FOLD      - 弃牌
1: CALL      - 跟注/过牌
2: MIN_RAISE - 最小加注
3: HALF_POT  - 半池加注(目标 = max_contribution + 0.5 × pot
4: FULL_POT  - 满池加注(目标 = max_contribution + 1.0 × pot
5: ALL_IN    - 全押

局势特征提取5维

[pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position]
  • pot: 当前底池(两玩家贡献额之和)
  • p0_stack/p1_stack: 双方剩余筹码
  • street: 当前轮次0=Preflop, 1=Flop, 2=Turn, 3=River
  • position: 当前行动玩家0或1

关键类

  • BetTranslator: 负责CFR动作 ↔ 引擎动作的双向转换
    • get_cfr_legal_mask(state) → 6维掩码
    • cfr_to_engine_action(state, cfr_action_idx) → 引擎动作ID
  • extract_env_state(state) → 局势字典

Fallback机制:当筹码不足无法执行目标加注时,自动降级到最大合法动作,保证不崩溃。


2.4 mccfr_trainer.py

功能MCCFR训练主流程协调各模块完成自对弈数据生成和网络训练。

核心流程

主循环 for iteration in 1~N:
  ├── 阶段A: 数据生成
  │   └── for game in 1~M:
  │       traverse() 遍历博弈树
  │       → 填充 CFRBuffer
  │
  └── 阶段B: 网络训练
      └── for step in 1~K:
          train_step() 从Buffer采样训练网络

外部采样MCCFR

  • 对方回合根据当前策略采样1个动作大减少计算量
  • 遍历者回合遍历所有合法动作精确计算每个动作的regret

traverse()函数

  • 递归遍历博弈树
  • 在遍历者节点计算反事实遗憾:regret(I,a) = v(I,a) - v(I)
  • [info_state, legal_mask, regrets, strategy] 存入Buffer

train_step()函数

  • 从Buffer采样mini-batch
  • 计算两个Loss
    • Regret Loss: MSE(预测regret, 目标regret)
    • Policy Loss: MSE(预测策略, 目标策略)
  • 总Loss = Regret Loss + Policy Loss

依赖模块

  • env_adapter - 状态提取和动作转换
  • card_model - 牌面特征提取
  • cfr_net - 策略网络
  • cfr_buffer - 经验池

2.5 train_card_model.py根目录

功能Card Model训练的入口脚本包装 card_model/train_card_model.py

作用:确保 poker/ 目录在sys.path中使导入路径正确。

from card_model.train_card_model import main
main()

2.6 card_model/init.py

功能:包入口,导出主要接口供外部使用。

from .config import *
from .data_generator import extract_cards_from_state, generate_sample
from .dataset import PokerCardDataset
from .model import CardModel

2.7 card_model/config.py

功能:项目配置常量定义。

关键常量

# 牌面编码
NUM_CARDS = 52       # 标准52张牌
PAD_TOKEN = 52      # 填充token
VOCAB_SIZE = 53      # 52 + 1 padding

# 游戏配置对应Botzone规则
HUNL_GAME_STRING = "universal_poker(...)"  # 2人无限德州

# 数据生成
NUM_ROLLOUTS = 1000   # 每样本MC rollout次数
NUM_BINS = 50        # 胜率直方图bin数

# 模型架构
EMBEDDING_DIM = 32
MLP_HIDDEN = [128, 128, 64]

# 训练超参数
NUM_TRAIN_SAMPLES = 2000000  # 200万训练样本
BATCH_SIZE = 16384
LEARNING_RATE = 5e-4
NUM_EPOCHS = 64

2.8 card_model/model.py

功能CardModel神经网络预测牌面胜率和胜率分布。

网络架构

输入:
  - x_hole: [batch, 2]  玩家手牌 IDs (0-51)
  - x_board: [batch, 5]  公共牌 IDs (不足5则用PAD填充)

Embedding层: 53 tokens → 32维
  ↓
Hole Embedding: sum → [batch, 32]
Board Embedding: sum → [batch, 32]
  ↓
拼接: [batch, 64]
  ↓
MLP骨干: 64 → 128 → 128 → 64
  ↓
双输出头:
  - equity_head: Sigmoid → scalar equity (0~1)
  - histogram_head: Softmax → 50维分布

输出

  • pred_equity: 预测胜率(标量)
  • pred_histogram: 胜率直方图50维概率分布

<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>

  • 输入:env_adapter 提取的手牌+公共牌IDs
  • 输出50维胜率直方图 → 作为 cfr_net 的card_features输入

训练目标

  • 最小化 EMD loss直方图之间的Wasserstein距离
  • 最小化 MSE loss胜率预测误差

2.9 card_model/data_generator.py

功能使用Monte Carlo方法生成牌面训练数据。

核心函数

  1. extract_cards_from_state(state)

    • 从OpenSpiel状态提取玩家手牌和公共牌
    • 返回 (hole_cards, board_cards) 列表
  2. _sample_random_state(game)

    • 随机采样一个游戏状态Preflop/Flop/Turn
    • 使用随机发牌 + check/call推进
  3. _monte_carlo_rollout(game, hole_cards, board_cards, used_cards)

    • 随机完成剩余公共牌和对手手牌
    • 返回 1.0(赢)/0.5(平)/0.0(输)
  4. generate_sample(game=None)

    • 生成单个训练样本:
      • x_hole: 手牌IDs [2]
      • x_board: 公共牌IDs [5]padding填充
      • y_equity: 平均胜率float
      • y_histogram: 50维直方图归一化

数据流:被 dataset.py 调用,生成训练/验证数据。


2.10 card_model/dataset.py

功能PyTorch Dataset封装支持多进程并行生成数据和磁盘缓存。

PokerCardDataset类

# 数据结构
x_hole:     [num_samples, 2]      int64   # 手牌IDs
x_board:    [num_samples, 5]      int64   # 公共牌IDspadding
y_equity:   [num_samples, 1]      float32 # 胜率
y_histogram:[num_samples, 50]     float32 # 直方图分布

多进程生成

  • 使用 multiprocessing.Pool 并行调用 generate_sample()
  • num_workers 参数控制并行度

磁盘缓存

  • 首次生成后保存为 .npz 文件
  • 下次运行直接加载,避免重复生成

2.11 card_model/train_card_model.py

功能Card Model训练脚本。

训练流程

for epoch in 1~NUM_EPOCHS:
    # 训练一个epoch
    train_one_epoch(model, train_loader, optimizer)
    # 验证
    validate(model, val_loader)
    # 保存最佳模型
    if val_loss < best_val_loss:
        save best_card_model.pt

Loss函数

emd = emd_loss_1d(pred_histogram, target_histogram)
mse = MSE(pred_equity, target_equity)
loss = emd + lambda * mse

EMD LossEarth Mover's Distance

  • 对于1D有序分布EMD = CDF之间L1距离的平均值
  • cdf_pred = cumsum(pred_histogram)
  • emd = mean(|cdf_pred - cdf_target|)

输出

  • best_card_model.pt - 验证集最佳模型
  • final_card_model.pt - 最终模型

三、文件依赖关系图

┌─────────────────────────────────────────────────────────────────────────────┐
│                           mccfr_trainer.py                                │
│                         (主训练流程,协调者)                             │
└─────────────────────────────────────┬─────────────────────────────────────┘
                                      │
          ┌───────────────────────────┼───────────────────────────┐
          │                           │                           │
          ▼                           ▼                           ▼
┌─────────────────────┐   ┌─────────────────────┐   ┌────────────────────<E29480><E29480><EFBFBD>┐
│    env_adapter.py   │   │    card_model/     │   │    cfr_buffer.py  │
│                     │   │   (牌面特征提取)   │   │   (经验回放池)     │
│ - 状态提取          │   │                     │   │                   │
│ - 动作转换         │◄──┤  model.py         │───►│  添加/采样数据      │
│ - BetTranslator    │   │  data_generator.py│   │                   │
└─────────────────────┘   │  dataset.py      │   └─────────────────────┘
                           │               │
                           │  train_card   │
                           │  _model.py  │
                           └─────────────┘
                                      │
                                      ▼
┌─────────────────────┐   ┌─────────────────────┐   ┌─────────────────────┐
│     cfr_net.py    │   │    card_model/    │   │    card_model/    │
│   (策略网络)      │   │    config.py    │   │    dataset.py   │
│                  │   │               │   │               │
│ - get_strategy   │   │  配置常量      │   │  数据集封装    │
│ - forward       │   │               │   │               │
└─────────────────────┘   └─────────────────────┘   └─────────────────────┘

四、数据流向图

4.1 Card Model训练数据流

OpenSpiel引擎
     │
     ▼
data_generator.generate_sample()
  ├── _sample_random_state() → 随机游戏状态
  ├── _monte_carlo_rollout() × 1000 → 胜率列表
  └── 返回: (x_hole, x_board, y_equity, y_histogram)
     │
     ▼
dataset.PokerCardDataset
  └── DataLoader → mini-batch
     │
     ▼
model.CardModel (forward)
  ├── embedding + MLP
  ├── equity_head → pred_equity
  └── histogram_head → pred_histogram
     │
     ▼
Loss = EMD + λ*MSE → 反向传播 → 更新权重

4.2 CFR训练数据流

OpenSpiel引擎 (new_initial_state())
     │
     ▼
env_adapter.extract_env_state()
  ├── pot, stacks, street, position
  └── legal_mask
     │
     ▼
card_model.model (eval模式)
  ├── extract_cards_from_state() → hole_cards, board_cards
  └── forward() → 50维 histogram
     │
     ▼
拼接: card_features(50) + env_features(5) = 55维 info_state
     │
     ▼
cfr_net.get_strategy(info_state, legal_mask)
  ├── Regret Matching → current_strategy
  └── Softmax → avg_strategy
     │
     ▼
如果当前玩家=遍历者:
  │  遍历所有合法动作计算regret
  │  buffer.add(info_state, legal_mask, regrets, strategy)
  │
  如果当前玩家≠对方:
       按策略采样1个动作外部采样
     │
     ▼
state.apply_action(engine_action)
     │
     ▼
递归 traverse() → 直到 terminal state
     │
     ▼
returns = state.returns() → 收益

4.3 完整Pipeline

┌──────────────────────────────────────────────────────────────────────────────────────┐
│                      训练 pipeline 流程                   │
├────────────────────────────────────────────────<E29480><E29480><EFBFBD><E29480>───┤
│                                                      │
│  1. 初始化                                          │
│     ├── 加载 OpenSpiel 游戏                              │
│     ├── 加载/初始化 CardModel (eval模式)                   │
│     ├── 加载/初始化 CFRNetwork (train模式)                  │
│     ├── 初始化 CFRBuffer                               │
│     └── 初始化 Optimizer                            │
│                                                      │
│  2. 迭代循环 (50000次)                               │
│     │                                             │
│     ├── 阶段A: 数据生成 (200局/iter)                  │
│     │   ├── traverse() 递归遍历博弈树               │
│     │   │   ├── 提取 info_state (card+env特征)       │
│     │   │   ├── CardModel 预测胜率直方图           │
│     │   │   ├── CFRNetwork 计算策略                 │
│     │   │   ├── 遍历/采样动作                    │
│     │   │   └── 计算 regret存入 Buffer         │
│     │   └── 累积经验数据                          │
│     │                                             │
│     └── 阶段B: 网络训练 (50步/iter)                  │
│         └── train_step()                              │
│             ├── Buffer.sample(batch_size=32768)         │
│             ├── 计算 Regret Loss + Policy Loss        │
│             └── 反向传播更新 CFRNetwork          │
│                                                      │
│  3. 保存模型                                       │
│     └── cfr_net_checkpoint.pt                        │
└──────────────────────────────────────────────────────┘

五、模块交互接口

5.1 env_adapter → cfr_net

# env_adapter.py 输出
env_info = extract_env_state(state)
# {
#   'pot': 500,
#   'p0_stack': 19500,
#   'p1_stack': 19500,
#   'street': 1,
#   'position': 0,
#   'legal_mask': [1, 1, 1, 0, 0, 0]
# }

# 构建 env_features
env_features = [
    pot / 20000,
    p0_stack / 20000,
    p1_stack / 20000,
    street / 3.0,
    position
]  # 5维

5.2 card_model → cfr_net

# card_model 输出一 50维胜率直方图
_, pred_histogram = card_model(x_hole, x_board)  # [1, 50]

# cfr_net 输入 = 拼接
info_state = [card_features(50) + env_features(5)]  # 55维

5.3 cfr_net → env_adapter

# cfr_net 输出
current_strategy = net.get_strategy(card_features, env_features, legal_mask)
# [1, 6] 概率分布

# 采样动作
action_idx = random.choices(legal_indices, weights=current_strategy)[0]

# env_adapter 转换
engine_action = translator.cfr_to_engine_action(state, action_idx)

# 执行
state.apply_action(engine_action)

5.4 cfr_net ↔ cfr_buffer

# traverse() 存入
buffer.add(info_state, legal_mask, regrets, strategy)

# train_step() 取出
info_states, legal_masks, target_regrets, target_strategies = buffer.sample(32768)

六、游戏规则与配置对齐

6.1 Botzone规则

- 玩家数: 2人 (Heads-up)
- 初始筹码: 20000
- 盲注: SB=50, BB=100
- 轮次: Preflop → Flop(3张) → Turn(1张) → River(1张)
- 下注模式: No-Limit

6.2 OpenSpiel配置

HUNL_FULLGAME_STRING = (
    "universal_poker("
    "betting=nolimit,"
    "numPlayers=2,"
    "numRanks=13,"
    "numSuits=4,"
    "numHoleCards=2,"
    "numRounds=4,"
    "numBoardCards=0 3 1 1,"  # Preflop/Flop/Turn/River
    "stack=20000 20000,"
    "blind=50 100,"
    "bettingAbstraction=fullgame)"
)

七、关键超参数汇总

7.1 Card Model

参数 说明
EMBEDDING_DIM 32 卡牌embedding维度
MLP_HIDDEN [128,128,64] MLP隐藏层
NUM_BINS 50 胜率直方图bin数
NUM_ROLLOUTS 1000 每样本MC rollout次数
BATCH_SIZE 16384 训练batch size
NUM_TRAIN_SAMPLES 2,000,000 训练样本数
NUM_EPOCHS 64 训练轮数
LEARNING_RATE 5e-4 学习率

7.2 CFR Trainer

参数 说明
CARD_DIM 50 Card Model输出维度
ENV_DIM 5 局势特征维度
NUM_ACTIONS 6 CFR动作数
MLP_HIDDEN [256,256,128] CFR网络隐藏层
NUM_ITERATIONS 50,000 训练迭代次数
GAMES_PER_ITER 200 每迭代对局数
BUFFER_MAX_SIZE 10,000,000 经验池容量
TRAIN_BATCH_SIZE 32,768 训练batch size
TRAIN_STEPS_PER_ITER 50 每迭代训练步数
LEARNING_RATE 5e-4 学习率

八、训练产物

8.1 Card Model

card_model/data/
  ├── train_data.npz       # 训练数据缓存
  ├── val_data.npz       # 验证数据缓存
  ├── best_card_model.pt # 最佳验证模型
  └── final_card_model.pt # 最终模型

8.2 CFR Trainer

cfr_net_checkpoint.pt  # CFR网络检查点

九、总结

这是一个典型的模块化AI系统设计

  1. 数据驱动Card Model学习牌面→胜率的映射提供特征表示
  2. 算法核心MCCFR迭代计算 regret学习博弈策略
  3. 工程桥接env_adapter连接引擎和算法处理动作空间离散化
  4. 存储中转cfr_buffer作为数据中转站解耦生成和训练

各模块职责清晰,接口明确,通过张量传递形成完整的数据闭环。