653 lines
20 KiB
Markdown
653 lines
20 KiB
Markdown
# 项目代码结构文档
|
||
|
||
## 概述
|
||
|
||
这是一个基于 Deep CFR (Deep CounterfactualRegret Minimization) 的德州扑克AI系统,使用PyTorch和OpenSpiel实现。系统分为两大模块:
|
||
|
||
1. **Card Model** - 牌面胜率预测模型
|
||
2. **CFR Trainer** - 策略训练器(基于MCCFR算法)
|
||
|
||
---
|
||
|
||
## 一、文件清单与功能说明
|
||
|
||
### 1.1 根目录核心文件
|
||
|
||
| 文件 | 功能 |
|
||
|------|------|
|
||
| `cfr_buffer.py` | CFR经验回放池,存储自对弈产生的训练数据 |
|
||
| `cfr_net.py` | CFR策略网络(双头MLP:Regret Head + Policy Head) |
|
||
| `env_adapter.py` | OpenSpiel环境适配器,动作转换与状态提取 |
|
||
| `mccfr_trainer.py` | MCCFR训练主流程,自对弈+网络训练 |
|
||
| `train_card_model.py` | Card Model训练入口(包装脚本) |
|
||
|
||
### 1.2 card_model 包
|
||
|
||
| 文件 | 功能 |
|
||
|------|------|
|
||
| `__init__.py` | 包入口,导出主要接口 |
|
||
| `config.py` | 配置常量(牌数、维度、路径等) |
|
||
| `model.py` | CardModel神经网络(双头输出:scalar + histogram) |
|
||
| `data_generator.py` | Monte Carlo数据生成器 |
|
||
| `dataset.py` | PyTorch Dataset封装 |
|
||
| `train_card_model.py` | Card Model训练脚本 |
|
||
|
||
---
|
||
|
||
## 二、各文件详细说明
|
||
|
||
### 2.1 cfr_buffer.py
|
||
|
||
**功能**:CFR经验回放池,存储自我博弈过程中收集的训练数据。
|
||
|
||
**核心数据结构**:
|
||
- 使用 `deque(maxlen=N)` 实现固定容量的FIFO滑动窗口
|
||
- 存储4类数据:info_state、legal_mask、regrets、strategy
|
||
|
||
**存储内容**:
|
||
```
|
||
- info_state: 信息集特征向量(55维 = 牌面50维 + 局势5维)
|
||
- legal_mask: 6个动作的合法性掩码 [0/1]
|
||
- regrets: 累计遗憾值(用于训练Regret网络)
|
||
- strategy: 平均策略概率分布(用于训练Policy网络)
|
||
```
|
||
|
||
**主要方法**:
|
||
- `add(info_state, legal_mask, regrets, strategy)` - 添加一条经验
|
||
- `sample(batch_size)` - 随机采样一个batch
|
||
- `clear()` - 清空缓冲区
|
||
|
||
**数据流**:被 `mccfr_trainer.py` 的 `traverse()` 函数填充,被 `train_step()` 函数消费。
|
||
|
||
---
|
||
|
||
### 2.2 cfr_net.py
|
||
|
||
**功能**:Deep CFR的策略网络,接受信息集特征,输出6个动作的遗憾值和平均策略。
|
||
|
||
**网络结构**:
|
||
```
|
||
输入: [card_features(50) + env_features(5)] = 55维
|
||
↓
|
||
MLP骨干: concat_dim(55) → 256 → 256 → 128
|
||
↓
|
||
双输出头:
|
||
- regret_head: 6维,无激活(regret可为负数)
|
||
- policy_head: 6维,Softmax输出(概率分布)
|
||
```
|
||
|
||
**核心方法**:
|
||
- `forward(card_features, env_features)` → `(regrets, policy_logits)`
|
||
- `get_strategy(card_features, env_features, legal_mask)` → `(current_strategy, avg_strategy)`
|
||
|
||
**Regret Matching算法**:
|
||
1. 将负regret截断为0:`positive_regret = relu(regret)`
|
||
2. 只保留合法动作:`masked_regret = positive_regret * legal_mask`
|
||
3. 归一化得到即时策略
|
||
4. 若所有regret为0,则使用均匀分布
|
||
|
||
**数据流**:
|
||
- 输入:来自 `card_model` 的50维胜率直方图 + `env_adapter` 提取的5维局势特征
|
||
- 输出:6维动作策略,驱动 `env_adapter` 的动作选择
|
||
|
||
---
|
||
|
||
### 2.3 env_adapter.py
|
||
|
||
**功能**:OpenSpiel引擎与CFR系统之间的桥梁,负责:
|
||
|
||
1. **动作离散化**:将引擎的连续下注空间映射为6个CFR标准动作
|
||
2. **状态提取**:从引擎状态中提取局势特征
|
||
3. **下注映射**:将CFR动作转换为引擎可执行的动作ID
|
||
|
||
**CFR动作定义**(对应索引0-5):
|
||
```
|
||
0: FOLD - 弃牌
|
||
1: CALL - 跟注/过牌
|
||
2: MIN_RAISE - 最小加注
|
||
3: HALF_POT - 半池加注(目标 = max_contribution + 0.5 × pot)
|
||
4: FULL_POT - 满池加注(目标 = max_contribution + 1.0 × pot)
|
||
5: ALL_IN - 全押
|
||
```
|
||
|
||
**局势特征提取**(5维):
|
||
```
|
||
[pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position]
|
||
```
|
||
- pot: 当前底池(两玩家贡献额之和)
|
||
- p0_stack/p1_stack: 双方剩余筹码
|
||
- street: 当前轮次(0=Preflop, 1=Flop, 2=Turn, 3=River)
|
||
- position: 当前行动玩家(0或1)
|
||
|
||
**关键类**:
|
||
- `BetTranslator`: 负责CFR动作 ↔ 引擎动作的双向转换
|
||
- `get_cfr_legal_mask(state)` → 6维掩码
|
||
- `cfr_to_engine_action(state, cfr_action_idx)` → 引擎动作ID
|
||
- `extract_env_state(state)` → 局势字典
|
||
|
||
**Fallback机制**:当筹码不足无法执行目标加注时,自动降级到最大合法动作,保证不崩溃。
|
||
|
||
---
|
||
|
||
### 2.4 mccfr_trainer.py
|
||
|
||
**功能**:MCCFR训练主流程,协调各模块完成自对弈数据生成和网络训练。
|
||
|
||
**核心流程**:
|
||
```
|
||
主循环 for iteration in 1~N:
|
||
├── 阶段A: 数据生成
|
||
│ └── for game in 1~M:
|
||
│ traverse() 遍历博弈树
|
||
│ → 填充 CFRBuffer
|
||
│
|
||
└── 阶段B: 网络训练
|
||
└── for step in 1~K:
|
||
train_step() 从Buffer采样,训练网络
|
||
```
|
||
|
||
**外部采样MCCFR**:
|
||
- 对方回合:根据当前策略采样1个动作(大减少计算量)
|
||
- 遍历者回合:遍历所有合法动作,精确计算每个动作的regret
|
||
|
||
**traverse()函数**:
|
||
- 递归遍历博弈树
|
||
- 在遍历者节点计算反事实遗憾:`regret(I,a) = v(I,a) - v(I)`
|
||
- 将 `[info_state, legal_mask, regrets, strategy]` 存入Buffer
|
||
|
||
**train_step()函数**:
|
||
- 从Buffer采样mini-batch
|
||
- 计算两个Loss:
|
||
- Regret Loss: MSE(预测regret, 目标regret)
|
||
- Policy Loss: MSE(预测策略, 目标策略)
|
||
- 总Loss = Regret Loss + Policy Loss
|
||
|
||
**依赖模块**:
|
||
- `env_adapter` - 状态提取和动作转换
|
||
- `card_model` - 牌面特征提取
|
||
- `cfr_net` - 策略网络
|
||
- `cfr_buffer` - 经验池
|
||
|
||
---
|
||
|
||
### 2.5 train_card_model.py(根目录)
|
||
|
||
**功能**:Card Model训练的入口脚本,包装 `card_model/train_card_model.py`。
|
||
|
||
**作用**:确保 `poker/` 目录在sys.path中,使导入路径正确。
|
||
|
||
```python
|
||
from card_model.train_card_model import main
|
||
main()
|
||
```
|
||
|
||
---
|
||
|
||
### 2.6 card_model/__init__.py
|
||
|
||
**功能**:包入口,导出主要接口供外部使用。
|
||
|
||
```python
|
||
from .config import *
|
||
from .data_generator import extract_cards_from_state, generate_sample
|
||
from .dataset import PokerCardDataset
|
||
from .model import CardModel
|
||
```
|
||
|
||
---
|
||
|
||
### 2.7 card_model/config.py
|
||
|
||
**功能**:项目配置常量定义。
|
||
|
||
**关键常量**:
|
||
```python
|
||
# 牌面编码
|
||
NUM_CARDS = 52 # 标准52张牌
|
||
PAD_TOKEN = 52 # 填充token
|
||
VOCAB_SIZE = 53 # 52 + 1 padding
|
||
|
||
# 游戏配置(对应Botzone规则)
|
||
HUNL_GAME_STRING = "universal_poker(...)" # 2人无限德州
|
||
|
||
# 数据生成
|
||
NUM_ROLLOUTS = 1000 # 每样本MC rollout次数
|
||
NUM_BINS = 50 # 胜率直方图bin数
|
||
|
||
# 模型架构
|
||
EMBEDDING_DIM = 32
|
||
MLP_HIDDEN = [128, 128, 64]
|
||
|
||
# 训练超参数
|
||
NUM_TRAIN_SAMPLES = 2000000 # 200万训练样本
|
||
BATCH_SIZE = 16384
|
||
LEARNING_RATE = 5e-4
|
||
NUM_EPOCHS = 64
|
||
```
|
||
|
||
---
|
||
|
||
### 2.8 card_model/model.py
|
||
|
||
**功能**:CardModel神经网络,预测牌面胜率和胜率分布。
|
||
|
||
**网络架构**:
|
||
```
|
||
输入:
|
||
- x_hole: [batch, 2] 玩家手牌 IDs (0-51)
|
||
- x_board: [batch, 5] 公共牌 IDs (不足5则用PAD填充)
|
||
|
||
Embedding层: 53 tokens → 32维
|
||
↓
|
||
Hole Embedding: sum → [batch, 32]
|
||
Board Embedding: sum → [batch, 32]
|
||
↓
|
||
拼接: [batch, 64]
|
||
↓
|
||
MLP骨干: 64 → 128 → 128 → 64
|
||
↓
|
||
双输出头:
|
||
- equity_head: Sigmoid → scalar equity (0~1)
|
||
- histogram_head: Softmax → 50维分布
|
||
```
|
||
|
||
**输出**:
|
||
- `pred_equity`: 预测胜率(标量)
|
||
- `pred_histogram`: 胜率直方图(50维概率分布)
|
||
|
||
**<EFBFBD><EFBFBD><EFBFBD>据<EFBFBD><EFBFBD>**:
|
||
- 输入:`env_adapter` 提取的手牌+公共牌IDs
|
||
- 输出:50维胜率直方图 → 作为 `cfr_net` 的card_features输入
|
||
|
||
**训练目标**:
|
||
- 最小化 EMD loss(直方图之间的Wasserstein距离)
|
||
- 最小化 MSE loss(胜率预测误差)
|
||
|
||
---
|
||
|
||
### 2.9 card_model/data_generator.py
|
||
|
||
**功能**:使用Monte Carlo方法生成牌面训练数据。
|
||
|
||
**核心函数**:
|
||
|
||
1. **`extract_cards_from_state(state)`**
|
||
- 从OpenSpiel状态提取玩家手牌和公共牌
|
||
- 返回 `(hole_cards, board_cards)` 列表
|
||
|
||
2. **`_sample_random_state(game)`**
|
||
- 随机采样一个游戏状态(Preflop/Flop/Turn)
|
||
- 使用随机发牌 + check/call推进
|
||
|
||
3. **`_monte_carlo_rollout(game, hole_cards, board_cards, used_cards)`**
|
||
- 随机完成剩余公共牌和对手手牌
|
||
- 返回 1.0(赢)/0.5(平)/0.0(输)
|
||
|
||
4. **`generate_sample(game=None)`**
|
||
- 生成单个训练样本:
|
||
- `x_hole`: 手牌IDs [2]
|
||
- `x_board`: 公共牌IDs [5](padding填充)
|
||
- `y_equity`: 平均胜率(float)
|
||
- `y_histogram`: 50维直方图(归一化)
|
||
|
||
**数据流**:被 `dataset.py` 调用,生成训练/验证数据。
|
||
|
||
---
|
||
|
||
### 2.10 card_model/dataset.py
|
||
|
||
**功能**:PyTorch Dataset封装,支持多进程并行生成数据和磁盘缓存。
|
||
|
||
**PokerCardDataset类**:
|
||
```python
|
||
# 数据结构
|
||
x_hole: [num_samples, 2] int64 # 手牌IDs
|
||
x_board: [num_samples, 5] int64 # 公共牌IDs(padding)
|
||
y_equity: [num_samples, 1] float32 # 胜率
|
||
y_histogram:[num_samples, 50] float32 # 直方图分布
|
||
```
|
||
|
||
**多进程生成**:
|
||
- 使用 `multiprocessing.Pool` 并行调用 `generate_sample()`
|
||
- `num_workers` 参数控制并行度
|
||
|
||
**磁盘缓存**:
|
||
- 首次生成后保存为 `.npz` 文件
|
||
- 下次运行直接加载,避免重复生成
|
||
|
||
---
|
||
|
||
### 2.11 card_model/train_card_model.py
|
||
|
||
**功能**:Card Model训练脚本。
|
||
|
||
**训练流程**:
|
||
```python
|
||
for epoch in 1~NUM_EPOCHS:
|
||
# 训练一个epoch
|
||
train_one_epoch(model, train_loader, optimizer)
|
||
# 验证
|
||
validate(model, val_loader)
|
||
# 保存最佳模型
|
||
if val_loss < best_val_loss:
|
||
save best_card_model.pt
|
||
```
|
||
|
||
**Loss函数**:
|
||
```python
|
||
emd = emd_loss_1d(pred_histogram, target_histogram)
|
||
mse = MSE(pred_equity, target_equity)
|
||
loss = emd + lambda * mse
|
||
```
|
||
|
||
**EMD Loss(Earth Mover's Distance)**:
|
||
- 对于1D有序分布,EMD = CDF之间L1距离的平均值
|
||
- `cdf_pred = cumsum(pred_histogram)`
|
||
- `emd = mean(|cdf_pred - cdf_target|)`
|
||
|
||
**输出**:
|
||
- `best_card_model.pt` - 验证集最佳模型
|
||
- `final_card_model.pt` - 最终模型
|
||
|
||
---
|
||
|
||
## 三、文件依赖关系图
|
||
|
||
```
|
||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||
│ mccfr_trainer.py │
|
||
│ (主训练流程,协调者) │
|
||
└─────────────────────────────────────┬─────────────────────────────────────┘
|
||
│
|
||
┌───────────────────────────┼───────────────────────────┐
|
||
│ │ │
|
||
▼ ▼ ▼
|
||
┌─────────────────────┐ ┌─────────────────────┐ ┌────────────────────<E29480><E29480><EFBFBD>┐
|
||
│ env_adapter.py │ │ card_model/ │ │ cfr_buffer.py │
|
||
│ │ │ (牌面特征提取) │ │ (经验回放池) │
|
||
│ - 状态提取 │ │ │ │ │
|
||
│ - 动作转换 │◄──┤ model.py │───►│ 添加/采样数据 │
|
||
│ - BetTranslator │ │ data_generator.py│ │ │
|
||
└─────────────────────┘ │ dataset.py │ └─────────────────────┘
|
||
│ │
|
||
│ train_card │
|
||
│ _model.py │
|
||
└─────────────┘
|
||
│
|
||
▼
|
||
┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐
|
||
│ cfr_net.py │ │ card_model/ │ │ card_model/ │
|
||
│ (策略网络) │ │ config.py │ │ dataset.py │
|
||
│ │ │ │ │ │
|
||
│ - get_strategy │ │ 配置常量 │ │ 数据集封装 │
|
||
│ - forward │ │ │ │ │
|
||
└─────────────────────┘ └─────────────────────┘ └─────────────────────┘
|
||
```
|
||
|
||
---
|
||
|
||
## 四、数据流向图
|
||
|
||
### 4.1 Card Model训练数据流
|
||
|
||
```
|
||
OpenSpiel引擎
|
||
│
|
||
▼
|
||
data_generator.generate_sample()
|
||
├── _sample_random_state() → 随机游戏状态
|
||
├── _monte_carlo_rollout() × 1000 → 胜率列表
|
||
└── 返回: (x_hole, x_board, y_equity, y_histogram)
|
||
│
|
||
▼
|
||
dataset.PokerCardDataset
|
||
└── DataLoader → mini-batch
|
||
│
|
||
▼
|
||
model.CardModel (forward)
|
||
├── embedding + MLP
|
||
├── equity_head → pred_equity
|
||
└── histogram_head → pred_histogram
|
||
│
|
||
▼
|
||
Loss = EMD + λ*MSE → 反向传播 → 更新权重
|
||
```
|
||
|
||
### 4.2 CFR训练数据流
|
||
|
||
```
|
||
OpenSpiel引擎 (new_initial_state())
|
||
│
|
||
▼
|
||
env_adapter.extract_env_state()
|
||
├── pot, stacks, street, position
|
||
└── legal_mask
|
||
│
|
||
▼
|
||
card_model.model (eval模式)
|
||
├── extract_cards_from_state() → hole_cards, board_cards
|
||
└── forward() → 50维 histogram
|
||
│
|
||
▼
|
||
拼接: card_features(50) + env_features(5) = 55维 info_state
|
||
│
|
||
▼
|
||
cfr_net.get_strategy(info_state, legal_mask)
|
||
├── Regret Matching → current_strategy
|
||
└── Softmax → avg_strategy
|
||
│
|
||
▼
|
||
如果当前玩家=遍历者:
|
||
│ 遍历所有合法动作,计算regret
|
||
│ buffer.add(info_state, legal_mask, regrets, strategy)
|
||
│
|
||
如果当前玩家≠对方:
|
||
按策略采样1个动作(外部采样)
|
||
│
|
||
▼
|
||
state.apply_action(engine_action)
|
||
│
|
||
▼
|
||
递归 traverse() → 直到 terminal state
|
||
│
|
||
▼
|
||
returns = state.returns() → 收益
|
||
```
|
||
|
||
### 4.3 完整Pipeline
|
||
|
||
```
|
||
┌──────────────────────────────────────────────────────────────────────────────────────┐
|
||
│ 训练 pipeline 流程 │
|
||
├────────────────────────────────────────────────<EFBFBD><EFBFBD>─<EFBFBD><EFBFBD>───┤
|
||
│ │
|
||
│ 1. 初始化 │
|
||
│ ├── 加载 OpenSpiel 游戏 │
|
||
│ ├── 加载/初始化 CardModel (eval模式) │
|
||
│ ├── 加载/初始化 CFRNetwork (train模式) │
|
||
│ ├── 初始化 CFRBuffer │
|
||
│ └── 初始化 Optimizer │
|
||
│ │
|
||
│ 2. 迭代循环 (50000次) │
|
||
│ │ │
|
||
│ ├── 阶段A: 数据生成 (200局/iter) │
|
||
│ │ ├── traverse() 递归遍历博弈树 │
|
||
│ │ │ ├── 提取 info_state (card+env特征) │
|
||
│ │ │ ├── CardModel 预测胜率直方图 │
|
||
│ │ │ ├── CFRNetwork 计算策略 │
|
||
│ │ │ ├── 遍历/采样动作 │
|
||
│ │ │ └── 计算 regret,存入 Buffer │
|
||
│ │ └── 累积经验数据 │
|
||
│ │ │
|
||
│ └── 阶段B: 网络训练 (50步/iter) │
|
||
│ └── train_step() │
|
||
│ ├── Buffer.sample(batch_size=32768) │
|
||
│ ├── 计算 Regret Loss + Policy Loss │
|
||
│ └── 反向传播更新 CFRNetwork │
|
||
│ │
|
||
│ 3. 保存模型 │
|
||
│ └── cfr_net_checkpoint.pt │
|
||
└──────────────────────────────────────────────────────┘
|
||
```
|
||
|
||
---
|
||
|
||
## 五、模块交互接口
|
||
|
||
### 5.1 env_adapter → cfr_net
|
||
|
||
```python
|
||
# env_adapter.py 输出
|
||
env_info = extract_env_state(state)
|
||
# {
|
||
# 'pot': 500,
|
||
# 'p0_stack': 19500,
|
||
# 'p1_stack': 19500,
|
||
# 'street': 1,
|
||
# 'position': 0,
|
||
# 'legal_mask': [1, 1, 1, 0, 0, 0]
|
||
# }
|
||
|
||
# 构建 env_features
|
||
env_features = [
|
||
pot / 20000,
|
||
p0_stack / 20000,
|
||
p1_stack / 20000,
|
||
street / 3.0,
|
||
position
|
||
] # 5维
|
||
```
|
||
|
||
### 5.2 card_model → cfr_net
|
||
|
||
```python
|
||
# card_model 输出一 50维胜率直方图
|
||
_, pred_histogram = card_model(x_hole, x_board) # [1, 50]
|
||
|
||
# cfr_net 输入 = 拼接
|
||
info_state = [card_features(50) + env_features(5)] # 55维
|
||
```
|
||
|
||
### 5.3 cfr_net → env_adapter
|
||
|
||
```python
|
||
# cfr_net 输出
|
||
current_strategy = net.get_strategy(card_features, env_features, legal_mask)
|
||
# [1, 6] 概率分布
|
||
|
||
# 采样动作
|
||
action_idx = random.choices(legal_indices, weights=current_strategy)[0]
|
||
|
||
# env_adapter 转换
|
||
engine_action = translator.cfr_to_engine_action(state, action_idx)
|
||
|
||
# 执行
|
||
state.apply_action(engine_action)
|
||
```
|
||
|
||
### 5.4 cfr_net ↔ cfr_buffer
|
||
|
||
```python
|
||
# traverse() 存入
|
||
buffer.add(info_state, legal_mask, regrets, strategy)
|
||
|
||
# train_step() 取出
|
||
info_states, legal_masks, target_regrets, target_strategies = buffer.sample(32768)
|
||
```
|
||
|
||
---
|
||
|
||
## 六、游戏规则与配置对齐
|
||
|
||
### 6.1 Botzone规则
|
||
|
||
```
|
||
- 玩家数: 2人 (Heads-up)
|
||
- 初始筹码: 20000
|
||
- 盲注: SB=50, BB=100
|
||
- 轮次: Preflop → Flop(3张) → Turn(1张) → River(1张)
|
||
- 下注模式: No-Limit
|
||
```
|
||
|
||
### 6.2 OpenSpiel配置
|
||
|
||
```python
|
||
HUNL_FULLGAME_STRING = (
|
||
"universal_poker("
|
||
"betting=nolimit,"
|
||
"numPlayers=2,"
|
||
"numRanks=13,"
|
||
"numSuits=4,"
|
||
"numHoleCards=2,"
|
||
"numRounds=4,"
|
||
"numBoardCards=0 3 1 1," # Preflop/Flop/Turn/River
|
||
"stack=20000 20000,"
|
||
"blind=50 100,"
|
||
"bettingAbstraction=fullgame)"
|
||
)
|
||
```
|
||
|
||
---
|
||
|
||
## 七、关键超参数汇总
|
||
|
||
### 7.1 Card Model
|
||
|
||
| 参数 | 值 | 说明 |
|
||
|------|-----|------|
|
||
| EMBEDDING_DIM | 32 | 卡牌embedding维度 |
|
||
| MLP_HIDDEN | [128,128,64] | MLP隐藏层 |
|
||
| NUM_BINS | 50 | 胜率直方图bin数 |
|
||
| NUM_ROLLOUTS | 1000 | 每样本MC rollout次数 |
|
||
| BATCH_SIZE | 16384 | 训练batch size |
|
||
| NUM_TRAIN_SAMPLES | 2,000,000 | 训练样本数 |
|
||
| NUM_EPOCHS | 64 | 训练轮数 |
|
||
| LEARNING_RATE | 5e-4 | 学习率 |
|
||
|
||
### 7.2 CFR Trainer
|
||
|
||
| 参数 | 值 | 说明 |
|
||
|------|-----|------|
|
||
| CARD_DIM | 50 | Card Model输出维度 |
|
||
| ENV_DIM | 5 | 局势特征维度 |
|
||
| NUM_ACTIONS | 6 | CFR动作数 |
|
||
| MLP_HIDDEN | [256,256,128] | CFR网络隐藏层 |
|
||
| NUM_ITERATIONS | 50,000 | 训练迭代次数 |
|
||
| GAMES_PER_ITER | 200 | 每迭代对局数 |
|
||
| BUFFER_MAX_SIZE | 10,000,000 | 经验池容量 |
|
||
| TRAIN_BATCH_SIZE | 32,768 | 训练batch size |
|
||
| TRAIN_STEPS_PER_ITER | 50 | 每迭代训练步数 |
|
||
| LEARNING_RATE | 5e-4 | 学习率 |
|
||
|
||
---
|
||
|
||
## 八、训练产物
|
||
|
||
### 8.1 Card Model
|
||
|
||
```
|
||
card_model/data/
|
||
├── train_data.npz # 训练数据缓存
|
||
├── val_data.npz # 验证数据缓存
|
||
├── best_card_model.pt # 最佳验证模型
|
||
└── final_card_model.pt # 最终模型
|
||
```
|
||
|
||
### 8.2 CFR Trainer
|
||
|
||
```
|
||
cfr_net_checkpoint.pt # CFR网络检查点
|
||
```
|
||
|
||
---
|
||
|
||
## 九、总结
|
||
|
||
这是一个典型的模块化AI系统设计:
|
||
|
||
1. **数据驱动**:Card Model学习牌面→胜率的映射,提供特征表示
|
||
2. **算法核心**:MCCFR迭代计算 regret,学习博弈策略
|
||
3. **工程桥接**:env_adapter连接引擎和算法,处理动作空间离散化
|
||
4. **存储中转**:cfr_buffer作为数据中转站,解耦生成和训练
|
||
|
||
各模块职责清晰,接口明确,通过张量传递形成完整的数据闭环。 |