Files
new/cfr_net_modular.py
2026-04-21 15:52:59 +08:00

89 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
cfr_net_modular.py — 模块化后期融合 CFR 策略网络
(完全契合架构设计文档:包含独立的 Env Model 编码器)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple
class CFRNetworkModular(nn.Module):
def __init__(
self,
card_dim: int = 50,
env_dim: int = 5,
z_env_dim: int = 32, # Env Model 压缩后的局势特征维度
num_actions: int = 6,
):
super().__init__()
self.card_dim = card_dim
self.env_dim = env_dim
self.z_env_dim = z_env_dim
self.num_actions = num_actions
# ──────── 1. 独立的 Env Model 局势编码器 ────────
# 负责把底池、筹码、位置等物理量,升维/降维提炼成局势概念 (z_env)
self.env_encoder = nn.Sequential(
nn.Linear(env_dim, 64),
nn.LayerNorm(64),
nn.ReLU(),
nn.Linear(64, z_env_dim),
nn.ReLU()
)
# (保留挂载辅助任务的能力) 如果以后你想预测 EV随时把这两行取消注释
# self.ev_head = nn.Linear(z_env_dim, 1)
# ──────── 2. CFR 策略主干网络 (Late Fusion) ────────
# 接收纯净的牌面特征 (50维) + 提炼后的局势特征 (32维)
concat_dim = card_dim + z_env_dim
self.cfr_backbone = nn.Sequential(
nn.Linear(concat_dim, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU()
)
# ──────── 3. 双头输出 ────────
self.regret_head = nn.Linear(128, num_actions)
self.policy_head = nn.Linear(128, num_actions)
def forward(self, card_features: torch.Tensor, env_features: torch.Tensor):
# 1. 局势提炼
z_env = self.env_encoder(env_features)
# 2. 模块化拼接
x = torch.cat([card_features, z_env], dim=-1)
# 3. CFR 骨干与输出
x = self.cfr_backbone(x)
regrets = self.regret_head(x)
policy_logits = self.policy_head(x)
return regrets, policy_logits
def get_strategy(self, card_features, env_features, legal_mask):
"""与原代码完全一致的 Regret Matching 逻辑"""
regrets, policy_logits = self.forward(card_features, env_features)
# 即时策略
positive_regret = F.relu(regrets)
masked_regret = positive_regret * legal_mask
regret_sum = masked_regret.sum(dim=-1, keepdim=True)
num_legal = legal_mask.sum(dim=-1, keepdim=True).clamp(min=1.0)
uniform = legal_mask / num_legal
current_strategy = torch.where(regret_sum > 0, masked_regret / regret_sum, uniform)
# 平均策略
masked_logits = policy_logits.masked_fill(legal_mask == 0, float("-inf"))
avg_strategy = F.softmax(masked_logits, dim=-1)
all_illegal = (legal_mask.sum(dim=-1, keepdim=True) == 0)
avg_strategy = torch.where(all_illegal, uniform, avg_strategy)
return current_strategy, avg_strategy