Files
new/cfr_net.py
2026-05-06 17:36:51 +08:00

340 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
cfr_net.py — CFR 策略网络
Deep CFR 的核心网络,接受信息集 (Information Set) 特征,
输出 6 个动作的累计遗憾值 (Cumulative Regret) 和平均策略 (Average Strategy)。
输入:
- card_features: 牌面特征 [batch, 50]Card Model 预测的胜率直方图)
- env_features: 局势特征 [batch, 5]pot, p0_stack, p1_stack, street, position 归一化后)
网络结构:
拼接 → MLP(256→256→128) → 双输出头
- regret_head: 6 维无激活regret 可为负数)
- policy_head: 6 维Softmax平均策略概率分布
核心方法 get_strategy():
实现 Regret Matching — 将负 regret 截断为 0归一化得到即时策略
若所有合法动作的 regret 均为 0则在合法动作上均匀分布。
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
# ───────────────────── 网络超参数 ─────────────────────
CARD_DIM = 50 # Card Model 输出的胜率直方图维度
ENV_DIM = 5 # 局势特征维度: (pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position)
NUM_ACTIONS = 5 # CFR 动作数: FOLD, CALL, HALF_POT, FULL_POT, ALL_IN
# MLP 隐藏层维度
MLP_HIDDEN_DIMS = [256, 256, 128]
class CFRNetwork(nn.Module):
"""CFR 策略网络
双输入(牌面 + 局势)→ 拼接 → 3 层 MLP → 双输出头regret + policy
典型用法:
net = CFRNetwork()
strategy = net.get_strategy(card_feat, env_feat, legal_mask)
"""
def __init__(
self,
card_dim: int = CARD_DIM,
env_dim: int = ENV_DIM,
num_actions: int = NUM_ACTIONS,
hidden_dims: List[int] = None,
) -> None:
"""
Args:
card_dim: 牌面特征维度
env_dim: 局势特征维度
num_actions: 输出动作数
hidden_dims: MLP 隐藏层维度列表
"""
super().__init__()
if hidden_dims is None:
hidden_dims = MLP_HIDDEN_DIMS
self.card_dim = card_dim
self.env_dim = env_dim
self.num_actions = num_actions
# ── 拼接后的输入维度 ──
concat_dim = card_dim + env_dim
# ── 共享 MLP 骨干 ──
# 逐层构建: concat_dim → hidden_dims[0] → hidden_dims[1] → ... → hidden_dims[-1]
layers: List[nn.Module] = []
in_dim = concat_dim
for h_dim in hidden_dims:
layers.append(nn.Linear(in_dim, h_dim))
layers.append(nn.ReLU())
in_dim = h_dim
self.backbone = nn.Sequential(*layers)
# 骨干输出维度
self._backbone_out_dim = hidden_dims[-1]
# ── 遗憾输出头: 无激活函数,因为 regret 可以为负数 ──
self.regret_head = nn.Linear(self._backbone_out_dim, num_actions)
# ── 策略输出头: 后续过 Softmax此处输出原始 logits ──
self.policy_head = nn.Linear(self._backbone_out_dim, num_actions)
def forward(
self, card_features: torch.Tensor, env_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""前向传播,返回 regret 和 policy 的原始输出
Args:
card_features: [batch, card_dim] 牌面特征
env_features: [batch, env_dim] 局势特征
Returns:
regrets: [batch, num_actions] 遗憾值(未经截断,可为负数)
policy_logits: [batch, num_actions] 策略 logits未经 Softmax
"""
# 拼接牌面特征和局势特征
x = torch.cat([card_features, env_features], dim=-1)
# 共享骨干
x = self.backbone(x)
# 双头输出
regrets = self.regret_head(x) # [batch, 6] 无激活
policy_logits = self.policy_head(x) # [batch, 6] 后续过 Softmax
return regrets, policy_logits
def get_strategy(
self,
card_features: torch.Tensor,
env_features: torch.Tensor,
legal_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取当前局势的即时策略 (Current Policy) 和平均策略 (Average Strategy)
核心逻辑 — Regret Matching:
1. 将负数 regret 截断为 0: positive_regret = F.relu(regret)
2. 只保留合法动作: masked_regret = positive_regret * legal_mask
3. 若 masked_regret 之和 > 0归一化得到即时策略
4. 若 masked_regret 之和 == 0在合法动作上均匀分布
同时通过 policy_head + Softmax 得到平均策略。
Args:
card_features: [batch, card_dim] 牌面特征
env_features: [batch, env_dim] 局势特征
legal_mask: [batch, num_actions] 合法性掩码1=合法0=非法
Returns:
current_strategy: [batch, num_actions] 即时策略(基于 Regret Matching
avg_strategy: [batch, num_actions] 平均策略(基于 Softmax
"""
# 前向传播
regrets, policy_logits = self.forward(card_features, env_features)
# ──────── 即时策略: Regret Matching ────────
# 步骤 1: 截断负 regret 为 0
positive_regret = F.relu(regrets) # [batch, 6]
# 步骤 2: 掩码掉非法动作
# legal_mask 确保 FOLD 在不能 fold 时不会被选,
# 以及 ALL_IN 在筹码不足时不会被选
masked_regret = positive_regret * legal_mask # [batch, 6]
# 步骤 3: 归一化或均匀分布
regret_sum = masked_regret.sum(dim=-1, keepdim=True) # [batch, 1]
# 构造均匀分布: 在合法动作上均匀分配概率
# legal_mask.sum → 每个样本的合法动作数
num_legal = legal_mask.sum(dim=-1, keepdim=True).clamp(min=1.0) # [batch, 1]
uniform = legal_mask / num_legal # [batch, 6]
# 当 regret_sum > 0 时归一化,否则使用均匀分布
# 使用 where 避免 division by zero
current_strategy = torch.where(
regret_sum > 0,
masked_regret / regret_sum,
uniform,
)
# ──────── 平均策略: Softmax + 合法动作掩码 ────────
# 将非法动作的 logit 设为 -infSoftmax 后概率为 0
masked_logits = policy_logits.masked_fill(legal_mask == 0, float("-inf"))
avg_strategy = F.softmax(masked_logits, dim=-1) # [batch, 6]
# 处理全非法动作的边界情况(理论上不应发生)
# 当 legal_mask 全为 0 时masked_logits 全为 -infSoftmax 结果为 NaN
# 替换为均匀分布(虽然合法动作为 0 不应出现)
all_illegal = (legal_mask.sum(dim=-1, keepdim=True) == 0)
avg_strategy = torch.where(all_illegal, uniform, avg_strategy)
return current_strategy, avg_strategy
# ─────────────────────── 实例化测试 ───────────────────────
if __name__ == "__main__":
print("=" * 60)
print("CFRNetwork 单元测试")
print("=" * 60)
torch.manual_seed(42)
BATCH_SIZE = 4
CARD_DIM_TEST = 50
ENV_DIM_TEST = 5
NUM_ACTIONS_TEST = 6
# ── 测试 1: 网络实例化 ──
print("\n[测试1] 网络实例化")
net = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST)
total_params = sum(p.numel() for p in net.parameters())
print(f" 参数量: {total_params:,}")
print(f" backbone: {net.backbone}")
print(f" regret_head: {net.regret_head}")
print(f" policy_head: {net.policy_head}")
print(" 实例化成功 ✓")
# ── 测试 2: forward 前向传播 ──
print("\n[测试2] forward 前向传播")
card_feat = torch.randn(BATCH_SIZE, CARD_DIM_TEST)
env_feat = torch.randn(BATCH_SIZE, ENV_DIM_TEST)
regrets, policy_logits = net(card_feat, env_feat)
assert regrets.shape == (BATCH_SIZE, NUM_ACTIONS_TEST), f"regrets 形状错误: {regrets.shape}"
assert policy_logits.shape == (BATCH_SIZE, NUM_ACTIONS_TEST), f"policy_logits 形状错误: {policy_logits.shape}"
print(f" regrets: {regrets.shape}")
print(f" policy_logits: {policy_logits.shape}")
# regret 可以有负数
print(f" regret 最小值: {regrets.min().item():.4f}(应为负数)✓")
# ── 测试 3: get_strategy — Regret Matching ──
print("\n[测试3] get_strategy — Regret Matching")
# 场景 A: 部分合法动作(前 3 个合法)
legal_mask_A = torch.tensor([
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 0], # 只有 fold 和 call 合法
[1, 1, 1, 1, 0, 0],
], dtype=torch.float32)
current_strat, avg_strat = net.get_strategy(card_feat, env_feat, legal_mask_A)
# 检查形状
assert current_strat.shape == (BATCH_SIZE, NUM_ACTIONS_TEST)
assert avg_strat.shape == (BATCH_SIZE, NUM_ACTIONS_TEST)
# 检查即时策略:非法位置概率应为 0
for i in range(BATCH_SIZE):
for j in range(NUM_ACTIONS_TEST):
if legal_mask_A[i, j] == 0:
assert current_strat[i, j].item() == 0.0, \
f"样本{i}动作{j}非法但策略值={current_strat[i, j].item()}"
assert avg_strat[i, j].item() == 0.0, \
f"样本{i}动作{j}非法但平均策略值={avg_strat[i, j].item()}"
# 检查即时策略:概率之和应约为 1
for i in range(BATCH_SIZE):
cs_sum = current_strat[i].sum().item()
as_sum = avg_strat[i].sum().item()
assert abs(cs_sum - 1.0) < 1e-5, f"样本{i}即时策略和={cs_sum}"
assert abs(as_sum - 1.0) < 1e-5, f"样本{i}平均策略和={as_sum}"
print(f" current_strategy 形状: {current_strat.shape}")
print(f" avg_strategy 形状: {avg_strat.shape}")
# 打印具体样本
print(f"\n 样本 0 (legal=[1,1,1,0,0,0]):")
print(f" current: {current_strat[0].detach().numpy()}")
print(f" avg: {avg_strat[0].detach().numpy()}")
print(f" sum(cur)={current_strat[0].sum().item():.6f}, sum(avg)={avg_strat[0].sum().item():.6f}")
print(f"\n 样本 2 (legal=[1,1,0,0,0,0]):")
print(f" current: {current_strat[2].detach().numpy()}")
print(f" avg: {avg_strat[2].detach().numpy()}")
print(f" sum(cur)={current_strat[2].sum().item():.6f}, sum(avg)={avg_strat[2].sum().item():.6f}")
# ── 测试 4: 全零 regret → 均匀分布 ──
print("\n[测试4] 全零 regret → 均匀分布")
# 手动构造一个网络,让 regret_head 输出接近 0 或负数
with torch.no_grad():
# 将 regret_head 的权重和偏置置零,使输出全为 0
nn.init.zeros_(net.regret_head.weight)
nn.init.zeros_(net.regret_head.bias)
legal_mask_B = torch.tensor([[1, 1, 1, 0, 0, 0]], dtype=torch.float32)
card_B = torch.randn(1, CARD_DIM_TEST)
env_B = torch.randn(1, ENV_DIM_TEST)
current_B, avg_B = net.get_strategy(card_B, env_B, legal_mask_B)
# 当所有 regret 为 0 时,应在合法动作上均匀分布
# 合法动作 3 个 → 每个概率 1/3
expected_uniform = torch.tensor([[1/3, 1/3, 1/3, 0, 0, 0]])
assert torch.allclose(current_B, expected_uniform, atol=1e-5), \
f"全零 regret 应产生均匀分布,实际: {current_B.detach().numpy()}"
print(f" current_strategy: {current_B.detach().numpy()}")
print(f" 期望均匀分布: {expected_uniform.numpy()}")
print(" 全零 regret → 均匀分布 ✓")
# ── 测试 5: 单一合法动作 ──
print("\n[测试5] 单一合法动作")
legal_mask_C = torch.tensor([[0, 1, 0, 0, 0, 0]], dtype=torch.float32)
# 重新初始化网络以获得非零 regret
net2 = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST)
card_C = torch.randn(1, CARD_DIM_TEST)
env_C = torch.randn(1, ENV_DIM_TEST)
current_C, avg_C = net2.get_strategy(card_C, env_C, legal_mask_C)
# 只有 CALL 合法,概率应全部集中在 CALL
assert abs(current_C[0, 1].item() - 1.0) < 1e-5, f"单一合法动作概率={current_C[0, 1].item()}"
assert abs(avg_C[0, 1].item() - 1.0) < 1e-5, f"单一合法动作平均策略概率={avg_C[0, 1].item()}"
# 其他位置应为 0
for j in [0, 2, 3, 4, 5]:
assert current_C[0, j].item() == 0.0, f"非法动作{j}概率={current_C[0, j].item()}"
assert avg_C[0, j].item() == 0.0, f"非法动作{j}平均策略概率={avg_C[0, j].item()}"
print(f" current: {current_C.detach().numpy()}")
print(f" avg: {avg_C.detach().numpy()}")
print(" 单一合法动作概率正确 ✓")
# ── 测试 6: 梯度流 ──
print("\n[测试6] 梯度流")
net3 = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST)
card_D = torch.randn(2, CARD_DIM_TEST, requires_grad=True)
env_D = torch.randn(2, ENV_DIM_TEST, requires_grad=True)
legal_mask_D = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 0, 0, 1, 1]], dtype=torch.float32)
current_D, avg_D = net3.get_strategy(card_D, env_D, legal_mask_D)
# 使用 MSE loss 而非 .sum(),因为归一化后策略之和恒为 1梯度近似为零
target = torch.zeros_like(current_D)
loss = F.mse_loss(current_D, target) + F.mse_loss(avg_D, target)
loss.backward()
assert card_D.grad is not None, "card_features 梯度为 None"
assert env_D.grad is not None, "env_features 梯度为 None"
assert card_D.grad.norm().item() > 0, "card_features 梯度范数应为正数"
assert env_D.grad.norm().item() > 0, "env_features 梯度范数应为正数"
print(f" card_features.grad norm: {card_D.grad.norm().item():.6f}")
print(f" env_features.grad norm: {env_D.grad.norm().item():.6f}")
# 检查网络参数梯度
has_grad = sum(1 for p in net3.parameters() if p.grad is not None)
total_p = sum(1 for p in net3.parameters())
print(f" 有梯度的参数: {has_grad}/{total_p}")
print("\n" + "=" * 60)
print("所有测试通过 ✓")
print("=" * 60)