720 lines
26 KiB
Markdown
720 lines
26 KiB
Markdown
# 神经网络架构 & C++ 在线推理指南
|
||
|
||
## 1. 系统总览
|
||
|
||
整个推理流水线由 **两个神经网络** 串联组成:
|
||
|
||
```
|
||
游戏状态 (OpenSpiel / Botzone)
|
||
│
|
||
├─── 提取牌面 ──→ CardModel ──→ pred_histogram [50] ──┐
|
||
│ ├─ concat → [55]
|
||
└─── 提取局势 ──→ env_features [5] ──┘
|
||
│
|
||
CFRNetwork
|
||
│
|
||
┌────────────┴────────────┐
|
||
│ │
|
||
regret_head [5] policy_head [5]
|
||
│ │
|
||
Regret Matching Softmax(masked)
|
||
│ │
|
||
current_strategy [5] avg_strategy [5]
|
||
```
|
||
|
||
**在线推理只需要 `avg_strategy`** —— 它是经过 Softmax 归一化的动作概率分布,从中采样即可得到最终动作。
|
||
|
||
---
|
||
|
||
## 2. CardModel 架构
|
||
|
||
### 2.1 用途
|
||
|
||
将扑克手牌(2 张底牌 + 0~5 张公共牌)编码为一个 **50 维胜率直方图**,作为 CFRNetwork 的输入之一。
|
||
|
||
### 2.2 网络结构
|
||
|
||
```
|
||
输入:
|
||
x_hole: [batch, 2] int64 — 2 张底牌 ID (0-51)
|
||
x_board: [batch, 5] int64 — 5 张公共牌 ID (0-51, 不足用 52 填充)
|
||
|
||
Embedding: nn.Embedding(53, 64, padding_idx=52)
|
||
- 53 个 token: 52 张牌 + 1 个 PAD(52)
|
||
- padding_idx=52 表示 PAD token 的 embedding 恒为 0
|
||
|
||
编码:
|
||
hole_emb = Embedding(x_hole).sum(dim=1) → [batch, 64]
|
||
board_emb = Embedding(x_board).sum(dim=1) → [batch, 64]
|
||
combined = cat([hole_emb, board_emb]) → [batch, 128]
|
||
|
||
Backbone MLP (每层: Linear → ReLU → LayerNorm):
|
||
128 → 512 → ReLU → LayerNorm(512)
|
||
512 → 512 → ReLU → LayerNorm(512)
|
||
512 → 256 → ReLU → LayerNorm(256)
|
||
|
||
Equity Head (在线推理不需要):
|
||
256 → 32 → ReLU → 1 → Sigmoid → [batch, 1] 标量胜率
|
||
|
||
Histogram Head (在线推理只需要这个):
|
||
256 → 64 → ReLU → 50 → Softmax → [batch, 50] 胜率直方图
|
||
|
||
输出:
|
||
pred_equity: [batch, 1] 胜率标量 (0~1)
|
||
pred_histogram: [batch, 50] 胜率直方图 (和为1) ← 这是 CFRNetwork 的 card_features
|
||
```
|
||
|
||
### 2.3 权重参数明细
|
||
|
||
| 层名 (state_dict key) | 形状 | 参数量 |
|
||
|---|---|---|
|
||
| `embedding.weight` | [53, 64] | 3,392 |
|
||
| `backbone.0.weight` (Linear 128→512) | [512, 128] | 65,536 |
|
||
| `backbone.0.bias` | [512] | 512 |
|
||
| `backbone.2.weight` (LayerNorm 512) | [512] | 512 |
|
||
| `backbone.2.bias` (LayerNorm 512) | [512] | 512 |
|
||
| `backbone.3.weight` (Linear 512→512) | [512, 512] | 262,144 |
|
||
| `backbone.3.bias` | [512] | 512 |
|
||
| `backbone.5.weight` (LayerNorm 512) | [512] | 512 |
|
||
| `backbone.5.bias` (LayerNorm 512) | 512 |
|
||
| `backbone.6.weight` (Linear 512→256) | [256, 512] | 131,072 |
|
||
| `backbone.6.bias` | [256] | 256 |
|
||
| `backbone.8.weight` (LayerNorm 256) | [256] | 256 |
|
||
| `backbone.8.bias` (LayerNorm 256) | 256 |
|
||
| `equity_head.0.weight` (Linear 256→32) | [32, 256] | 8,192 |
|
||
| `equity_head.0.bias` | [32] | 32 |
|
||
| `equity_head.2.weight` (Linear 32→1) | [1, 32] | 32 |
|
||
| `equity_head.2.bias` | [1] | 1 |
|
||
| `histogram_head.0.weight` (Linear 256→64) | [64, 256] | 16,384 |
|
||
| `histogram_head.0.bias` | [64] | 64 |
|
||
| `histogram_head.2.weight` (Linear 64→50) | [50, 64] | 3,200 |
|
||
| `histogram_head.2.bias` | [50] | 50 |
|
||
| **合计** | | **~426,803** |
|
||
|
||
### 2.4 牌面 ID 编码规则
|
||
|
||
```
|
||
card_id = rank * 4 + suit
|
||
|
||
rank: 0=2, 1=3, 2=4, 3=5, 4=6, 5=7, 6=8, 7=9, 8=T, 9=J, 10=Q, 11=K, 12=A
|
||
suit: 0=c(梅花), 1=d(方块), 2=h(红心), 3=s(黑桃)
|
||
|
||
示例: Ac = 12*4+0 = 48, Ks = 11*4+3 = 47, 2c = 0*4+0 = 0
|
||
PAD_TOKEN = 52 (embedding 恒为零向量)
|
||
```
|
||
|
||
---
|
||
|
||
## 3. CFRNetwork 架构
|
||
|
||
### 3.1 用途
|
||
|
||
接受牌面特征 + 局势特征,输出 **5 个动作** 的遗憾值和策略 logits,经 Regret Matching 和 Softmax 得到动作概率分布。
|
||
|
||
### 3.2 网络结构
|
||
|
||
```
|
||
输入:
|
||
card_features: [batch, 50] — CardModel 输出的胜率直方图
|
||
env_features: [batch, 5] — 归一化后的局势特征
|
||
|
||
拼接:
|
||
x = cat([card_features, env_features]) → [batch, 55]
|
||
|
||
Backbone MLP (每层: Linear → ReLU):
|
||
55 → 256 → ReLU
|
||
256 → 256 → ReLU
|
||
256 → 128 → ReLU
|
||
|
||
Regret Head (在线推理可跳过):
|
||
128 → 5 (无激活函数, regret 可为负数)
|
||
|
||
Policy Head:
|
||
128 → 5 (输出 logits, 后续过 Softmax)
|
||
|
||
输出:
|
||
regrets: [batch, 5] 遗憾值原始输出
|
||
policy_logits: [batch, 5] 策略 logits
|
||
```
|
||
|
||
### 3.3 权重参数明细
|
||
|
||
| 层名 (state_dict key) | 形状 | 参数量 |
|
||
|---|---|---|
|
||
| `backbone.0.weight` (Linear 55→256) | [256, 55] | 14,080 |
|
||
| `backbone.0.bias` | [256] | 256 |
|
||
| `backbone.2.weight` (Linear 256→256) | [256, 256] | 65,536 |
|
||
| `backbone.2.bias` | [256] | 256 |
|
||
| `backbone.4.weight` (Linear 256→128) | [128, 256] | 32,768 |
|
||
| `backbone.4.bias` | [128] | 128 |
|
||
| `regret_head.weight` (Linear 128→5) | [5, 128] | 640 |
|
||
| `regret_head.bias` | [5] | 5 |
|
||
| `policy_head.weight` (Linear 128→5) | [5, 128] | 640 |
|
||
| `policy_head.bias` | [5] | 5 |
|
||
| **合计** | | **~114,314** |
|
||
|
||
> 注意: `backbone` 是 `nn.Sequential`,索引 0,2,4 是 Linear 层,索引 1,3,5 是 ReLU(无参数)。
|
||
|
||
### 3.4 动作空间
|
||
|
||
| 索引 | 名称 | 含义 |
|
||
|---|---|---|
|
||
| 0 | FOLD | 弃牌 |
|
||
| 1 | CALL | 跟注/过牌 |
|
||
| 2 | HALF_POT | 加注 = 1/2 跟注后底池 |
|
||
| 3 | FULL_POT | 加注 = 1.0 跟注后底池 |
|
||
| 4 | ALL_IN | 全押 |
|
||
|
||
---
|
||
|
||
## 4. 推理流水线详解
|
||
|
||
### 4.1 输入特征构造
|
||
|
||
#### 4.1.1 env_features [5] 构造
|
||
|
||
```
|
||
env_features = [
|
||
pot / 20000.0, // 底池归一化
|
||
p0_stack / 20000.0, // 玩家0剩余筹码归一化
|
||
p1_stack / 20000.0, // 玩家1剩余筹码归一化
|
||
street / 3.0, // 轮次归一化 (0=Preflop, 1=Flop, 2=Turn, 3=River)
|
||
position, // 当前行动玩家 (0.0 或 1.0)
|
||
]
|
||
```
|
||
|
||
#### 4.1.2 card_features [50] 构造
|
||
|
||
```
|
||
1. 提取当前玩家的 2 张底牌 ID → hole_cards[2]
|
||
2. 提取公共牌 ID (0~5张) → board_cards[0..5]
|
||
3. 公共牌不足 5 张时用 PAD_TOKEN=52 填充 → x_board[5]
|
||
4. 送入 CardModel:
|
||
- hole_emb = Embedding(hole_cards).sum(行向量求和) → [64]
|
||
- board_emb = Embedding(x_board).sum(行向量求和) → [64]
|
||
- combined = cat([hole_emb, board_emb]) → [128]
|
||
- features = Backbone(combined) → [256]
|
||
- histogram = HistogramHead(features) → [50]
|
||
5. card_features = histogram (50维胜率直方图)
|
||
```
|
||
|
||
#### 4.1.3 legal_mask [5] 构造
|
||
|
||
```
|
||
legal_mask = [fold_ok, call_ok, half_pot_ok, full_pot_ok, allin_ok]
|
||
每个元素为 0 或 1,标识该 CFR 动作是否合法。
|
||
|
||
构造规则:
|
||
FOLD(0): 引擎合法动作包含 action 0
|
||
CALL(1): 引擎合法动作包含 action 1
|
||
ALL_IN(4): 引擎存在 >1 的加注动作
|
||
HALF_POT(2)/FULL_POT(3): 需同时满足:
|
||
- 存在加注动作
|
||
- 当前 street 加注次数 < 2 (RAISE_CAP)
|
||
- 计算目标贡献额 >= 最小加注额
|
||
- 目标贡献额映射后不等于 ALL-IN 动作
|
||
```
|
||
|
||
### 4.2 从 policy_logits 到动作概率 (avg_strategy)
|
||
|
||
这是在线推理的核心逻辑,只需使用 `policy_head` 的输出:
|
||
|
||
```
|
||
1. logits = policy_head(backbone_output) → [5]
|
||
2. masked_logits = logits
|
||
对 legal_mask[i]==0 的位置, 令 masked_logits[i] = -1e9
|
||
3. avg_strategy = softmax(masked_logits) → [5]
|
||
所有合法动作概率 > 0, 非法动作概率 ≈ 0
|
||
4. 从 avg_strategy 的合法动作中采样
|
||
```
|
||
|
||
### 4.3 降噪 (可选但推荐)
|
||
|
||
训练代码中的实践经验:将概率低于 3% 的动作直接置零后重新归一化,防止神经网络底噪导致异常 All-in。
|
||
|
||
---
|
||
|
||
## 5. C++ 实现指南
|
||
|
||
### 5.1 推荐方案: 手写前向传播 + 加载权重
|
||
|
||
由于网络结构简单(纯 MLP,无卷积/注意力),**不需要 LibTorch**,直接用 Eigen 或手写矩阵乘法即可,部署体积小、推理快。
|
||
|
||
### 5.2 权重文件格式
|
||
|
||
PyTorch 的 `.pt` 文件本质是 Python pickle 序列化的 dict。C++ 直接读取比较麻烦,推荐两步转换:
|
||
|
||
**Step 1: Python 导出为二进制**
|
||
|
||
```python
|
||
import torch
|
||
import struct
|
||
|
||
def export_weights_bin(state_dict, output_path):
|
||
"""将 state_dict 导出为 C++ 可直接读取的二进制文件。"""
|
||
with open(output_path, 'wb') as f:
|
||
# 写入张量数量
|
||
f.write(struct.pack('I', len(state_dict)))
|
||
|
||
for name, tensor in state_dict.items():
|
||
# 写入名字长度 + 名字
|
||
name_bytes = name.encode('utf-8')
|
||
f.write(struct.pack('I', len(name_bytes)))
|
||
f.write(name_bytes)
|
||
|
||
# 写入维度数量
|
||
shape = tensor.shape
|
||
f.write(struct.pack('I', len(shape)))
|
||
|
||
# 写入每个维度大小
|
||
for dim in shape:
|
||
f.write(struct.pack('I', dim))
|
||
|
||
# 写入数据 (float32)
|
||
data = tensor.float().numpy().tobytes()
|
||
f.write(data)
|
||
|
||
# 导出 CardModel 权重
|
||
card_sd = torch.load("card_model/data/best_card_model.pt", map_location="cpu", weights_only=False)
|
||
if "model_state_dict" in card_sd:
|
||
card_sd = card_sd["model_state_dict"]
|
||
export_weights_bin(card_sd, "weights/card_model.bin")
|
||
|
||
# 导出 CFRNetwork 权重
|
||
cfr_sd = torch.load("botzone_cfr_net.pt", map_location="cpu", weights_only=False)
|
||
if "model_state_dict" in cfr_sd:
|
||
cfr_sd = cfr_sd["model_state_dict"]
|
||
export_weights_bin(cfr_sd, "weights/cfr_net.bin")
|
||
```
|
||
|
||
**Step 2: C++ 读取二进制权重**
|
||
|
||
```cpp
|
||
struct Tensor {
|
||
std::string name;
|
||
std::vector<int> shape;
|
||
std::vector<float> data;
|
||
};
|
||
|
||
std::unordered_map<std::string, Tensor> load_weights(const std::string& path) {
|
||
std::unordered_map<std::string, Tensor> weights;
|
||
std::ifstream f(path, std::ios::binary);
|
||
|
||
uint32_t num_tensors;
|
||
f.read((char*)&num_tensors, 4);
|
||
|
||
for (uint32_t i = 0; i < num_tensors; i++) {
|
||
Tensor t;
|
||
uint32_t name_len;
|
||
f.read((char*)&name_len, 4);
|
||
t.name.resize(name_len);
|
||
f.read(t.name.data(), name_len);
|
||
|
||
uint32_t ndim;
|
||
f.read((char*)&ndim, 4);
|
||
t.shape.resize(ndim);
|
||
for (uint32_t d = 0; d < ndim; d++)
|
||
f.read((char*)&t.shape[d], 4);
|
||
|
||
int total = 1;
|
||
for (int d : t.shape) total *= d;
|
||
t.data.resize(total);
|
||
f.read((char*)t.data.data(), total * sizeof(float));
|
||
|
||
weights[t.name] = std::move(t);
|
||
}
|
||
return weights;
|
||
}
|
||
```
|
||
|
||
### 5.3 C++ 前向传播实现
|
||
|
||
#### 5.3.1 基础算子
|
||
|
||
```cpp
|
||
// 矩阵乘法: y = W * x + b (W: [out, in], x: [in], b: [out])
|
||
void linear(const float* W, const float* b, const float* x,
|
||
float* y, int in_dim, int out_dim) {
|
||
for (int i = 0; i < out_dim; i++) {
|
||
float sum = b[i];
|
||
for (int j = 0; j < in_dim; j++) {
|
||
sum += W[i * in_dim + j] * x[j];
|
||
}
|
||
y[i] = sum;
|
||
}
|
||
}
|
||
|
||
// ReLU
|
||
void relu(float* x, int dim) {
|
||
for (int i = 0; i < dim; i++)
|
||
x[i] = std::max(0.0f, x[i]);
|
||
}
|
||
|
||
// LayerNorm: y = (x - mean) / sqrt(var + eps) * gamma + beta
|
||
void layer_norm(const float* gamma, const float* beta,
|
||
const float* x, float* y, int dim, float eps = 1e-5f) {
|
||
float mean = 0.0f;
|
||
for (int i = 0; i < dim; i++) mean += x[i];
|
||
mean /= dim;
|
||
|
||
float var = 0.0f;
|
||
for (int i = 0; i < dim; i++) var += (x[i] - mean) * (x[i] - mean);
|
||
var /= dim;
|
||
|
||
float inv_std = 1.0f / std::sqrt(var + eps);
|
||
for (int i = 0; i < dim; i++)
|
||
y[i] = gamma[i] * (x[i] - mean) * inv_std + beta[i];
|
||
}
|
||
|
||
// Sigmoid
|
||
void sigmoid(float* x, int dim) {
|
||
for (int i = 0; i < dim; i++)
|
||
x[i] = 1.0f / (1.0f + std::exp(-x[i]));
|
||
}
|
||
|
||
// Softmax
|
||
void softmax(float* x, int dim) {
|
||
float max_val = *std::max_element(x, x + dim);
|
||
float sum = 0.0f;
|
||
for (int i = 0; i < dim; i++) {
|
||
x[i] = std::exp(x[i] - max_val);
|
||
sum += x[i];
|
||
}
|
||
for (int i = 0; i < dim; i++) x[i] /= sum;
|
||
}
|
||
```
|
||
|
||
#### 5.3.2 CardModel 前向传播
|
||
|
||
```cpp
|
||
class CardModelInference {
|
||
public:
|
||
// 预分配缓冲区
|
||
float hole_emb[64]; // embedding sum
|
||
float board_emb[64]; // embedding sum
|
||
float combined[128]; // concat
|
||
float backbone_buf[3][512]; // 各隐藏层
|
||
float hist_fc1[64];
|
||
float hist_out[50];
|
||
|
||
// 权重引用 (从 load_weights 获取)
|
||
const float* emb_weight; // [53, 64]
|
||
float backbone_w[3], backbone_b[3]; // Linear 权重/偏置
|
||
float ln_gamma[3], ln_beta[3]; // LayerNorm 参数
|
||
float hist_w1, hist_b1; // 256→64
|
||
float hist_w2, hist_b2; // 64→50
|
||
|
||
void forward(const int* hole_cards, // [2] int, 0-51
|
||
const int* board_cards, // [5] int, 0-51 (不足填52)
|
||
float* histogram) { // 输出 [50]
|
||
|
||
// 1. Embedding lookup + sum
|
||
// hole_emb = emb_weight[hole_cards[0]] + emb_weight[hole_cards[1]]
|
||
memset(hole_emb, 0, 64 * sizeof(float));
|
||
for (int c = 0; c < 2; c++) {
|
||
const float* emb = emb_weight + hole_cards[c] * 64;
|
||
for (int i = 0; i < 64; i++)
|
||
hole_emb[i] += emb[i];
|
||
}
|
||
// board_emb = sum of emb_weight[board_cards[i]], PAD(52) 的 embedding 全为0
|
||
memset(board_emb, 0, 64 * sizeof(float));
|
||
for (int c = 0; c < 5; c++) {
|
||
if (board_cards[c] == 52) continue; // PAD, skip
|
||
const float* emb = emb_weight + board_cards[c] * 64;
|
||
for (int i = 0; i < 64; i++)
|
||
board_emb[i] += emb[i];
|
||
}
|
||
|
||
// 2. Concat [hole_emb | board_emb]
|
||
memcpy(combined, hole_emb, 64 * sizeof(float));
|
||
memcpy(combined + 64, board_emb, 64 * sizeof(float));
|
||
|
||
// 3. Backbone: Linear → ReLU → LayerNorm × 3
|
||
int in_dim = 128;
|
||
int hidden_dims[3] = {512, 512, 256};
|
||
const float* input = combined;
|
||
|
||
for (int layer = 0; layer < 3; layer++) {
|
||
linear(backbone_w[layer], backbone_b[layer],
|
||
input, backbone_buf[layer], in_dim, hidden_dims[layer]);
|
||
relu(backbone_buf[layer], hidden_dims[layer]);
|
||
layer_norm(ln_gamma[layer], ln_beta[layer],
|
||
backbone_buf[layer], backbone_buf[layer], hidden_dims[layer]);
|
||
in_dim = hidden_dims[layer];
|
||
input = backbone_buf[layer];
|
||
}
|
||
|
||
// 4. Histogram head: 256 → 64 (ReLU) → 50 (Softmax)
|
||
linear(hist_w1, hist_b1, backbone_buf[2], hist_fc1, 256, 64);
|
||
relu(hist_fc1, 64);
|
||
linear(hist_w2, hist_b2, hist_fc1, hist_out, 64, 50);
|
||
softmax(hist_out, 50);
|
||
|
||
memcpy(histogram, hist_out, 50 * sizeof(float));
|
||
}
|
||
};
|
||
```
|
||
|
||
#### 5.3.3 CFRNetwork 前向传播 (只走 policy_head)
|
||
|
||
```cpp
|
||
class CFRNetInference {
|
||
public:
|
||
float concat_buf[55]; // card_features[50] + env_features[5]
|
||
float backbone_buf[3]; // 三层隐藏层
|
||
float logits[5]; // policy_head 输出
|
||
float strategy[5]; // 最终动作概率
|
||
|
||
// 权重
|
||
const float* backbone_w[3]; const float* backbone_b[3]; // 三层 Linear
|
||
const float* policy_w; const float* policy_b; // 128→5
|
||
|
||
void forward(const float* card_features, // [50]
|
||
const float* env_features, // [5]
|
||
const int* legal_mask, // [5], 0或1
|
||
float* out_strategy) { // 输出 [5]
|
||
|
||
// 1. Concat
|
||
memcpy(concat_buf, card_features, 50 * sizeof(float));
|
||
memcpy(concat_buf + 50, env_features, 5 * sizeof(float));
|
||
|
||
// 2. Backbone: Linear → ReLU × 3
|
||
int in_dim = 55;
|
||
int hidden_dims[3] = {256, 256, 128};
|
||
const float* input = concat_buf;
|
||
float* output = nullptr;
|
||
|
||
// 需要为每层分配缓冲区, 这里简化表示
|
||
// layer 0: 55 → 256
|
||
float h0[256];
|
||
linear(backbone_w[0], backbone_b[0], input, h0, 55, 256);
|
||
relu(h0, 256);
|
||
|
||
// layer 1: 256 → 256
|
||
float h1[256];
|
||
linear(backbone_w[1], backbone_b[1], h0, h1, 256, 256);
|
||
relu(h1, 256);
|
||
|
||
// layer 2: 256 → 128
|
||
float h2[128];
|
||
linear(backbone_w[2], backbone_b[2], h1, h2, 256, 128);
|
||
relu(h2, 128);
|
||
|
||
// 3. Policy head: 128 → 5
|
||
linear(policy_w, policy_b, h2, logits, 128, 5);
|
||
|
||
// 4. Masked Softmax
|
||
for (int i = 0; i < 5; i++) {
|
||
if (legal_mask[i] == 0)
|
||
logits[i] = -1e9f; // 非法动作设大负数
|
||
}
|
||
softmax(logits, 5);
|
||
|
||
memcpy(out_strategy, logits, 5 * sizeof(float));
|
||
}
|
||
};
|
||
```
|
||
|
||
#### 5.3.4 完整推理流程
|
||
|
||
```cpp
|
||
// === Step 1: 从游戏状态提取信息 ===
|
||
int hole_cards[2] = { /* 底牌 ID 0-51 */ };
|
||
int board_cards[5] = { /* 公共牌 ID, 不足5张用52填充 */ };
|
||
float env_features[5] = {
|
||
pot / 20000.0f,
|
||
p0_stack / 20000.0f,
|
||
p1_stack / 20000.0f,
|
||
street / 3.0f,
|
||
(float)position
|
||
};
|
||
int legal_mask[5] = { /* 由 BetTranslator 逻辑计算 */ };
|
||
|
||
// === Step 2: CardModel 前向传播 ===
|
||
float card_features[50];
|
||
card_model.forward(hole_cards, board_cards, card_features);
|
||
|
||
// === Step 3: CFRNetwork 前向传播 ===
|
||
float strategy[5];
|
||
cfr_net.forward(card_features, env_features, legal_mask, strategy);
|
||
|
||
// === Step 4: 采样 (可选: 先降噪) ===
|
||
// 降噪: 将 <3% 的概率置零后重新归一化
|
||
float threshold = 0.03f;
|
||
float sum = 0.0f;
|
||
for (int i = 0; i < 5; i++) {
|
||
if (legal_mask[i] && strategy[i] < threshold)
|
||
strategy[i] = 0.0f;
|
||
sum += strategy[i];
|
||
}
|
||
if (sum > 0) for (int i = 0; i < 5; i++) strategy[i] /= sum;
|
||
|
||
// 从合法动作中按概率采样
|
||
int chosen = sample_from_distribution(strategy, legal_mask);
|
||
|
||
// === Step 5: CFR 动作 → 引擎动作 (BetTranslator) ===
|
||
int engine_action = cfr_to_engine(state, chosen);
|
||
```
|
||
|
||
### 5.4 权重加载映射表
|
||
|
||
C++ 加载权重时,需要按照 PyTorch `state_dict` 的 key 名映射到对应的层:
|
||
|
||
#### CardModel 权重 key → 用途
|
||
|
||
| state_dict key | 形状 | 用途 |
|
||
|---|---|---|
|
||
| `embedding.weight` | [53, 64] | Embedding 查找表 |
|
||
| `backbone.0.weight` | [512, 128] | Linear 128→512 |
|
||
| `backbone.0.bias` | [512] | |
|
||
| `backbone.2.weight` | [512, 512] | LayerNorm gamma (错! 见下方) |
|
||
| `backbone.2.bias` | [512] | LayerNorm beta |
|
||
| `backbone.3.weight` | [512, 512] | Linear 512→512 |
|
||
| `backbone.3.bias` | [512] | |
|
||
| `backbone.5.weight` | [512, 512] | LayerNorm gamma |
|
||
| `backbone.5.bias` | [512] | LayerNorm beta |
|
||
| `backbone.6.weight` | [256, 512] | Linear 512→256 |
|
||
| `backbone.6.bias` | [256] | |
|
||
| `backbone.8.weight` | [256] | LayerNorm gamma |
|
||
| `backbone.8.bias` | [256] | LayerNorm beta |
|
||
| `histogram_head.0.weight` | [64, 256] | Linear 256→64 |
|
||
| `histogram_head.0.bias` | [64] | |
|
||
| `histogram_head.2.weight` | [50, 64] | Linear 64→50 |
|
||
| `histogram_head.2.bias` | [50] | |
|
||
| `equity_head.*` | - | 在线推理不需要 |
|
||
|
||
> **重要**: CardModel 的 `backbone` 是 `nn.Sequential`,层的索引对应关系为:
|
||
> - `backbone.0` = Linear(128, 512)
|
||
> - `backbone.1` = ReLU (无参数)
|
||
> - `backbone.2` = LayerNorm(512)
|
||
> - `backbone.3` = Linear(512, 512)
|
||
> - `backbone.4` = ReLU (无参数)
|
||
> - `backbone.5` = LayerNorm(512)
|
||
> - `backbone.6` = Linear(512, 256)
|
||
> - `backbone.7` = ReLU (无参数)
|
||
> - `backbone.8` = LayerNorm(256)
|
||
|
||
#### CFRNetwork 权重 key → 用途
|
||
|
||
| state_dict key | 形状 | 用途 |
|
||
|---|---|---|
|
||
| `backbone.0.weight` | [256, 55] | Linear 55→256 |
|
||
| `backbone.0.bias` | [256] | |
|
||
| `backbone.2.weight` | [256, 256] | Linear 256→256 |
|
||
| `backbone.2.bias` | [256] | |
|
||
| `backbone.4.weight` | [128, 256] | Linear 256→128 |
|
||
| `backbone.4.bias` | [128] | |
|
||
| `regret_head.weight` | [5, 128] | 在线推理不需要 |
|
||
| `regret_head.bias` | [5] | 在线推理不需要 |
|
||
| `policy_head.weight` | [5, 128] | Linear 128→5 |
|
||
| `policy_head.bias` | [5] | |
|
||
|
||
> **重要**: CFRNetwork 的 `backbone` 索引:
|
||
> - `backbone.0` = Linear(55, 256)
|
||
> - `backbone.1` = ReLU (无参数)
|
||
> - `backbone.2` = Linear(256, 256)
|
||
> - `backbone.3` = ReLU (无参数)
|
||
> - `backbone.4` = Linear(256, 128)
|
||
> - `backbone.5` = ReLU (无参数)
|
||
|
||
---
|
||
|
||
## 6. 可优化项
|
||
|
||
### 6.1 是否需要 CardModel?
|
||
|
||
**需要**。CardModel 是推理流水线的必要组成部分,它将离散的牌面信息编码为 50 维连续向量,CFRNetwork 依赖这个输入。如果去掉 CardModel,你需要另外设计牌面编码方式,且训练好的 CFRNetwork 权重将无法使用。
|
||
|
||
### 6.2 是否需要 Regret Head?
|
||
|
||
**在线推理不需要**。Regret Head 用于训练时的 Regret Matching,在线推理只使用 Policy Head + Softmax 得到 `avg_strategy`。可以不加载 `regret_head` 的权重以节省内存。
|
||
|
||
### 6.3 是否需要 Equity Head?
|
||
|
||
**在线推理不需要**。Equity Head 只输出标量胜率用于监控,推理时只需要 Histogram Head 的 50 维输出。
|
||
|
||
### 6.4 替代方案: LibTorch
|
||
|
||
如果不想手写前向传播,可以使用 LibTorch (PyTorch C++ API) 直接加载 `.pt` 权重并执行推理。优点是代码量少,缺点是依赖体积大 (~200MB+)。
|
||
|
||
### 6.5 替代方案: ONNX Runtime
|
||
|
||
可以将两个模型导出为 ONNX 格式,用 ONNX Runtime C++ API 推理。兼顾易用性和性能。
|
||
|
||
```python
|
||
# 导出 CardModel 到 ONNX
|
||
card_model = CardModel()
|
||
card_model.load_state_dict(torch.load("card_model/data/best_card_model.pt", map_location="cpu", weights_only=False))
|
||
card_model.eval()
|
||
dummy_hole = torch.randint(0, 52, (1, 2))
|
||
dummy_board = torch.randint(0, 52, (1, 5))
|
||
torch.onnx.export(card_model, (dummy_hole, dummy_board), "card_model.onnx",
|
||
input_names=["x_hole", "x_board"],
|
||
output_names=["pred_equity", "pred_histogram"])
|
||
|
||
# 导出 CFRNetwork 到 ONNX
|
||
cfr_net = CFRNetwork()
|
||
cfr_net.load_state_dict(torch.load("botzone_cfr_net.pt", map_location="cpu", weights_only=False))
|
||
cfr_net.eval()
|
||
dummy_card = torch.randn(1, 50)
|
||
dummy_env = torch.randn(1, 5)
|
||
torch.onnx.export(cfr_net, (dummy_card, dummy_env), "cfr_net.onnx",
|
||
input_names=["card_features", "env_features"],
|
||
output_names=["regrets", "policy_logits"])
|
||
```
|
||
|
||
---
|
||
|
||
## 7. 数据流总结
|
||
|
||
```
|
||
┌─────────────────────────────────────────────────────────────────┐
|
||
│ 在线推理数据流 │
|
||
├─────────────────────────────────────────────────────────────────┤
|
||
│ │
|
||
│ 游戏状态 │
|
||
│ ├── hole_cards[2] (int, 0-51) │
|
||
│ ├── board_cards[5] (int, 0-51 or 52=PAD) │
|
||
│ ├── pot, stacks, street, position │
|
||
│ └── legal_actions (引擎原生) │
|
||
│ │ │
|
||
│ ▼ │
|
||
│ ┌─────────────────┐ ┌──────────────────┐ │
|
||
│ │ CardModel │ │ Feature Builder │ │
|
||
│ │ Embedding(53,64)│ │ pot/20000 │ │
|
||
│ │ sum → cat │ │ p0_stack/20000 │ │
|
||
│ │ MLP 128→512→ │ │ p1_stack/20000 │ │
|
||
│ │ 512→256 │ │ street/3.0 │ │
|
||
│ │ +LayerNorm │ │ position (0|1) │ │
|
||
│ │ HistHead 256→ │ └────────┬─────────┘ │
|
||
│ │ 64→50+Sofmax │ │ │
|
||
│ └────────┬────────┘ │ │
|
||
│ │ │ │
|
||
│ card_features[50] env_features[5] │
|
||
│ │ │ │
|
||
│ └──────────┬───────────┘ │
|
||
│ ▼ │
|
||
│ ┌─────────────────────┐ │
|
||
│ │ CFRNetwork │ │
|
||
│ │ cat → [55] │ │
|
||
│ │ MLP 55→256→256→128 │ │
|
||
│ │ PolicyHead 128→5 │ │
|
||
│ └────────┬────────────┘ │
|
||
│ │ │
|
||
│ policy_logits[5] │
|
||
│ │ │
|
||
│ masked_fill(非法→-1e9) │
|
||
│ │ │
|
||
│ Softmax → avg_strategy[5] │
|
||
│ │ │
|
||
│ 降噪 (<3%→0, 重归一化) │
|
||
│ │ │
|
||
│ 按概率采样 → chosen_cfr_idx │
|
||
│ │ │
|
||
│ BetTranslator → engine_action │
|
||
│ │ │
|
||
│ 执行引擎动作 │
|
||
│ │
|
||
└─────────────────────────────────────────────────────────────────┘
|
||
```
|