Files
new/CODE_STRUCTURE.md
e2hang ed2fadb625 What
2026-04-20 20:25:35 +08:00

653 lines
20 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 项目代码结构文档
## 概述
这是一个基于 Deep CFR (Deep CounterfactualRegret Minimization) 的德州扑克AI系统使用PyTorch和OpenSpiel实现。系统分为两大模块
1. **Card Model** - 牌面胜率预测模型
2. **CFR Trainer** - 策略训练器基于MCCFR算法
---
## 一、文件清单与功能说明
### 1.1 根目录核心文件
| 文件 | 功能 |
|------|------|
| `cfr_buffer.py` | CFR经验回放池存储自对弈产生的训练数据 |
| `cfr_net.py` | CFR策略网络双头MLPRegret Head + Policy Head |
| `env_adapter.py` | OpenSpiel环境适配器动作转换与状态提取 |
| `mccfr_trainer.py` | MCCFR训练主流程自对弈+网络训练 |
| `train_card_model.py` | Card Model训练入口包装脚本 |
### 1.2 card_model 包
| 文件 | 功能 |
|------|------|
| `__init__.py` | 包入口,导出主要接口 |
| `config.py` | 配置常量(牌数、维度、路径等) |
| `model.py` | CardModel神经网络双头输出scalar + histogram |
| `data_generator.py` | Monte Carlo数据生成器 |
| `dataset.py` | PyTorch Dataset封装 |
| `train_card_model.py` | Card Model训练脚本 |
---
## 二、各文件详细说明
### 2.1 cfr_buffer.py
**功能**CFR经验回放池存储自我博弈过程中收集的训练数据。
**核心数据结构**
- 使用 `deque(maxlen=N)` 实现固定容量的FIFO滑动窗口
- 存储4类数据info_state、legal_mask、regrets、strategy
**存储内容**
```
- info_state: 信息集特征向量55维 = 牌面50维 + 局势5维
- legal_mask: 6个动作的合法性掩码 [0/1]
- regrets: 累计遗憾值用于训练Regret网络
- strategy: 平均策略概率分布用于训练Policy网络
```
**主要方法**
- `add(info_state, legal_mask, regrets, strategy)` - 添加一条经验
- `sample(batch_size)` - 随机采样一个batch
- `clear()` - 清空缓冲区
**数据流**:被 `mccfr_trainer.py``traverse()` 函数填充,被 `train_step()` 函数消费。
---
### 2.2 cfr_net.py
**功能**Deep CFR的策略网络接受信息集特征输出6个动作的遗憾值和平均策略。
**网络结构**
```
输入: [card_features(50) + env_features(5)] = 55维
MLP骨干: concat_dim(55) → 256 → 256 → 128
双输出头:
- regret_head: 6维无激活regret可为负数
- policy_head: 6维Softmax输出概率分布
```
**核心方法**
- `forward(card_features, env_features)``(regrets, policy_logits)`
- `get_strategy(card_features, env_features, legal_mask)``(current_strategy, avg_strategy)`
**Regret Matching算法**
1. 将负regret截断为0`positive_regret = relu(regret)`
2. 只保留合法动作:`masked_regret = positive_regret * legal_mask`
3. 归一化得到即时策略
4. 若所有regret为0则使用均匀分布
**数据流**
- 输入:来自 `card_model` 的50维胜率直方图 + `env_adapter` 提取的5维局势特征
- 输出6维动作策略驱动 `env_adapter` 的动作选择
---
### 2.3 env_adapter.py
**功能**OpenSpiel引擎与CFR系统之间的桥梁负责
1. **动作离散化**将引擎的连续下注空间映射为6个CFR标准动作
2. **状态提取**:从引擎状态中提取局势特征
3. **下注映射**将CFR动作转换为引擎可执行的动作ID
**CFR动作定义**对应索引0-5
```
0: FOLD - 弃牌
1: CALL - 跟注/过牌
2: MIN_RAISE - 最小加注
3: HALF_POT - 半池加注(目标 = max_contribution + 0.5 × pot
4: FULL_POT - 满池加注(目标 = max_contribution + 1.0 × pot
5: ALL_IN - 全押
```
**局势特征提取**5维
```
[pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position]
```
- pot: 当前底池(两玩家贡献额之和)
- p0_stack/p1_stack: 双方剩余筹码
- street: 当前轮次0=Preflop, 1=Flop, 2=Turn, 3=River
- position: 当前行动玩家0或1
**关键类**
- `BetTranslator`: 负责CFR动作 ↔ 引擎动作的双向转换
- `get_cfr_legal_mask(state)` → 6维掩码
- `cfr_to_engine_action(state, cfr_action_idx)` → 引擎动作ID
- `extract_env_state(state)` → 局势字典
**Fallback机制**:当筹码不足无法执行目标加注时,自动降级到最大合法动作,保证不崩溃。
---
### 2.4 mccfr_trainer.py
**功能**MCCFR训练主流程协调各模块完成自对弈数据生成和网络训练。
**核心流程**
```
主循环 for iteration in 1~N:
├── 阶段A: 数据生成
│ └── for game in 1~M:
│ traverse() 遍历博弈树
│ → 填充 CFRBuffer
└── 阶段B: 网络训练
└── for step in 1~K:
train_step() 从Buffer采样训练网络
```
**外部采样MCCFR**
- 对方回合根据当前策略采样1个动作大减少计算量
- 遍历者回合遍历所有合法动作精确计算每个动作的regret
**traverse()函数**
- 递归遍历博弈树
- 在遍历者节点计算反事实遗憾:`regret(I,a) = v(I,a) - v(I)`
-`[info_state, legal_mask, regrets, strategy]` 存入Buffer
**train_step()函数**
- 从Buffer采样mini-batch
- 计算两个Loss
- Regret Loss: MSE(预测regret, 目标regret)
- Policy Loss: MSE(预测策略, 目标策略)
- 总Loss = Regret Loss + Policy Loss
**依赖模块**
- `env_adapter` - 状态提取和动作转换
- `card_model` - 牌面特征提取
- `cfr_net` - 策略网络
- `cfr_buffer` - 经验池
---
### 2.5 train_card_model.py根目录
**功能**Card Model训练的入口脚本包装 `card_model/train_card_model.py`
**作用**:确保 `poker/` 目录在sys.path中使导入路径正确。
```python
from card_model.train_card_model import main
main()
```
---
### 2.6 card_model/__init__.py
**功能**:包入口,导出主要接口供外部使用。
```python
from .config import *
from .data_generator import extract_cards_from_state, generate_sample
from .dataset import PokerCardDataset
from .model import CardModel
```
---
### 2.7 card_model/config.py
**功能**:项目配置常量定义。
**关键常量**
```python
# 牌面编码
NUM_CARDS = 52 # 标准52张牌
PAD_TOKEN = 52 # 填充token
VOCAB_SIZE = 53 # 52 + 1 padding
# 游戏配置对应Botzone规则
HUNL_GAME_STRING = "universal_poker(...)" # 2人无限德州
# 数据生成
NUM_ROLLOUTS = 1000 # 每样本MC rollout次数
NUM_BINS = 50 # 胜率直方图bin数
# 模型架构
EMBEDDING_DIM = 32
MLP_HIDDEN = [128, 128, 64]
# 训练超参数
NUM_TRAIN_SAMPLES = 2000000 # 200万训练样本
BATCH_SIZE = 16384
LEARNING_RATE = 5e-4
NUM_EPOCHS = 64
```
---
### 2.8 card_model/model.py
**功能**CardModel神经网络预测牌面胜率和胜率分布。
**网络架构**
```
输入:
- x_hole: [batch, 2] 玩家手牌 IDs (0-51)
- x_board: [batch, 5] 公共牌 IDs (不足5则用PAD填充)
Embedding层: 53 tokens → 32维
Hole Embedding: sum → [batch, 32]
Board Embedding: sum → [batch, 32]
拼接: [batch, 64]
MLP骨干: 64 → 128 → 128 → 64
双输出头:
- equity_head: Sigmoid → scalar equity (0~1)
- histogram_head: Softmax → 50维分布
```
**输出**
- `pred_equity`: 预测胜率(标量)
- `pred_histogram`: 胜率直方图50维概率分布
**<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>**
- 输入:`env_adapter` 提取的手牌+公共牌IDs
- 输出50维胜率直方图 → 作为 `cfr_net` 的card_features输入
**训练目标**
- 最小化 EMD loss直方图之间的Wasserstein距离
- 最小化 MSE loss胜率预测误差
---
### 2.9 card_model/data_generator.py
**功能**使用Monte Carlo方法生成牌面训练数据。
**核心函数**
1. **`extract_cards_from_state(state)`**
- 从OpenSpiel状态提取玩家手牌和公共牌
- 返回 `(hole_cards, board_cards)` 列表
2. **`_sample_random_state(game)`**
- 随机采样一个游戏状态Preflop/Flop/Turn
- 使用随机发牌 + check/call推进
3. **`_monte_carlo_rollout(game, hole_cards, board_cards, used_cards)`**
- 随机完成剩余公共牌和对手手牌
- 返回 1.0(赢)/0.5(平)/0.0(输)
4. **`generate_sample(game=None)`**
- 生成单个训练样本:
- `x_hole`: 手牌IDs [2]
- `x_board`: 公共牌IDs [5]padding填充
- `y_equity`: 平均胜率float
- `y_histogram`: 50维直方图归一化
**数据流**:被 `dataset.py` 调用,生成训练/验证数据。
---
### 2.10 card_model/dataset.py
**功能**PyTorch Dataset封装支持多进程并行生成数据和磁盘缓存。
**PokerCardDataset类**
```python
# 数据结构
x_hole: [num_samples, 2] int64 # 手牌IDs
x_board: [num_samples, 5] int64 # 公共牌IDspadding
y_equity: [num_samples, 1] float32 # 胜率
y_histogram:[num_samples, 50] float32 # 直方图分布
```
**多进程生成**
- 使用 `multiprocessing.Pool` 并行调用 `generate_sample()`
- `num_workers` 参数控制并行度
**磁盘缓存**
- 首次生成后保存为 `.npz` 文件
- 下次运行直接加载,避免重复生成
---
### 2.11 card_model/train_card_model.py
**功能**Card Model训练脚本。
**训练流程**
```python
for epoch in 1~NUM_EPOCHS:
# 训练一个epoch
train_one_epoch(model, train_loader, optimizer)
# 验证
validate(model, val_loader)
# 保存最佳模型
if val_loss < best_val_loss:
save best_card_model.pt
```
**Loss函数**
```python
emd = emd_loss_1d(pred_histogram, target_histogram)
mse = MSE(pred_equity, target_equity)
loss = emd + lambda * mse
```
**EMD LossEarth Mover's Distance**
- 对于1D有序分布EMD = CDF之间L1距离的平均值
- `cdf_pred = cumsum(pred_histogram)`
- `emd = mean(|cdf_pred - cdf_target|)`
**输出**
- `best_card_model.pt` - 验证集最佳模型
- `final_card_model.pt` - 最终模型
---
## 三、文件依赖关系图
```
┌─────────────────────────────────────────────────────────────────────────────┐
│ mccfr_trainer.py │
│ (主训练流程,协调者) │
└─────────────────────────────────────┬─────────────────────────────────────┘
┌───────────────────────────┼───────────────────────────┐
│ │ │
▼ ▼ ▼
┌─────────────────────┐ ┌─────────────────────┐ ┌────────────────────<E29480><E29480><EFBFBD>
│ env_adapter.py │ │ card_model/ │ │ cfr_buffer.py │
│ │ │ (牌面特征提取) │ │ (经验回放池) │
│ - 状态提取 │ │ │ │ │
│ - 动作转换 │◄──┤ model.py │───►│ 添加/采样数据 │
│ - BetTranslator │ │ data_generator.py│ │ │
└─────────────────────┘ │ dataset.py │ └─────────────────────┘
│ │
│ train_card │
│ _model.py │
└─────────────┘
┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐
│ cfr_net.py │ │ card_model/ │ │ card_model/ │
│ (策略网络) │ │ config.py │ │ dataset.py │
│ │ │ │ │ │
│ - get_strategy │ │ 配置常量 │ │ 数据集封装 │
│ - forward │ │ │ │ │
└─────────────────────┘ └─────────────────────┘ └─────────────────────┘
```
---
## 四、数据流向图
### 4.1 Card Model训练数据流
```
OpenSpiel引擎
data_generator.generate_sample()
├── _sample_random_state() → 随机游戏状态
├── _monte_carlo_rollout() × 1000 → 胜率列表
└── 返回: (x_hole, x_board, y_equity, y_histogram)
dataset.PokerCardDataset
└── DataLoader → mini-batch
model.CardModel (forward)
├── embedding + MLP
├── equity_head → pred_equity
└── histogram_head → pred_histogram
Loss = EMD + λ*MSE → 反向传播 → 更新权重
```
### 4.2 CFR训练数据流
```
OpenSpiel引擎 (new_initial_state())
env_adapter.extract_env_state()
├── pot, stacks, street, position
└── legal_mask
card_model.model (eval模式)
├── extract_cards_from_state() → hole_cards, board_cards
└── forward() → 50维 histogram
拼接: card_features(50) + env_features(5) = 55维 info_state
cfr_net.get_strategy(info_state, legal_mask)
├── Regret Matching → current_strategy
└── Softmax → avg_strategy
如果当前玩家=遍历者:
│ 遍历所有合法动作计算regret
│ buffer.add(info_state, legal_mask, regrets, strategy)
如果当前玩家≠对方:
按策略采样1个动作外部采样
state.apply_action(engine_action)
递归 traverse() → 直到 terminal state
returns = state.returns() → 收益
```
### 4.3 完整Pipeline
```
┌──────────────────────────────────────────────────────────────────────────────────────┐
│ 训练 pipeline 流程 │
├────────────────────────────────────────────────<EFBFBD><EFBFBD><EFBFBD><EFBFBD>───┤
│ │
│ 1. 初始化 │
│ ├── 加载 OpenSpiel 游戏 │
│ ├── 加载/初始化 CardModel (eval模式) │
│ ├── 加载/初始化 CFRNetwork (train模式) │
│ ├── 初始化 CFRBuffer │
│ └── 初始化 Optimizer │
│ │
│ 2. 迭代循环 (50000次) │
│ │ │
│ ├── 阶段A: 数据生成 (200局/iter) │
│ │ ├── traverse() 递归遍历博弈树 │
│ │ │ ├── 提取 info_state (card+env特征) │
│ │ │ ├── CardModel 预测胜率直方图 │
│ │ │ ├── CFRNetwork 计算策略 │
│ │ │ ├── 遍历/采样动作 │
│ │ │ └── 计算 regret存入 Buffer │
│ │ └── 累积经验数据 │
│ │ │
│ └── 阶段B: 网络训练 (50步/iter) │
│ └── train_step() │
│ ├── Buffer.sample(batch_size=32768) │
│ ├── 计算 Regret Loss + Policy Loss │
│ └── 反向传播更新 CFRNetwork │
│ │
│ 3. 保存模型 │
│ └── cfr_net_checkpoint.pt │
└──────────────────────────────────────────────────────┘
```
---
## 五、模块交互接口
### 5.1 env_adapter → cfr_net
```python
# env_adapter.py 输出
env_info = extract_env_state(state)
# {
# 'pot': 500,
# 'p0_stack': 19500,
# 'p1_stack': 19500,
# 'street': 1,
# 'position': 0,
# 'legal_mask': [1, 1, 1, 0, 0, 0]
# }
# 构建 env_features
env_features = [
pot / 20000,
p0_stack / 20000,
p1_stack / 20000,
street / 3.0,
position
] # 5维
```
### 5.2 card_model → cfr_net
```python
# card_model 输出一 50维胜率直方图
_, pred_histogram = card_model(x_hole, x_board) # [1, 50]
# cfr_net 输入 = 拼接
info_state = [card_features(50) + env_features(5)] # 55维
```
### 5.3 cfr_net → env_adapter
```python
# cfr_net 输出
current_strategy = net.get_strategy(card_features, env_features, legal_mask)
# [1, 6] 概率分布
# 采样动作
action_idx = random.choices(legal_indices, weights=current_strategy)[0]
# env_adapter 转换
engine_action = translator.cfr_to_engine_action(state, action_idx)
# 执行
state.apply_action(engine_action)
```
### 5.4 cfr_net ↔ cfr_buffer
```python
# traverse() 存入
buffer.add(info_state, legal_mask, regrets, strategy)
# train_step() 取出
info_states, legal_masks, target_regrets, target_strategies = buffer.sample(32768)
```
---
## 六、游戏规则与配置对齐
### 6.1 Botzone规则
```
- 玩家数: 2人 (Heads-up)
- 初始筹码: 20000
- 盲注: SB=50, BB=100
- 轮次: Preflop → Flop(3张) → Turn(1张) → River(1张)
- 下注模式: No-Limit
```
### 6.2 OpenSpiel配置
```python
HUNL_FULLGAME_STRING = (
"universal_poker("
"betting=nolimit,"
"numPlayers=2,"
"numRanks=13,"
"numSuits=4,"
"numHoleCards=2,"
"numRounds=4,"
"numBoardCards=0 3 1 1," # Preflop/Flop/Turn/River
"stack=20000 20000,"
"blind=50 100,"
"bettingAbstraction=fullgame)"
)
```
---
## 七、关键超参数汇总
### 7.1 Card Model
| 参数 | 值 | 说明 |
|------|-----|------|
| EMBEDDING_DIM | 32 | 卡牌embedding维度 |
| MLP_HIDDEN | [128,128,64] | MLP隐藏层 |
| NUM_BINS | 50 | 胜率直方图bin数 |
| NUM_ROLLOUTS | 1000 | 每样本MC rollout次数 |
| BATCH_SIZE | 16384 | 训练batch size |
| NUM_TRAIN_SAMPLES | 2,000,000 | 训练样本数 |
| NUM_EPOCHS | 64 | 训练轮数 |
| LEARNING_RATE | 5e-4 | 学习率 |
### 7.2 CFR Trainer
| 参数 | 值 | 说明 |
|------|-----|------|
| CARD_DIM | 50 | Card Model输出维度 |
| ENV_DIM | 5 | 局势特征维度 |
| NUM_ACTIONS | 6 | CFR动作数 |
| MLP_HIDDEN | [256,256,128] | CFR网络隐藏层 |
| NUM_ITERATIONS | 50,000 | 训练迭代次数 |
| GAMES_PER_ITER | 200 | 每迭代对局数 |
| BUFFER_MAX_SIZE | 10,000,000 | 经验池容量 |
| TRAIN_BATCH_SIZE | 32,768 | 训练batch size |
| TRAIN_STEPS_PER_ITER | 50 | 每迭代训练步数 |
| LEARNING_RATE | 5e-4 | 学习率 |
---
## 八、训练产物
### 8.1 Card Model
```
card_model/data/
├── train_data.npz # 训练数据缓存
├── val_data.npz # 验证数据缓存
├── best_card_model.pt # 最佳验证模型
└── final_card_model.pt # 最终模型
```
### 8.2 CFR Trainer
```
cfr_net_checkpoint.pt # CFR网络检查点
```
---
## 九、总结
这是一个典型的模块化AI系统设计
1. **数据驱动**Card Model学习牌面→胜率的映射提供特征表示
2. **算法核心**MCCFR迭代计算 regret学习博弈策略
3. **工程桥接**env_adapter连接引擎和算法处理动作空间离散化
4. **存储中转**cfr_buffer作为数据中转站解耦生成和训练
各模块职责清晰,接口明确,通过张量传递形成完整的数据闭环。