89 lines
3.2 KiB
Python
89 lines
3.2 KiB
Python
"""
|
||
cfr_net_modular.py — 模块化后期融合 CFR 策略网络
|
||
(完全契合架构设计文档:包含独立的 Env Model 编码器)
|
||
"""
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from typing import List, Tuple
|
||
|
||
class CFRNetworkModular(nn.Module):
|
||
def __init__(
|
||
self,
|
||
card_dim: int = 50,
|
||
env_dim: int = 5,
|
||
z_env_dim: int = 32, # Env Model 压缩后的局势特征维度
|
||
num_actions: int = 6,
|
||
):
|
||
super().__init__()
|
||
|
||
self.card_dim = card_dim
|
||
self.env_dim = env_dim
|
||
self.z_env_dim = z_env_dim
|
||
self.num_actions = num_actions
|
||
|
||
# ──────── 1. 独立的 Env Model 局势编码器 ────────
|
||
# 负责把底池、筹码、位置等物理量,升维/降维提炼成局势概念 (z_env)
|
||
self.env_encoder = nn.Sequential(
|
||
nn.Linear(env_dim, 64),
|
||
nn.LayerNorm(64),
|
||
nn.ReLU(),
|
||
nn.Linear(64, z_env_dim),
|
||
nn.ReLU()
|
||
)
|
||
|
||
# (保留挂载辅助任务的能力) 如果以后你想预测 EV,随时把这两行取消注释
|
||
# self.ev_head = nn.Linear(z_env_dim, 1)
|
||
|
||
# ──────── 2. CFR 策略主干网络 (Late Fusion) ────────
|
||
# 接收纯净的牌面特征 (50维) + 提炼后的局势特征 (32维)
|
||
concat_dim = card_dim + z_env_dim
|
||
|
||
self.cfr_backbone = nn.Sequential(
|
||
nn.Linear(concat_dim, 256),
|
||
nn.ReLU(),
|
||
nn.Linear(256, 128),
|
||
nn.ReLU()
|
||
)
|
||
|
||
# ──────── 3. 双头输出 ────────
|
||
self.regret_head = nn.Linear(128, num_actions)
|
||
self.policy_head = nn.Linear(128, num_actions)
|
||
|
||
def forward(self, card_features: torch.Tensor, env_features: torch.Tensor):
|
||
# 1. 局势提炼
|
||
z_env = self.env_encoder(env_features)
|
||
|
||
# 2. 模块化拼接
|
||
x = torch.cat([card_features, z_env], dim=-1)
|
||
|
||
# 3. CFR 骨干与输出
|
||
x = self.cfr_backbone(x)
|
||
regrets = self.regret_head(x)
|
||
policy_logits = self.policy_head(x)
|
||
|
||
return regrets, policy_logits
|
||
|
||
def get_strategy(self, card_features, env_features, legal_mask):
|
||
"""与原代码完全一致的 Regret Matching 逻辑"""
|
||
regrets, policy_logits = self.forward(card_features, env_features)
|
||
|
||
# 即时策略
|
||
positive_regret = F.relu(regrets)
|
||
masked_regret = positive_regret * legal_mask
|
||
regret_sum = masked_regret.sum(dim=-1, keepdim=True)
|
||
|
||
num_legal = legal_mask.sum(dim=-1, keepdim=True).clamp(min=1.0)
|
||
uniform = legal_mask / num_legal
|
||
|
||
current_strategy = torch.where(regret_sum > 0, masked_regret / regret_sum, uniform)
|
||
|
||
# 平均策略
|
||
masked_logits = policy_logits.masked_fill(legal_mask == 0, float("-inf"))
|
||
avg_strategy = F.softmax(masked_logits, dim=-1)
|
||
|
||
all_illegal = (legal_mask.sum(dim=-1, keepdim=True) == 0)
|
||
avg_strategy = torch.where(all_illegal, uniform, avg_strategy)
|
||
|
||
return current_strategy, avg_strategy |