""" cfr_net_modular.py — 模块化后期融合 CFR 策略网络 (完全契合架构设计文档:包含独立的 Env Model 编码器) """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple class CFRNetworkModular(nn.Module): def __init__( self, card_dim: int = 50, env_dim: int = 5, z_env_dim: int = 32, # Env Model 压缩后的局势特征维度 num_actions: int = 6, ): super().__init__() self.card_dim = card_dim self.env_dim = env_dim self.z_env_dim = z_env_dim self.num_actions = num_actions # ──────── 1. 独立的 Env Model 局势编码器 ──────── # 负责把底池、筹码、位置等物理量,升维/降维提炼成局势概念 (z_env) self.env_encoder = nn.Sequential( nn.Linear(env_dim, 64), nn.LayerNorm(64), nn.ReLU(), nn.Linear(64, z_env_dim), nn.ReLU() ) # (保留挂载辅助任务的能力) 如果以后你想预测 EV,随时把这两行取消注释 # self.ev_head = nn.Linear(z_env_dim, 1) # ──────── 2. CFR 策略主干网络 (Late Fusion) ──────── # 接收纯净的牌面特征 (50维) + 提炼后的局势特征 (32维) concat_dim = card_dim + z_env_dim self.cfr_backbone = nn.Sequential( nn.Linear(concat_dim, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU() ) # ──────── 3. 双头输出 ──────── self.regret_head = nn.Linear(128, num_actions) self.policy_head = nn.Linear(128, num_actions) def forward(self, card_features: torch.Tensor, env_features: torch.Tensor): # 1. 局势提炼 z_env = self.env_encoder(env_features) # 2. 模块化拼接 x = torch.cat([card_features, z_env], dim=-1) # 3. CFR 骨干与输出 x = self.cfr_backbone(x) regrets = self.regret_head(x) policy_logits = self.policy_head(x) return regrets, policy_logits def get_strategy(self, card_features, env_features, legal_mask): """与原代码完全一致的 Regret Matching 逻辑""" regrets, policy_logits = self.forward(card_features, env_features) # 即时策略 positive_regret = F.relu(regrets) masked_regret = positive_regret * legal_mask regret_sum = masked_regret.sum(dim=-1, keepdim=True) num_legal = legal_mask.sum(dim=-1, keepdim=True).clamp(min=1.0) uniform = legal_mask / num_legal current_strategy = torch.where(regret_sum > 0, masked_regret / regret_sum, uniform) # 平均策略 masked_logits = policy_logits.masked_fill(legal_mask == 0, float("-inf")) avg_strategy = F.softmax(masked_logits, dim=-1) all_illegal = (legal_mask.sum(dim=-1, keepdim=True) == 0) avg_strategy = torch.where(all_illegal, uniform, avg_strategy) return current_strategy, avg_strategy