Files
TexasPoker-AI/docs/structure.md
2026-05-07 17:38:05 +08:00

720 lines
26 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 神经网络架构 & C++ 在线推理指南
## 1. 系统总览
整个推理流水线由 **两个神经网络** 串联组成:
```
游戏状态 (OpenSpiel / Botzone)
├─── 提取牌面 ──→ CardModel ──→ pred_histogram [50] ──┐
│ ├─ concat → [55]
└─── 提取局势 ──→ env_features [5] ──┘
CFRNetwork
┌────────────┴────────────┐
│ │
regret_head [5] policy_head [5]
│ │
Regret Matching Softmax(masked)
│ │
current_strategy [5] avg_strategy [5]
```
**在线推理只需要 `avg_strategy`** —— 它是经过 Softmax 归一化的动作概率分布,从中采样即可得到最终动作。
---
## 2. CardModel 架构
### 2.1 用途
将扑克手牌2 张底牌 + 0~5 张公共牌)编码为一个 **50 维胜率直方图**,作为 CFRNetwork 的输入之一。
### 2.2 网络结构
```
输入:
x_hole: [batch, 2] int64 — 2 张底牌 ID (0-51)
x_board: [batch, 5] int64 — 5 张公共牌 ID (0-51, 不足用 52 填充)
Embedding: nn.Embedding(53, 64, padding_idx=52)
- 53 个 token: 52 张牌 + 1 个 PAD(52)
- padding_idx=52 表示 PAD token 的 embedding 恒为 0
编码:
hole_emb = Embedding(x_hole).sum(dim=1) → [batch, 64]
board_emb = Embedding(x_board).sum(dim=1) → [batch, 64]
combined = cat([hole_emb, board_emb]) → [batch, 128]
Backbone MLP (每层: Linear → ReLU → LayerNorm):
128 → 512 → ReLU → LayerNorm(512)
512 → 512 → ReLU → LayerNorm(512)
512 → 256 → ReLU → LayerNorm(256)
Equity Head (在线推理不需要):
256 → 32 → ReLU → 1 → Sigmoid → [batch, 1] 标量胜率
Histogram Head (在线推理只需要这个):
256 → 64 → ReLU → 50 → Softmax → [batch, 50] 胜率直方图
输出:
pred_equity: [batch, 1] 胜率标量 (0~1)
pred_histogram: [batch, 50] 胜率直方图 (和为1) ← 这是 CFRNetwork 的 card_features
```
### 2.3 权重参数明细
| 层名 (state_dict key) | 形状 | 参数量 |
|---|---|---|
| `embedding.weight` | [53, 64] | 3,392 |
| `backbone.0.weight` (Linear 128→512) | [512, 128] | 65,536 |
| `backbone.0.bias` | [512] | 512 |
| `backbone.2.weight` (LayerNorm 512) | [512] | 512 |
| `backbone.2.bias` (LayerNorm 512) | [512] | 512 |
| `backbone.3.weight` (Linear 512→512) | [512, 512] | 262,144 |
| `backbone.3.bias` | [512] | 512 |
| `backbone.5.weight` (LayerNorm 512) | [512] | 512 |
| `backbone.5.bias` (LayerNorm 512) | 512 |
| `backbone.6.weight` (Linear 512→256) | [256, 512] | 131,072 |
| `backbone.6.bias` | [256] | 256 |
| `backbone.8.weight` (LayerNorm 256) | [256] | 256 |
| `backbone.8.bias` (LayerNorm 256) | 256 |
| `equity_head.0.weight` (Linear 256→32) | [32, 256] | 8,192 |
| `equity_head.0.bias` | [32] | 32 |
| `equity_head.2.weight` (Linear 32→1) | [1, 32] | 32 |
| `equity_head.2.bias` | [1] | 1 |
| `histogram_head.0.weight` (Linear 256→64) | [64, 256] | 16,384 |
| `histogram_head.0.bias` | [64] | 64 |
| `histogram_head.2.weight` (Linear 64→50) | [50, 64] | 3,200 |
| `histogram_head.2.bias` | [50] | 50 |
| **合计** | | **~426,803** |
### 2.4 牌面 ID 编码规则
```
card_id = rank * 4 + suit
rank: 0=2, 1=3, 2=4, 3=5, 4=6, 5=7, 6=8, 7=9, 8=T, 9=J, 10=Q, 11=K, 12=A
suit: 0=c(梅花), 1=d(方块), 2=h(红心), 3=s(黑桃)
示例: Ac = 12*4+0 = 48, Ks = 11*4+3 = 47, 2c = 0*4+0 = 0
PAD_TOKEN = 52 (embedding 恒为零向量)
```
---
## 3. CFRNetwork 架构
### 3.1 用途
接受牌面特征 + 局势特征,输出 **5 个动作** 的遗憾值和策略 logits经 Regret Matching 和 Softmax 得到动作概率分布。
### 3.2 网络结构
```
输入:
card_features: [batch, 50] — CardModel 输出的胜率直方图
env_features: [batch, 5] — 归一化后的局势特征
拼接:
x = cat([card_features, env_features]) → [batch, 55]
Backbone MLP (每层: Linear → ReLU):
55 → 256 → ReLU
256 → 256 → ReLU
256 → 128 → ReLU
Regret Head (在线推理可跳过):
128 → 5 (无激活函数, regret 可为负数)
Policy Head:
128 → 5 (输出 logits, 后续过 Softmax)
输出:
regrets: [batch, 5] 遗憾值原始输出
policy_logits: [batch, 5] 策略 logits
```
### 3.3 权重参数明细
| 层名 (state_dict key) | 形状 | 参数量 |
|---|---|---|
| `backbone.0.weight` (Linear 55→256) | [256, 55] | 14,080 |
| `backbone.0.bias` | [256] | 256 |
| `backbone.2.weight` (Linear 256→256) | [256, 256] | 65,536 |
| `backbone.2.bias` | [256] | 256 |
| `backbone.4.weight` (Linear 256→128) | [128, 256] | 32,768 |
| `backbone.4.bias` | [128] | 128 |
| `regret_head.weight` (Linear 128→5) | [5, 128] | 640 |
| `regret_head.bias` | [5] | 5 |
| `policy_head.weight` (Linear 128→5) | [5, 128] | 640 |
| `policy_head.bias` | [5] | 5 |
| **合计** | | **~114,314** |
> 注意: `backbone` 是 `nn.Sequential`,索引 0,2,4 是 Linear 层,索引 1,3,5 是 ReLU无参数
### 3.4 动作空间
| 索引 | 名称 | 含义 |
|---|---|---|
| 0 | FOLD | 弃牌 |
| 1 | CALL | 跟注/过牌 |
| 2 | HALF_POT | 加注 = 1/2 跟注后底池 |
| 3 | FULL_POT | 加注 = 1.0 跟注后底池 |
| 4 | ALL_IN | 全押 |
---
## 4. 推理流水线详解
### 4.1 输入特征构造
#### 4.1.1 env_features [5] 构造
```
env_features = [
pot / 20000.0, // 底池归一化
p0_stack / 20000.0, // 玩家0剩余筹码归一化
p1_stack / 20000.0, // 玩家1剩余筹码归一化
street / 3.0, // 轮次归一化 (0=Preflop, 1=Flop, 2=Turn, 3=River)
position, // 当前行动玩家 (0.0 或 1.0)
]
```
#### 4.1.2 card_features [50] 构造
```
1. 提取当前玩家的 2 张底牌 ID → hole_cards[2]
2. 提取公共牌 ID (0~5张) → board_cards[0..5]
3. 公共牌不足 5 张时用 PAD_TOKEN=52 填充 → x_board[5]
4. 送入 CardModel:
- hole_emb = Embedding(hole_cards).sum(行向量求和) → [64]
- board_emb = Embedding(x_board).sum(行向量求和) → [64]
- combined = cat([hole_emb, board_emb]) → [128]
- features = Backbone(combined) → [256]
- histogram = HistogramHead(features) → [50]
5. card_features = histogram (50维胜率直方图)
```
#### 4.1.3 legal_mask [5] 构造
```
legal_mask = [fold_ok, call_ok, half_pot_ok, full_pot_ok, allin_ok]
每个元素为 0 或 1标识该 CFR 动作是否合法。
构造规则:
FOLD(0): 引擎合法动作包含 action 0
CALL(1): 引擎合法动作包含 action 1
ALL_IN(4): 引擎存在 >1 的加注动作
HALF_POT(2)/FULL_POT(3): 需同时满足:
- 存在加注动作
- 当前 street 加注次数 < 2 (RAISE_CAP)
- 计算目标贡献额 >= 最小加注额
- 目标贡献额映射后不等于 ALL-IN 动作
```
### 4.2 从 policy_logits 到动作概率 (avg_strategy)
这是在线推理的核心逻辑,只需使用 `policy_head` 的输出:
```
1. logits = policy_head(backbone_output) → [5]
2. masked_logits = logits
对 legal_mask[i]==0 的位置, 令 masked_logits[i] = -1e9
3. avg_strategy = softmax(masked_logits) → [5]
所有合法动作概率 > 0, 非法动作概率 ≈ 0
4. 从 avg_strategy 的合法动作中采样
```
### 4.3 降噪 (可选但推荐)
训练代码中的实践经验:将概率低于 3% 的动作直接置零后重新归一化,防止神经网络底噪导致异常 All-in。
---
## 5. C++ 实现指南
### 5.1 推荐方案: 手写前向传播 + 加载权重
由于网络结构简单(纯 MLP无卷积/注意力),**不需要 LibTorch**,直接用 Eigen 或手写矩阵乘法即可,部署体积小、推理快。
### 5.2 权重文件格式
PyTorch 的 `.pt` 文件本质是 Python pickle 序列化的 dict。C++ 直接读取比较麻烦,推荐两步转换:
**Step 1: Python 导出为二进制**
```python
import torch
import struct
def export_weights_bin(state_dict, output_path):
"""将 state_dict 导出为 C++ 可直接读取的二进制文件。"""
with open(output_path, 'wb') as f:
# 写入张量数量
f.write(struct.pack('I', len(state_dict)))
for name, tensor in state_dict.items():
# 写入名字长度 + 名字
name_bytes = name.encode('utf-8')
f.write(struct.pack('I', len(name_bytes)))
f.write(name_bytes)
# 写入维度数量
shape = tensor.shape
f.write(struct.pack('I', len(shape)))
# 写入每个维度大小
for dim in shape:
f.write(struct.pack('I', dim))
# 写入数据 (float32)
data = tensor.float().numpy().tobytes()
f.write(data)
# 导出 CardModel 权重
card_sd = torch.load("card_model/data/best_card_model.pt", map_location="cpu", weights_only=False)
if "model_state_dict" in card_sd:
card_sd = card_sd["model_state_dict"]
export_weights_bin(card_sd, "weights/card_model.bin")
# 导出 CFRNetwork 权重
cfr_sd = torch.load("botzone_cfr_net.pt", map_location="cpu", weights_only=False)
if "model_state_dict" in cfr_sd:
cfr_sd = cfr_sd["model_state_dict"]
export_weights_bin(cfr_sd, "weights/cfr_net.bin")
```
**Step 2: C++ 读取二进制权重**
```cpp
struct Tensor {
std::string name;
std::vector<int> shape;
std::vector<float> data;
};
std::unordered_map<std::string, Tensor> load_weights(const std::string& path) {
std::unordered_map<std::string, Tensor> weights;
std::ifstream f(path, std::ios::binary);
uint32_t num_tensors;
f.read((char*)&num_tensors, 4);
for (uint32_t i = 0; i < num_tensors; i++) {
Tensor t;
uint32_t name_len;
f.read((char*)&name_len, 4);
t.name.resize(name_len);
f.read(t.name.data(), name_len);
uint32_t ndim;
f.read((char*)&ndim, 4);
t.shape.resize(ndim);
for (uint32_t d = 0; d < ndim; d++)
f.read((char*)&t.shape[d], 4);
int total = 1;
for (int d : t.shape) total *= d;
t.data.resize(total);
f.read((char*)t.data.data(), total * sizeof(float));
weights[t.name] = std::move(t);
}
return weights;
}
```
### 5.3 C++ 前向传播实现
#### 5.3.1 基础算子
```cpp
// 矩阵乘法: y = W * x + b (W: [out, in], x: [in], b: [out])
void linear(const float* W, const float* b, const float* x,
float* y, int in_dim, int out_dim) {
for (int i = 0; i < out_dim; i++) {
float sum = b[i];
for (int j = 0; j < in_dim; j++) {
sum += W[i * in_dim + j] * x[j];
}
y[i] = sum;
}
}
// ReLU
void relu(float* x, int dim) {
for (int i = 0; i < dim; i++)
x[i] = std::max(0.0f, x[i]);
}
// LayerNorm: y = (x - mean) / sqrt(var + eps) * gamma + beta
void layer_norm(const float* gamma, const float* beta,
const float* x, float* y, int dim, float eps = 1e-5f) {
float mean = 0.0f;
for (int i = 0; i < dim; i++) mean += x[i];
mean /= dim;
float var = 0.0f;
for (int i = 0; i < dim; i++) var += (x[i] - mean) * (x[i] - mean);
var /= dim;
float inv_std = 1.0f / std::sqrt(var + eps);
for (int i = 0; i < dim; i++)
y[i] = gamma[i] * (x[i] - mean) * inv_std + beta[i];
}
// Sigmoid
void sigmoid(float* x, int dim) {
for (int i = 0; i < dim; i++)
x[i] = 1.0f / (1.0f + std::exp(-x[i]));
}
// Softmax
void softmax(float* x, int dim) {
float max_val = *std::max_element(x, x + dim);
float sum = 0.0f;
for (int i = 0; i < dim; i++) {
x[i] = std::exp(x[i] - max_val);
sum += x[i];
}
for (int i = 0; i < dim; i++) x[i] /= sum;
}
```
#### 5.3.2 CardModel 前向传播
```cpp
class CardModelInference {
public:
// 预分配缓冲区
float hole_emb[64]; // embedding sum
float board_emb[64]; // embedding sum
float combined[128]; // concat
float backbone_buf[3][512]; // 各隐藏层
float hist_fc1[64];
float hist_out[50];
// 权重引用 (从 load_weights 获取)
const float* emb_weight; // [53, 64]
float backbone_w[3], backbone_b[3]; // Linear 权重/偏置
float ln_gamma[3], ln_beta[3]; // LayerNorm 参数
float hist_w1, hist_b1; // 256→64
float hist_w2, hist_b2; // 64→50
void forward(const int* hole_cards, // [2] int, 0-51
const int* board_cards, // [5] int, 0-51 (52)
float* histogram) { // 输出 [50]
// 1. Embedding lookup + sum
// hole_emb = emb_weight[hole_cards[0]] + emb_weight[hole_cards[1]]
memset(hole_emb, 0, 64 * sizeof(float));
for (int c = 0; c < 2; c++) {
const float* emb = emb_weight + hole_cards[c] * 64;
for (int i = 0; i < 64; i++)
hole_emb[i] += emb[i];
}
// board_emb = sum of emb_weight[board_cards[i]], PAD(52) 的 embedding 全为0
memset(board_emb, 0, 64 * sizeof(float));
for (int c = 0; c < 5; c++) {
if (board_cards[c] == 52) continue; // PAD, skip
const float* emb = emb_weight + board_cards[c] * 64;
for (int i = 0; i < 64; i++)
board_emb[i] += emb[i];
}
// 2. Concat [hole_emb | board_emb]
memcpy(combined, hole_emb, 64 * sizeof(float));
memcpy(combined + 64, board_emb, 64 * sizeof(float));
// 3. Backbone: Linear → ReLU → LayerNorm × 3
int in_dim = 128;
int hidden_dims[3] = {512, 512, 256};
const float* input = combined;
for (int layer = 0; layer < 3; layer++) {
linear(backbone_w[layer], backbone_b[layer],
input, backbone_buf[layer], in_dim, hidden_dims[layer]);
relu(backbone_buf[layer], hidden_dims[layer]);
layer_norm(ln_gamma[layer], ln_beta[layer],
backbone_buf[layer], backbone_buf[layer], hidden_dims[layer]);
in_dim = hidden_dims[layer];
input = backbone_buf[layer];
}
// 4. Histogram head: 256 → 64 (ReLU) → 50 (Softmax)
linear(hist_w1, hist_b1, backbone_buf[2], hist_fc1, 256, 64);
relu(hist_fc1, 64);
linear(hist_w2, hist_b2, hist_fc1, hist_out, 64, 50);
softmax(hist_out, 50);
memcpy(histogram, hist_out, 50 * sizeof(float));
}
};
```
#### 5.3.3 CFRNetwork 前向传播 (只走 policy_head)
```cpp
class CFRNetInference {
public:
float concat_buf[55]; // card_features[50] + env_features[5]
float backbone_buf[3]; // 三层隐藏层
float logits[5]; // policy_head 输出
float strategy[5]; // 最终动作概率
// 权重
const float* backbone_w[3]; const float* backbone_b[3]; // 三层 Linear
const float* policy_w; const float* policy_b; // 128→5
void forward(const float* card_features, // [50]
const float* env_features, // [5]
const int* legal_mask, // [5], 0或1
float* out_strategy) { // 输出 [5]
// 1. Concat
memcpy(concat_buf, card_features, 50 * sizeof(float));
memcpy(concat_buf + 50, env_features, 5 * sizeof(float));
// 2. Backbone: Linear → ReLU × 3
int in_dim = 55;
int hidden_dims[3] = {256, 256, 128};
const float* input = concat_buf;
float* output = nullptr;
// 需要为每层分配缓冲区, 这里简化表示
// layer 0: 55 → 256
float h0[256];
linear(backbone_w[0], backbone_b[0], input, h0, 55, 256);
relu(h0, 256);
// layer 1: 256 → 256
float h1[256];
linear(backbone_w[1], backbone_b[1], h0, h1, 256, 256);
relu(h1, 256);
// layer 2: 256 → 128
float h2[128];
linear(backbone_w[2], backbone_b[2], h1, h2, 256, 128);
relu(h2, 128);
// 3. Policy head: 128 → 5
linear(policy_w, policy_b, h2, logits, 128, 5);
// 4. Masked Softmax
for (int i = 0; i < 5; i++) {
if (legal_mask[i] == 0)
logits[i] = -1e9f; // 非法动作设大负数
}
softmax(logits, 5);
memcpy(out_strategy, logits, 5 * sizeof(float));
}
};
```
#### 5.3.4 完整推理流程
```cpp
// === Step 1: 从游戏状态提取信息 ===
int hole_cards[2] = { /* 底牌 ID 0-51 */ };
int board_cards[5] = { /* 公共牌 ID, 不足5张用52填充 */ };
float env_features[5] = {
pot / 20000.0f,
p0_stack / 20000.0f,
p1_stack / 20000.0f,
street / 3.0f,
(float)position
};
int legal_mask[5] = { /* 由 BetTranslator 逻辑计算 */ };
// === Step 2: CardModel 前向传播 ===
float card_features[50];
card_model.forward(hole_cards, board_cards, card_features);
// === Step 3: CFRNetwork 前向传播 ===
float strategy[5];
cfr_net.forward(card_features, env_features, legal_mask, strategy);
// === Step 4: 采样 (可选: 先降噪) ===
// 降噪: 将 <3% 的概率置零后重新归一化
float threshold = 0.03f;
float sum = 0.0f;
for (int i = 0; i < 5; i++) {
if (legal_mask[i] && strategy[i] < threshold)
strategy[i] = 0.0f;
sum += strategy[i];
}
if (sum > 0) for (int i = 0; i < 5; i++) strategy[i] /= sum;
// 从合法动作中按概率采样
int chosen = sample_from_distribution(strategy, legal_mask);
// === Step 5: CFR 动作 → 引擎动作 (BetTranslator) ===
int engine_action = cfr_to_engine(state, chosen);
```
### 5.4 权重加载映射表
C++ 加载权重时,需要按照 PyTorch `state_dict` 的 key 名映射到对应的层:
#### CardModel 权重 key → 用途
| state_dict key | 形状 | 用途 |
|---|---|---|
| `embedding.weight` | [53, 64] | Embedding 查找表 |
| `backbone.0.weight` | [512, 128] | Linear 128→512 |
| `backbone.0.bias` | [512] | |
| `backbone.2.weight` | [512, 512] | LayerNorm gamma (错! 见下方) |
| `backbone.2.bias` | [512] | LayerNorm beta |
| `backbone.3.weight` | [512, 512] | Linear 512→512 |
| `backbone.3.bias` | [512] | |
| `backbone.5.weight` | [512, 512] | LayerNorm gamma |
| `backbone.5.bias` | [512] | LayerNorm beta |
| `backbone.6.weight` | [256, 512] | Linear 512→256 |
| `backbone.6.bias` | [256] | |
| `backbone.8.weight` | [256] | LayerNorm gamma |
| `backbone.8.bias` | [256] | LayerNorm beta |
| `histogram_head.0.weight` | [64, 256] | Linear 256→64 |
| `histogram_head.0.bias` | [64] | |
| `histogram_head.2.weight` | [50, 64] | Linear 64→50 |
| `histogram_head.2.bias` | [50] | |
| `equity_head.*` | - | 在线推理不需要 |
> **重要**: CardModel 的 `backbone` 是 `nn.Sequential`,层的索引对应关系为:
> - `backbone.0` = Linear(128, 512)
> - `backbone.1` = ReLU (无参数)
> - `backbone.2` = LayerNorm(512)
> - `backbone.3` = Linear(512, 512)
> - `backbone.4` = ReLU (无参数)
> - `backbone.5` = LayerNorm(512)
> - `backbone.6` = Linear(512, 256)
> - `backbone.7` = ReLU (无参数)
> - `backbone.8` = LayerNorm(256)
#### CFRNetwork 权重 key → 用途
| state_dict key | 形状 | 用途 |
|---|---|---|
| `backbone.0.weight` | [256, 55] | Linear 55→256 |
| `backbone.0.bias` | [256] | |
| `backbone.2.weight` | [256, 256] | Linear 256→256 |
| `backbone.2.bias` | [256] | |
| `backbone.4.weight` | [128, 256] | Linear 256→128 |
| `backbone.4.bias` | [128] | |
| `regret_head.weight` | [5, 128] | 在线推理不需要 |
| `regret_head.bias` | [5] | 在线推理不需要 |
| `policy_head.weight` | [5, 128] | Linear 128→5 |
| `policy_head.bias` | [5] | |
> **重要**: CFRNetwork 的 `backbone` 索引:
> - `backbone.0` = Linear(55, 256)
> - `backbone.1` = ReLU (无参数)
> - `backbone.2` = Linear(256, 256)
> - `backbone.3` = ReLU (无参数)
> - `backbone.4` = Linear(256, 128)
> - `backbone.5` = ReLU (无参数)
---
## 6. 可优化项
### 6.1 是否需要 CardModel?
**需要**。CardModel 是推理流水线的必要组成部分,它将离散的牌面信息编码为 50 维连续向量CFRNetwork 依赖这个输入。如果去掉 CardModel你需要另外设计牌面编码方式且训练好的 CFRNetwork 权重将无法使用。
### 6.2 是否需要 Regret Head?
**在线推理不需要**。Regret Head 用于训练时的 Regret Matching在线推理只使用 Policy Head + Softmax 得到 `avg_strategy`。可以不加载 `regret_head` 的权重以节省内存。
### 6.3 是否需要 Equity Head?
**在线推理不需要**。Equity Head 只输出标量胜率用于监控,推理时只需要 Histogram Head 的 50 维输出。
### 6.4 替代方案: LibTorch
如果不想手写前向传播,可以使用 LibTorch (PyTorch C++ API) 直接加载 `.pt` 权重并执行推理。优点是代码量少,缺点是依赖体积大 (~200MB+)。
### 6.5 替代方案: ONNX Runtime
可以将两个模型导出为 ONNX 格式,用 ONNX Runtime C++ API 推理。兼顾易用性和性能。
```python
# 导出 CardModel 到 ONNX
card_model = CardModel()
card_model.load_state_dict(torch.load("card_model/data/best_card_model.pt", map_location="cpu", weights_only=False))
card_model.eval()
dummy_hole = torch.randint(0, 52, (1, 2))
dummy_board = torch.randint(0, 52, (1, 5))
torch.onnx.export(card_model, (dummy_hole, dummy_board), "card_model.onnx",
input_names=["x_hole", "x_board"],
output_names=["pred_equity", "pred_histogram"])
# 导出 CFRNetwork 到 ONNX
cfr_net = CFRNetwork()
cfr_net.load_state_dict(torch.load("botzone_cfr_net.pt", map_location="cpu", weights_only=False))
cfr_net.eval()
dummy_card = torch.randn(1, 50)
dummy_env = torch.randn(1, 5)
torch.onnx.export(cfr_net, (dummy_card, dummy_env), "cfr_net.onnx",
input_names=["card_features", "env_features"],
output_names=["regrets", "policy_logits"])
```
---
## 7. 数据流总结
```
┌─────────────────────────────────────────────────────────────────┐
│ 在线推理数据流 │
├─────────────────────────────────────────────────────────────────┤
│ │
│ 游戏状态 │
│ ├── hole_cards[2] (int, 0-51) │
│ ├── board_cards[5] (int, 0-51 or 52=PAD) │
│ ├── pot, stacks, street, position │
│ └── legal_actions (引擎原生) │
│ │ │
│ ▼ │
│ ┌─────────────────┐ ┌──────────────────┐ │
│ │ CardModel │ │ Feature Builder │ │
│ │ Embedding(53,64)│ │ pot/20000 │ │
│ │ sum → cat │ │ p0_stack/20000 │ │
│ │ MLP 128→512→ │ │ p1_stack/20000 │ │
│ │ 512→256 │ │ street/3.0 │ │
│ │ +LayerNorm │ │ position (0|1) │ │
│ │ HistHead 256→ │ └────────┬─────────┘ │
│ │ 64→50+Sofmax │ │ │
│ └────────┬────────┘ │ │
│ │ │ │
│ card_features[50] env_features[5] │
│ │ │ │
│ └──────────┬───────────┘ │
│ ▼ │
│ ┌─────────────────────┐ │
│ │ CFRNetwork │ │
│ │ cat → [55] │ │
│ │ MLP 55→256→256→128 │ │
│ │ PolicyHead 128→5 │ │
│ └────────┬────────────┘ │
│ │ │
│ policy_logits[5] │
│ │ │
│ masked_fill(非法→-1e9) │
│ │ │
│ Softmax → avg_strategy[5] │
│ │ │
│ 降噪 (<3%→0, 重归一化) │
│ │ │
│ 按概率采样 → chosen_cfr_idx │
│ │ │
│ BetTranslator → engine_action │
│ │ │
│ 执行引擎动作 │
│ │
└─────────────────────────────────────────────────────────────────┘
```