340 lines
14 KiB
Python
340 lines
14 KiB
Python
"""
|
||
cfr_net.py — CFR 策略网络
|
||
|
||
Deep CFR 的核心网络,接受信息集 (Information Set) 特征,
|
||
输出 6 个动作的累计遗憾值 (Cumulative Regret) 和平均策略 (Average Strategy)。
|
||
|
||
输入:
|
||
- card_features: 牌面特征 [batch, 50](Card Model 预测的胜率直方图)
|
||
- env_features: 局势特征 [batch, 5](pot, p0_stack, p1_stack, street, position 归一化后)
|
||
|
||
网络结构:
|
||
拼接 → MLP(256→256→128) → 双输出头
|
||
- regret_head: 6 维,无激活(regret 可为负数)
|
||
- policy_head: 6 维,Softmax(平均策略概率分布)
|
||
|
||
核心方法 get_strategy():
|
||
实现 Regret Matching — 将负 regret 截断为 0,归一化得到即时策略;
|
||
若所有合法动作的 regret 均为 0,则在合法动作上均匀分布。
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from typing import List, Tuple
|
||
|
||
|
||
# ───────────────────── 网络超参数 ─────────────────────
|
||
|
||
CARD_DIM = 50 # Card Model 输出的胜率直方图维度
|
||
ENV_DIM = 5 # 局势特征维度: (pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position)
|
||
NUM_ACTIONS = 5 # CFR 动作数: FOLD, CALL, HALF_POT, FULL_POT, ALL_IN
|
||
|
||
# MLP 隐藏层维度
|
||
MLP_HIDDEN_DIMS = [256, 256, 128]
|
||
|
||
|
||
class CFRNetwork(nn.Module):
|
||
"""CFR 策略网络
|
||
|
||
双输入(牌面 + 局势)→ 拼接 → 3 层 MLP → 双输出头(regret + policy)
|
||
|
||
典型用法:
|
||
net = CFRNetwork()
|
||
strategy = net.get_strategy(card_feat, env_feat, legal_mask)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
card_dim: int = CARD_DIM,
|
||
env_dim: int = ENV_DIM,
|
||
num_actions: int = NUM_ACTIONS,
|
||
hidden_dims: List[int] = None,
|
||
) -> None:
|
||
"""
|
||
Args:
|
||
card_dim: 牌面特征维度
|
||
env_dim: 局势特征维度
|
||
num_actions: 输出动作数
|
||
hidden_dims: MLP 隐藏层维度列表
|
||
"""
|
||
super().__init__()
|
||
if hidden_dims is None:
|
||
hidden_dims = MLP_HIDDEN_DIMS
|
||
|
||
self.card_dim = card_dim
|
||
self.env_dim = env_dim
|
||
self.num_actions = num_actions
|
||
|
||
# ── 拼接后的输入维度 ──
|
||
concat_dim = card_dim + env_dim
|
||
|
||
# ── 共享 MLP 骨干 ──
|
||
# 逐层构建: concat_dim → hidden_dims[0] → hidden_dims[1] → ... → hidden_dims[-1]
|
||
layers: List[nn.Module] = []
|
||
in_dim = concat_dim
|
||
for h_dim in hidden_dims:
|
||
layers.append(nn.Linear(in_dim, h_dim))
|
||
layers.append(nn.ReLU())
|
||
in_dim = h_dim
|
||
self.backbone = nn.Sequential(*layers)
|
||
|
||
# 骨干输出维度
|
||
self._backbone_out_dim = hidden_dims[-1]
|
||
|
||
# ── 遗憾输出头: 无激活函数,因为 regret 可以为负数 ──
|
||
self.regret_head = nn.Linear(self._backbone_out_dim, num_actions)
|
||
|
||
# ── 策略输出头: 后续过 Softmax,此处输出原始 logits ──
|
||
self.policy_head = nn.Linear(self._backbone_out_dim, num_actions)
|
||
|
||
def forward(
|
||
self, card_features: torch.Tensor, env_features: torch.Tensor
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""前向传播,返回 regret 和 policy 的原始输出
|
||
|
||
Args:
|
||
card_features: [batch, card_dim] 牌面特征
|
||
env_features: [batch, env_dim] 局势特征
|
||
|
||
Returns:
|
||
regrets: [batch, num_actions] 遗憾值(未经截断,可为负数)
|
||
policy_logits: [batch, num_actions] 策略 logits(未经 Softmax)
|
||
"""
|
||
# 拼接牌面特征和局势特征
|
||
x = torch.cat([card_features, env_features], dim=-1)
|
||
|
||
# 共享骨干
|
||
x = self.backbone(x)
|
||
|
||
# 双头输出
|
||
regrets = self.regret_head(x) # [batch, 6] 无激活
|
||
policy_logits = self.policy_head(x) # [batch, 6] 后续过 Softmax
|
||
|
||
return regrets, policy_logits
|
||
|
||
def get_strategy(
|
||
self,
|
||
card_features: torch.Tensor,
|
||
env_features: torch.Tensor,
|
||
legal_mask: torch.Tensor,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""获取当前局势的即时策略 (Current Policy) 和平均策略 (Average Strategy)
|
||
|
||
核心逻辑 — Regret Matching:
|
||
1. 将负数 regret 截断为 0: positive_regret = F.relu(regret)
|
||
2. 只保留合法动作: masked_regret = positive_regret * legal_mask
|
||
3. 若 masked_regret 之和 > 0,归一化得到即时策略
|
||
4. 若 masked_regret 之和 == 0,在合法动作上均匀分布
|
||
|
||
同时通过 policy_head + Softmax 得到平均策略。
|
||
|
||
Args:
|
||
card_features: [batch, card_dim] 牌面特征
|
||
env_features: [batch, env_dim] 局势特征
|
||
legal_mask: [batch, num_actions] 合法性掩码,1=合法,0=非法
|
||
|
||
Returns:
|
||
current_strategy: [batch, num_actions] 即时策略(基于 Regret Matching)
|
||
avg_strategy: [batch, num_actions] 平均策略(基于 Softmax)
|
||
"""
|
||
# 前向传播
|
||
regrets, policy_logits = self.forward(card_features, env_features)
|
||
|
||
# ──────── 即时策略: Regret Matching ────────
|
||
|
||
# 步骤 1: 截断负 regret 为 0
|
||
positive_regret = F.relu(regrets) # [batch, 6]
|
||
|
||
# 步骤 2: 掩码掉非法动作
|
||
# legal_mask 确保 FOLD 在不能 fold 时不会被选,
|
||
# 以及 ALL_IN 在筹码不足时不会被选
|
||
masked_regret = positive_regret * legal_mask # [batch, 6]
|
||
|
||
# 步骤 3: 归一化或均匀分布
|
||
regret_sum = masked_regret.sum(dim=-1, keepdim=True) # [batch, 1]
|
||
|
||
# 构造均匀分布: 在合法动作上均匀分配概率
|
||
# legal_mask.sum → 每个样本的合法动作数
|
||
num_legal = legal_mask.sum(dim=-1, keepdim=True).clamp(min=1.0) # [batch, 1]
|
||
uniform = legal_mask / num_legal # [batch, 6]
|
||
|
||
# 当 regret_sum > 0 时归一化,否则使用均匀分布
|
||
# 使用 where 避免 division by zero
|
||
current_strategy = torch.where(
|
||
regret_sum > 0,
|
||
masked_regret / regret_sum,
|
||
uniform,
|
||
)
|
||
|
||
# ──────── 平均策略: Softmax + 合法动作掩码 ────────
|
||
|
||
# 将非法动作的 logit 设为 -inf,Softmax 后概率为 0
|
||
masked_logits = policy_logits.masked_fill(legal_mask == 0, float("-inf"))
|
||
avg_strategy = F.softmax(masked_logits, dim=-1) # [batch, 6]
|
||
|
||
# 处理全非法动作的边界情况(理论上不应发生)
|
||
# 当 legal_mask 全为 0 时,masked_logits 全为 -inf,Softmax 结果为 NaN
|
||
# 替换为均匀分布(虽然合法动作为 0 不应出现)
|
||
all_illegal = (legal_mask.sum(dim=-1, keepdim=True) == 0)
|
||
avg_strategy = torch.where(all_illegal, uniform, avg_strategy)
|
||
|
||
return current_strategy, avg_strategy
|
||
|
||
|
||
# ─────────────────────── 实例化测试 ───────────────────────
|
||
if __name__ == "__main__":
|
||
print("=" * 60)
|
||
print("CFRNetwork 单元测试")
|
||
print("=" * 60)
|
||
|
||
torch.manual_seed(42)
|
||
|
||
BATCH_SIZE = 4
|
||
CARD_DIM_TEST = 50
|
||
ENV_DIM_TEST = 5
|
||
NUM_ACTIONS_TEST = 6
|
||
|
||
# ── 测试 1: 网络实例化 ──
|
||
print("\n[测试1] 网络实例化")
|
||
net = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST)
|
||
total_params = sum(p.numel() for p in net.parameters())
|
||
print(f" 参数量: {total_params:,}")
|
||
print(f" backbone: {net.backbone}")
|
||
print(f" regret_head: {net.regret_head}")
|
||
print(f" policy_head: {net.policy_head}")
|
||
print(" 实例化成功 ✓")
|
||
|
||
# ── 测试 2: forward 前向传播 ──
|
||
print("\n[测试2] forward 前向传播")
|
||
card_feat = torch.randn(BATCH_SIZE, CARD_DIM_TEST)
|
||
env_feat = torch.randn(BATCH_SIZE, ENV_DIM_TEST)
|
||
regrets, policy_logits = net(card_feat, env_feat)
|
||
assert regrets.shape == (BATCH_SIZE, NUM_ACTIONS_TEST), f"regrets 形状错误: {regrets.shape}"
|
||
assert policy_logits.shape == (BATCH_SIZE, NUM_ACTIONS_TEST), f"policy_logits 形状错误: {policy_logits.shape}"
|
||
print(f" regrets: {regrets.shape} ✓")
|
||
print(f" policy_logits: {policy_logits.shape} ✓")
|
||
# regret 可以有负数
|
||
print(f" regret 最小值: {regrets.min().item():.4f}(应为负数)✓")
|
||
|
||
# ── 测试 3: get_strategy — Regret Matching ──
|
||
print("\n[测试3] get_strategy — Regret Matching")
|
||
|
||
# 场景 A: 部分合法动作(前 3 个合法)
|
||
legal_mask_A = torch.tensor([
|
||
[1, 1, 1, 0, 0, 0],
|
||
[1, 1, 1, 1, 1, 1],
|
||
[1, 1, 0, 0, 0, 0], # 只有 fold 和 call 合法
|
||
[1, 1, 1, 1, 0, 0],
|
||
], dtype=torch.float32)
|
||
|
||
current_strat, avg_strat = net.get_strategy(card_feat, env_feat, legal_mask_A)
|
||
|
||
# 检查形状
|
||
assert current_strat.shape == (BATCH_SIZE, NUM_ACTIONS_TEST)
|
||
assert avg_strat.shape == (BATCH_SIZE, NUM_ACTIONS_TEST)
|
||
|
||
# 检查即时策略:非法位置概率应为 0
|
||
for i in range(BATCH_SIZE):
|
||
for j in range(NUM_ACTIONS_TEST):
|
||
if legal_mask_A[i, j] == 0:
|
||
assert current_strat[i, j].item() == 0.0, \
|
||
f"样本{i}动作{j}非法但策略值={current_strat[i, j].item()}"
|
||
assert avg_strat[i, j].item() == 0.0, \
|
||
f"样本{i}动作{j}非法但平均策略值={avg_strat[i, j].item()}"
|
||
|
||
# 检查即时策略:概率之和应约为 1
|
||
for i in range(BATCH_SIZE):
|
||
cs_sum = current_strat[i].sum().item()
|
||
as_sum = avg_strat[i].sum().item()
|
||
assert abs(cs_sum - 1.0) < 1e-5, f"样本{i}即时策略和={cs_sum}"
|
||
assert abs(as_sum - 1.0) < 1e-5, f"样本{i}平均策略和={as_sum}"
|
||
|
||
print(f" current_strategy 形状: {current_strat.shape} ✓")
|
||
print(f" avg_strategy 形状: {avg_strat.shape} ✓")
|
||
|
||
# 打印具体样本
|
||
print(f"\n 样本 0 (legal=[1,1,1,0,0,0]):")
|
||
print(f" current: {current_strat[0].detach().numpy()}")
|
||
print(f" avg: {avg_strat[0].detach().numpy()}")
|
||
print(f" sum(cur)={current_strat[0].sum().item():.6f}, sum(avg)={avg_strat[0].sum().item():.6f}")
|
||
|
||
print(f"\n 样本 2 (legal=[1,1,0,0,0,0]):")
|
||
print(f" current: {current_strat[2].detach().numpy()}")
|
||
print(f" avg: {avg_strat[2].detach().numpy()}")
|
||
print(f" sum(cur)={current_strat[2].sum().item():.6f}, sum(avg)={avg_strat[2].sum().item():.6f}")
|
||
|
||
# ── 测试 4: 全零 regret → 均匀分布 ──
|
||
print("\n[测试4] 全零 regret → 均匀分布")
|
||
# 手动构造一个网络,让 regret_head 输出接近 0 或负数
|
||
with torch.no_grad():
|
||
# 将 regret_head 的权重和偏置置零,使输出全为 0
|
||
nn.init.zeros_(net.regret_head.weight)
|
||
nn.init.zeros_(net.regret_head.bias)
|
||
|
||
legal_mask_B = torch.tensor([[1, 1, 1, 0, 0, 0]], dtype=torch.float32)
|
||
card_B = torch.randn(1, CARD_DIM_TEST)
|
||
env_B = torch.randn(1, ENV_DIM_TEST)
|
||
|
||
current_B, avg_B = net.get_strategy(card_B, env_B, legal_mask_B)
|
||
|
||
# 当所有 regret 为 0 时,应在合法动作上均匀分布
|
||
# 合法动作 3 个 → 每个概率 1/3
|
||
expected_uniform = torch.tensor([[1/3, 1/3, 1/3, 0, 0, 0]])
|
||
assert torch.allclose(current_B, expected_uniform, atol=1e-5), \
|
||
f"全零 regret 应产生均匀分布,实际: {current_B.detach().numpy()}"
|
||
print(f" current_strategy: {current_B.detach().numpy()}")
|
||
print(f" 期望均匀分布: {expected_uniform.numpy()}")
|
||
print(" 全零 regret → 均匀分布 ✓")
|
||
|
||
# ── 测试 5: 单一合法动作 ──
|
||
print("\n[测试5] 单一合法动作")
|
||
legal_mask_C = torch.tensor([[0, 1, 0, 0, 0, 0]], dtype=torch.float32)
|
||
# 重新初始化网络以获得非零 regret
|
||
net2 = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST)
|
||
card_C = torch.randn(1, CARD_DIM_TEST)
|
||
env_C = torch.randn(1, ENV_DIM_TEST)
|
||
|
||
current_C, avg_C = net2.get_strategy(card_C, env_C, legal_mask_C)
|
||
|
||
# 只有 CALL 合法,概率应全部集中在 CALL
|
||
assert abs(current_C[0, 1].item() - 1.0) < 1e-5, f"单一合法动作概率={current_C[0, 1].item()}"
|
||
assert abs(avg_C[0, 1].item() - 1.0) < 1e-5, f"单一合法动作平均策略概率={avg_C[0, 1].item()}"
|
||
# 其他位置应为 0
|
||
for j in [0, 2, 3, 4, 5]:
|
||
assert current_C[0, j].item() == 0.0, f"非法动作{j}概率={current_C[0, j].item()}"
|
||
assert avg_C[0, j].item() == 0.0, f"非法动作{j}平均策略概率={avg_C[0, j].item()}"
|
||
print(f" current: {current_C.detach().numpy()}")
|
||
print(f" avg: {avg_C.detach().numpy()}")
|
||
print(" 单一合法动作概率正确 ✓")
|
||
|
||
# ── 测试 6: 梯度流 ──
|
||
print("\n[测试6] 梯度流")
|
||
net3 = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST)
|
||
card_D = torch.randn(2, CARD_DIM_TEST, requires_grad=True)
|
||
env_D = torch.randn(2, ENV_DIM_TEST, requires_grad=True)
|
||
legal_mask_D = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 0, 0, 1, 1]], dtype=torch.float32)
|
||
|
||
current_D, avg_D = net3.get_strategy(card_D, env_D, legal_mask_D)
|
||
|
||
# 使用 MSE loss 而非 .sum(),因为归一化后策略之和恒为 1,梯度近似为零
|
||
target = torch.zeros_like(current_D)
|
||
loss = F.mse_loss(current_D, target) + F.mse_loss(avg_D, target)
|
||
loss.backward()
|
||
|
||
assert card_D.grad is not None, "card_features 梯度为 None"
|
||
assert env_D.grad is not None, "env_features 梯度为 None"
|
||
assert card_D.grad.norm().item() > 0, "card_features 梯度范数应为正数"
|
||
assert env_D.grad.norm().item() > 0, "env_features 梯度范数应为正数"
|
||
print(f" card_features.grad norm: {card_D.grad.norm().item():.6f} ✓")
|
||
print(f" env_features.grad norm: {env_D.grad.norm().item():.6f} ✓")
|
||
|
||
# 检查网络参数梯度
|
||
has_grad = sum(1 for p in net3.parameters() if p.grad is not None)
|
||
total_p = sum(1 for p in net3.parameters())
|
||
print(f" 有梯度的参数: {has_grad}/{total_p} ✓")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("所有测试通过 ✓")
|
||
print("=" * 60)
|