""" cfr_net.py — CFR 策略网络 Deep CFR 的核心网络,接受信息集 (Information Set) 特征, 输出 6 个动作的累计遗憾值 (Cumulative Regret) 和平均策略 (Average Strategy)。 输入: - card_features: 牌面特征 [batch, 50](Card Model 预测的胜率直方图) - env_features: 局势特征 [batch, 5](pot, p0_stack, p1_stack, street, position 归一化后) 网络结构: 拼接 → MLP(256→256→128) → 双输出头 - regret_head: 6 维,无激活(regret 可为负数) - policy_head: 6 维,Softmax(平均策略概率分布) 核心方法 get_strategy(): 实现 Regret Matching — 将负 regret 截断为 0,归一化得到即时策略; 若所有合法动作的 regret 均为 0,则在合法动作上均匀分布。 """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple # ───────────────────── 网络超参数 ───────────────────── CARD_DIM = 50 # Card Model 输出的胜率直方图维度 ENV_DIM = 5 # 局势特征维度: (pot/20000, p0_stack/20000, p1_stack/20000, street/3.0, position) NUM_ACTIONS = 5 # CFR 动作数: FOLD, CALL, HALF_POT, FULL_POT, ALL_IN # MLP 隐藏层维度 MLP_HIDDEN_DIMS = [256, 256, 128] class CFRNetwork(nn.Module): """CFR 策略网络 双输入(牌面 + 局势)→ 拼接 → 3 层 MLP → 双输出头(regret + policy) 典型用法: net = CFRNetwork() strategy = net.get_strategy(card_feat, env_feat, legal_mask) """ def __init__( self, card_dim: int = CARD_DIM, env_dim: int = ENV_DIM, num_actions: int = NUM_ACTIONS, hidden_dims: List[int] = None, ) -> None: """ Args: card_dim: 牌面特征维度 env_dim: 局势特征维度 num_actions: 输出动作数 hidden_dims: MLP 隐藏层维度列表 """ super().__init__() if hidden_dims is None: hidden_dims = MLP_HIDDEN_DIMS self.card_dim = card_dim self.env_dim = env_dim self.num_actions = num_actions # ── 拼接后的输入维度 ── concat_dim = card_dim + env_dim # ── 共享 MLP 骨干 ── # 逐层构建: concat_dim → hidden_dims[0] → hidden_dims[1] → ... → hidden_dims[-1] layers: List[nn.Module] = [] in_dim = concat_dim for h_dim in hidden_dims: layers.append(nn.Linear(in_dim, h_dim)) layers.append(nn.ReLU()) in_dim = h_dim self.backbone = nn.Sequential(*layers) # 骨干输出维度 self._backbone_out_dim = hidden_dims[-1] # ── 遗憾输出头: 无激活函数,因为 regret 可以为负数 ── self.regret_head = nn.Linear(self._backbone_out_dim, num_actions) # ── 策略输出头: 后续过 Softmax,此处输出原始 logits ── self.policy_head = nn.Linear(self._backbone_out_dim, num_actions) def forward( self, card_features: torch.Tensor, env_features: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """前向传播,返回 regret 和 policy 的原始输出 Args: card_features: [batch, card_dim] 牌面特征 env_features: [batch, env_dim] 局势特征 Returns: regrets: [batch, num_actions] 遗憾值(未经截断,可为负数) policy_logits: [batch, num_actions] 策略 logits(未经 Softmax) """ # 拼接牌面特征和局势特征 x = torch.cat([card_features, env_features], dim=-1) # 共享骨干 x = self.backbone(x) # 双头输出 regrets = self.regret_head(x) # [batch, 6] 无激活 policy_logits = self.policy_head(x) # [batch, 6] 后续过 Softmax return regrets, policy_logits def get_strategy( self, card_features: torch.Tensor, env_features: torch.Tensor, legal_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """获取当前局势的即时策略 (Current Policy) 和平均策略 (Average Strategy) 核心逻辑 — Regret Matching: 1. 将负数 regret 截断为 0: positive_regret = F.relu(regret) 2. 只保留合法动作: masked_regret = positive_regret * legal_mask 3. 若 masked_regret 之和 > 0,归一化得到即时策略 4. 若 masked_regret 之和 == 0,在合法动作上均匀分布 同时通过 policy_head + Softmax 得到平均策略。 Args: card_features: [batch, card_dim] 牌面特征 env_features: [batch, env_dim] 局势特征 legal_mask: [batch, num_actions] 合法性掩码,1=合法,0=非法 Returns: current_strategy: [batch, num_actions] 即时策略(基于 Regret Matching) avg_strategy: [batch, num_actions] 平均策略(基于 Softmax) """ # 前向传播 regrets, policy_logits = self.forward(card_features, env_features) # ──────── 即时策略: Regret Matching ──────── # 步骤 1: 截断负 regret 为 0 positive_regret = F.relu(regrets) # [batch, 6] # 步骤 2: 掩码掉非法动作 # legal_mask 确保 FOLD 在不能 fold 时不会被选, # 以及 ALL_IN 在筹码不足时不会被选 masked_regret = positive_regret * legal_mask # [batch, 6] # 步骤 3: 归一化或均匀分布 regret_sum = masked_regret.sum(dim=-1, keepdim=True) # [batch, 1] # 构造均匀分布: 在合法动作上均匀分配概率 # legal_mask.sum → 每个样本的合法动作数 num_legal = legal_mask.sum(dim=-1, keepdim=True).clamp(min=1.0) # [batch, 1] uniform = legal_mask / num_legal # [batch, 6] # 当 regret_sum > 0 时归一化,否则使用均匀分布 # 使用 where 避免 division by zero current_strategy = torch.where( regret_sum > 0, masked_regret / regret_sum, uniform, ) # ──────── 平均策略: Softmax + 合法动作掩码 ──────── # 将非法动作的 logit 设为 -inf,Softmax 后概率为 0 masked_logits = policy_logits.masked_fill(legal_mask == 0, float("-inf")) avg_strategy = F.softmax(masked_logits, dim=-1) # [batch, 6] # 处理全非法动作的边界情况(理论上不应发生) # 当 legal_mask 全为 0 时,masked_logits 全为 -inf,Softmax 结果为 NaN # 替换为均匀分布(虽然合法动作为 0 不应出现) all_illegal = (legal_mask.sum(dim=-1, keepdim=True) == 0) avg_strategy = torch.where(all_illegal, uniform, avg_strategy) return current_strategy, avg_strategy # ─────────────────────── 实例化测试 ─────────────────────── if __name__ == "__main__": print("=" * 60) print("CFRNetwork 单元测试") print("=" * 60) torch.manual_seed(42) BATCH_SIZE = 4 CARD_DIM_TEST = 50 ENV_DIM_TEST = 5 NUM_ACTIONS_TEST = 6 # ── 测试 1: 网络实例化 ── print("\n[测试1] 网络实例化") net = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST) total_params = sum(p.numel() for p in net.parameters()) print(f" 参数量: {total_params:,}") print(f" backbone: {net.backbone}") print(f" regret_head: {net.regret_head}") print(f" policy_head: {net.policy_head}") print(" 实例化成功 ✓") # ── 测试 2: forward 前向传播 ── print("\n[测试2] forward 前向传播") card_feat = torch.randn(BATCH_SIZE, CARD_DIM_TEST) env_feat = torch.randn(BATCH_SIZE, ENV_DIM_TEST) regrets, policy_logits = net(card_feat, env_feat) assert regrets.shape == (BATCH_SIZE, NUM_ACTIONS_TEST), f"regrets 形状错误: {regrets.shape}" assert policy_logits.shape == (BATCH_SIZE, NUM_ACTIONS_TEST), f"policy_logits 形状错误: {policy_logits.shape}" print(f" regrets: {regrets.shape} ✓") print(f" policy_logits: {policy_logits.shape} ✓") # regret 可以有负数 print(f" regret 最小值: {regrets.min().item():.4f}(应为负数)✓") # ── 测试 3: get_strategy — Regret Matching ── print("\n[测试3] get_strategy — Regret Matching") # 场景 A: 部分合法动作(前 3 个合法) legal_mask_A = torch.tensor([ [1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 1], [1, 1, 0, 0, 0, 0], # 只有 fold 和 call 合法 [1, 1, 1, 1, 0, 0], ], dtype=torch.float32) current_strat, avg_strat = net.get_strategy(card_feat, env_feat, legal_mask_A) # 检查形状 assert current_strat.shape == (BATCH_SIZE, NUM_ACTIONS_TEST) assert avg_strat.shape == (BATCH_SIZE, NUM_ACTIONS_TEST) # 检查即时策略:非法位置概率应为 0 for i in range(BATCH_SIZE): for j in range(NUM_ACTIONS_TEST): if legal_mask_A[i, j] == 0: assert current_strat[i, j].item() == 0.0, \ f"样本{i}动作{j}非法但策略值={current_strat[i, j].item()}" assert avg_strat[i, j].item() == 0.0, \ f"样本{i}动作{j}非法但平均策略值={avg_strat[i, j].item()}" # 检查即时策略:概率之和应约为 1 for i in range(BATCH_SIZE): cs_sum = current_strat[i].sum().item() as_sum = avg_strat[i].sum().item() assert abs(cs_sum - 1.0) < 1e-5, f"样本{i}即时策略和={cs_sum}" assert abs(as_sum - 1.0) < 1e-5, f"样本{i}平均策略和={as_sum}" print(f" current_strategy 形状: {current_strat.shape} ✓") print(f" avg_strategy 形状: {avg_strat.shape} ✓") # 打印具体样本 print(f"\n 样本 0 (legal=[1,1,1,0,0,0]):") print(f" current: {current_strat[0].detach().numpy()}") print(f" avg: {avg_strat[0].detach().numpy()}") print(f" sum(cur)={current_strat[0].sum().item():.6f}, sum(avg)={avg_strat[0].sum().item():.6f}") print(f"\n 样本 2 (legal=[1,1,0,0,0,0]):") print(f" current: {current_strat[2].detach().numpy()}") print(f" avg: {avg_strat[2].detach().numpy()}") print(f" sum(cur)={current_strat[2].sum().item():.6f}, sum(avg)={avg_strat[2].sum().item():.6f}") # ── 测试 4: 全零 regret → 均匀分布 ── print("\n[测试4] 全零 regret → 均匀分布") # 手动构造一个网络,让 regret_head 输出接近 0 或负数 with torch.no_grad(): # 将 regret_head 的权重和偏置置零,使输出全为 0 nn.init.zeros_(net.regret_head.weight) nn.init.zeros_(net.regret_head.bias) legal_mask_B = torch.tensor([[1, 1, 1, 0, 0, 0]], dtype=torch.float32) card_B = torch.randn(1, CARD_DIM_TEST) env_B = torch.randn(1, ENV_DIM_TEST) current_B, avg_B = net.get_strategy(card_B, env_B, legal_mask_B) # 当所有 regret 为 0 时,应在合法动作上均匀分布 # 合法动作 3 个 → 每个概率 1/3 expected_uniform = torch.tensor([[1/3, 1/3, 1/3, 0, 0, 0]]) assert torch.allclose(current_B, expected_uniform, atol=1e-5), \ f"全零 regret 应产生均匀分布,实际: {current_B.detach().numpy()}" print(f" current_strategy: {current_B.detach().numpy()}") print(f" 期望均匀分布: {expected_uniform.numpy()}") print(" 全零 regret → 均匀分布 ✓") # ── 测试 5: 单一合法动作 ── print("\n[测试5] 单一合法动作") legal_mask_C = torch.tensor([[0, 1, 0, 0, 0, 0]], dtype=torch.float32) # 重新初始化网络以获得非零 regret net2 = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST) card_C = torch.randn(1, CARD_DIM_TEST) env_C = torch.randn(1, ENV_DIM_TEST) current_C, avg_C = net2.get_strategy(card_C, env_C, legal_mask_C) # 只有 CALL 合法,概率应全部集中在 CALL assert abs(current_C[0, 1].item() - 1.0) < 1e-5, f"单一合法动作概率={current_C[0, 1].item()}" assert abs(avg_C[0, 1].item() - 1.0) < 1e-5, f"单一合法动作平均策略概率={avg_C[0, 1].item()}" # 其他位置应为 0 for j in [0, 2, 3, 4, 5]: assert current_C[0, j].item() == 0.0, f"非法动作{j}概率={current_C[0, j].item()}" assert avg_C[0, j].item() == 0.0, f"非法动作{j}平均策略概率={avg_C[0, j].item()}" print(f" current: {current_C.detach().numpy()}") print(f" avg: {avg_C.detach().numpy()}") print(" 单一合法动作概率正确 ✓") # ── 测试 6: 梯度流 ── print("\n[测试6] 梯度流") net3 = CFRNetwork(card_dim=CARD_DIM_TEST, env_dim=ENV_DIM_TEST, num_actions=NUM_ACTIONS_TEST) card_D = torch.randn(2, CARD_DIM_TEST, requires_grad=True) env_D = torch.randn(2, ENV_DIM_TEST, requires_grad=True) legal_mask_D = torch.tensor([[1, 1, 1, 1, 1, 1], [1, 1, 0, 0, 1, 1]], dtype=torch.float32) current_D, avg_D = net3.get_strategy(card_D, env_D, legal_mask_D) # 使用 MSE loss 而非 .sum(),因为归一化后策略之和恒为 1,梯度近似为零 target = torch.zeros_like(current_D) loss = F.mse_loss(current_D, target) + F.mse_loss(avg_D, target) loss.backward() assert card_D.grad is not None, "card_features 梯度为 None" assert env_D.grad is not None, "env_features 梯度为 None" assert card_D.grad.norm().item() > 0, "card_features 梯度范数应为正数" assert env_D.grad.norm().item() > 0, "env_features 梯度范数应为正数" print(f" card_features.grad norm: {card_D.grad.norm().item():.6f} ✓") print(f" env_features.grad norm: {env_D.grad.norm().item():.6f} ✓") # 检查网络参数梯度 has_grad = sum(1 for p in net3.parameters() if p.grad is not None) total_p = sum(1 for p in net3.parameters()) print(f" 有梯度的参数: {has_grad}/{total_p} ✓") print("\n" + "=" * 60) print("所有测试通过 ✓") print("=" * 60)