635 lines
24 KiB
Python
635 lines
24 KiB
Python
"""
|
||
env_adapter.py — 环境状态提取与下注动作映射 (Env Adapter & Bet Translation)
|
||
|
||
本模块是 HUNL AI 系统的第二阶段,负责:
|
||
1. 将 OpenSpiel 引擎的连续下注动作空间离散化为 5 个 CFR 标准动作
|
||
2. 提取当前博弈节点的局势特征,供策略网络使用
|
||
3. 提供随机自对弈测试流水线,验证映射的健壮性
|
||
|
||
=== Botzone 比赛规则对齐 ===
|
||
- 玩家数: 2人 (Heads-up)
|
||
- 初始筹码: 每局固定 20000
|
||
- 盲注: SB=50, BB=100
|
||
- 轮次发牌: Preflop(0张), Flop(3张), Turn(1张), River(1张)
|
||
- 计分: 赢取筹码 / 100 = 最终得分 (纯 EV 导向)
|
||
- 边界规则: 筹码不足跟注/加注时, 只能 Fold 或 All-in
|
||
|
||
=== OpenSpiel fullgame 动作编码规则 ===
|
||
使用 bettingAbstraction=fullgame 时:
|
||
- action 0 = Fold(弃牌)
|
||
- action 1 = Call/Check(跟注/过牌)
|
||
- action N (N>=2) = 总贡献额(bet-to amount),即下注后该玩家的累计投入
|
||
例如:action 500 表示当前玩家累计投入 500(而非额外加注 500)
|
||
|
||
=== 底池与筹码计算 ===
|
||
- 实际底池 = sum(player_contributions) (来自 state.to_dict())
|
||
- 注意:state.pot_size(multiple) 的语义是 "下注 multiple 倍底池大小后的总贡献额",
|
||
并非当前底池金额!其公式为:round(maxSpent + multiple * pot_after_call)
|
||
- 剩余筹码 = starting_stacks[i] - player_contributions[i]
|
||
- 归一化因子: STACK_NORMALIZE = 20000.0(与 Botzone 初始筹码对齐)
|
||
|
||
=== 5 个 CFR 离散动作 ===
|
||
0: FOLD — 弃牌
|
||
1: CALL — 跟注/过牌
|
||
2: HALF_POT — 1/2 底池加注(目标 = 当前最大贡献 + 1/2 × 跟注后底池)
|
||
3: FULL_POT — 满池底加注(目标 = 当前最大贡献 + 1.0 × 跟注后底池)
|
||
4: ALL_IN — 全押(引擎允许的最大动作)
|
||
"""
|
||
|
||
import re
|
||
import random
|
||
import logging
|
||
from typing import Dict, List, Optional, Tuple
|
||
|
||
import pyspiel
|
||
|
||
logger = logging.getLogger("poker.bet_translator")
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 常量定义
|
||
# ---------------------------------------------------------------------------
|
||
|
||
# 5 个离散化的 CFR 动作
|
||
CFR_ACTIONS = ["FOLD", "CALL", "HALF_POT", "FULL_POT", "ALL_IN"]
|
||
NUM_CFR_ACTIONS = len(CFR_ACTIONS) # 5
|
||
|
||
# 每个加注动作对应的底池倍率(用于计算目标贡献额)
|
||
# target = max_contribution + RAISE_MULTIPLIERS[idx] * pot_after_call
|
||
# 其中 pot_after_call = pot + call_amount(跟注后的底池,标准扑克定义)
|
||
RAISE_MULTIPLIERS = {
|
||
2: 1 / 2, # HALF_POT: 加注额 = 1/2 跟注后底池
|
||
3: 1.0, # FULL_POT: 加注额 = 1.0 跟注后底池
|
||
}
|
||
|
||
# OpenSpiel 引擎动作常量
|
||
ENGINE_FOLD = 0 # Fold 固定为 action 0
|
||
ENGINE_CALL = 1 # Call/Check 固定为 action 1
|
||
|
||
# Botzone 比赛规则对齐常量
|
||
STACK_NORMALIZE = 20000.0 # 归一化因子 = Botzone 初始筹码 (env_features 的筹码/底池归一化)
|
||
|
||
# HUNL 游戏配置字符串(必须使用 fullgame 以获得连续下注空间)
|
||
# Botzone 规则: SB=50, BB=100, 初始筹码=20000
|
||
# OpenSpiel blind 参数顺序: [player0_blind, player1_blind]
|
||
# blind=50 100 → player0=SB(50), player1=BB(100)
|
||
HUNL_FULLGAME_STRING = (
|
||
"universal_poker("
|
||
"betting=nolimit,"
|
||
"numPlayers=2,"
|
||
"numRanks=13,"
|
||
"numSuits=4,"
|
||
"numHoleCards=2,"
|
||
"numRounds=4,"
|
||
"numBoardCards=0 3 1 1,"
|
||
"stack=20000 20000,"
|
||
"blind=50 100," # Botzone: SB=50, BB=100
|
||
"bettingAbstraction=fullgame)"
|
||
)
|
||
|
||
# 街道名称映射
|
||
STREET_NAMES = ["Preflop", "Flop", "Turn", "River"]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 工具函数
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _get_street(state) -> int:
|
||
"""
|
||
从 state 的字符串表示中解析当前轮次(street)。
|
||
|
||
OpenSpiel 的 universal_poker 未在 Python 层暴露 current_round() 方法,
|
||
但 str(state) 中包含 "Round: N" 字段,可以通过正则解析。
|
||
|
||
备选方案:通过 board_cards 数量推断——
|
||
0 张公共牌 = Preflop (0)
|
||
3 张公共牌 = Flop (1)
|
||
4 张公共牌 = Turn (2)
|
||
5 张公共牌 = River (3)
|
||
|
||
Returns:
|
||
int: 0=Preflop, 1=Flop, 2=Turn, 3=River
|
||
"""
|
||
# 优先从字符串解析
|
||
state_str = str(state)
|
||
match = re.search(r"Round:\s*(\d+)", state_str)
|
||
if match:
|
||
return int(match.group(1))
|
||
|
||
# 备选:通过公共牌数量推断
|
||
d = state.to_dict()
|
||
board_str = d.get("board_cards", "")
|
||
num_board = len(board_str) // 2 # 每张牌 2 个字符(如 "Ac", "Js")
|
||
if num_board == 0:
|
||
return 0 # Preflop
|
||
elif num_board == 3:
|
||
return 1 # Flop
|
||
elif num_board == 4:
|
||
return 2 # Turn
|
||
elif num_board == 5:
|
||
return 3 # River
|
||
else:
|
||
return 0 # 默认
|
||
|
||
|
||
def _get_pot_and_contributions(state) -> Tuple[int, List[int]]:
|
||
"""
|
||
获取当前底池金额和各玩家贡献额。
|
||
|
||
注意:state.pot_size() 的返回值并非当前底池金额,
|
||
其语义为 "下注 1 倍底池后的总贡献额"(含数学公式)。
|
||
真实底池 = 两名玩家贡献额之和。
|
||
|
||
Returns:
|
||
(pot, contributions): 底池金额, [P0贡献, P1贡献]
|
||
"""
|
||
d = state.to_dict()
|
||
contributions = list(d["player_contributions"])
|
||
pot = sum(contributions)
|
||
return pot, contributions
|
||
|
||
|
||
def _get_stacks(state) -> List[int]:
|
||
"""
|
||
计算两名玩家的剩余筹码。
|
||
|
||
剩余筹码 = 起始筹码 - 已贡献金额
|
||
|
||
Returns:
|
||
[P0剩余筹码, P1剩余筹码]
|
||
"""
|
||
d = state.to_dict()
|
||
starting = list(d["starting_stacks"])
|
||
contributions = list(d["player_contributions"])
|
||
return [starting[i] - contributions[i] for i in range(2)]
|
||
|
||
|
||
def _get_bet_actions(legal_actions: List[int]) -> List[int]:
|
||
"""
|
||
从合法动作列表中提取所有加注动作(action > 1)。
|
||
|
||
在 fullgame 模式下,action > 1 代表具体的总贡献额。
|
||
这些动作按升序排列,最小值 = 最小加注,最大值 = All-in。
|
||
|
||
Returns:
|
||
排序后的加注动作列表
|
||
"""
|
||
return sorted([a for a in legal_actions if a > ENGINE_CALL])
|
||
|
||
|
||
def _find_nearest_legal_action(
|
||
legal_actions: List[int],
|
||
target_contribution: int,
|
||
) -> Optional[int]:
|
||
"""
|
||
在合法动作列表中找到最接近目标贡献额的合法动作。
|
||
|
||
使用二分查找的思想,遍历所有合法的加注动作(>1),
|
||
找到与 target_contribution 差值绝对值最小的那个。
|
||
|
||
Args:
|
||
legal_actions: 引擎返回的合法动作列表
|
||
target_contribution: 目标总贡献额(bet-to amount)
|
||
|
||
Returns:
|
||
最接近的合法动作 ID,如果没有加注动作可用则返回 None
|
||
"""
|
||
bet_actions = _get_bet_actions(legal_actions)
|
||
if not bet_actions:
|
||
return None
|
||
|
||
# 找差值最小的动作
|
||
best_action = min(bet_actions, key=lambda a: abs(a - target_contribution))
|
||
return best_action
|
||
|
||
|
||
def _count_raises_this_street(state) -> int:
|
||
"""
|
||
统计当前轮次 (street) 已经发生的加注次数。
|
||
|
||
使用 state.full_history() 逆向遍历:
|
||
从 history 末尾向前遍历,遇到发牌动作 (player < 0) 即停止,
|
||
统计该 street 内的所有加注动作 (action > 1)。
|
||
|
||
原理:OpenSpiel full_history() 返回 List[(player_id, action)],
|
||
其中 chance 节点的 player_id < 0,标志着新 street 的发牌分界线。
|
||
从末尾向前遍历,遇到第一个 chance 动作就意味着跨越了 street 边界,
|
||
此前统计的加注数即为当前 street 的加注次数。
|
||
|
||
Returns:
|
||
int: 当前 street 的加注次数
|
||
"""
|
||
raise_count = 0
|
||
for pa in reversed(state.full_history()):
|
||
# pyspiel.PlayerAction 对象,通过 .player / .action 属性访问
|
||
player = pa.player
|
||
action = pa.action
|
||
# chance 节点 (player < 0) = 发牌,是 street 分界线,停止向前追溯
|
||
if player < 0:
|
||
break
|
||
# fullgame 中 action > 1 代表加注 (bet-to amount)
|
||
if action > ENGINE_CALL:
|
||
raise_count += 1
|
||
return raise_count
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 1. 动作转换器类 (BetTranslator)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class BetTranslator:
|
||
"""
|
||
引擎动作 <-> CFR 动作 的双向映射器。
|
||
|
||
核心职责:
|
||
- get_cfr_legal_mask: 根据引擎当前合法动作,生成 5 维 CFR 动作掩码
|
||
- cfr_to_engine_action: 将 CFR 动作索引转换为引擎可直接执行的动作 ID
|
||
|
||
=== 5 个 CFR 动作 ===
|
||
0: FOLD — 弃牌
|
||
1: CALL — 跟注/过牌
|
||
2: HALF_POT — 1/2 底池加注
|
||
3: FULL_POT — 满池底加注
|
||
4: ALL_IN — 全押
|
||
|
||
=== 下注金额映射的数学逻辑 ===
|
||
|
||
在 No-Limit Hold'em 中,"加注到 X" 意味着当前玩家的总贡献额变为 X。
|
||
标准扑克规则:加注额基于"跟注后的底池"(pot_after_call)计算。
|
||
|
||
例如:当前玩家已贡献 100,对手已贡献 300,则:
|
||
- 跟注 (CALL) = 补到 300,即额外投入 200
|
||
- 跟注后底池 = 100 + 300 + 200(跟注) = 600
|
||
- 1/2 底池加注 (HALF_POT):
|
||
目标贡献额 = 当前最大贡献 + 1/2 × 跟注后底池 = 300 + 300 = 600
|
||
- 满池底加注 (FULL_POT):
|
||
目标贡献额 = 当前最大贡献 + 1.0 × 跟注后底池 = 300 + 600 = 900
|
||
|
||
这与 OpenSpiel 的 state.pot_size(multiple) 公式一致:
|
||
pot_size(m) = round(maxSpent + m × pot_after_call)
|
||
|
||
注意:OpenSpiel fullgame 中的 action 值就是"总贡献额"(bet-to),
|
||
因此计算目标后,在 legal_actions 中找最接近的合法 action ID。
|
||
|
||
=== 安全机制 ===
|
||
1. 如果 target < min_raise → 比例加注不可用,fallback 到 CALL
|
||
2. 如果 target 映射到 ALL-IN → 比例加注不应导致 ALL-IN,fallback 到 CALL
|
||
3. CALL 永远不能映射到 raise
|
||
4. FOLD 永远不能映射到 CALL/raise
|
||
"""
|
||
|
||
# 加注次数上限:防止博弈树爆炸
|
||
# 当当前 street 的加注次数 >= 此值时,禁止 HALF_POT/FULL_POT,
|
||
# 仅保留 ALL_IN 作为终结手段
|
||
RAISE_CAP = 2
|
||
|
||
def get_cfr_legal_mask(self, state) -> List[int]:
|
||
"""
|
||
生成 5 维 CFR 动作合法性掩码。
|
||
|
||
掩码逻辑:
|
||
- FOLD (0): 引擎合法动作包含 0 时为合法
|
||
- CALL (1): 引擎合法动作包含 1 时为合法
|
||
- HALF_POT (2): 需要加注动作可用 且 target >= min_raise 且映射不碰撞
|
||
- FULL_POT (3): 同上
|
||
- ALL_IN (4): 只要引擎存在 >1 的加注动作即合法
|
||
|
||
关键修正: 比例加注只有在映射后语义独立(不等于 ALL-IN)时才标记为合法。
|
||
如果 target < min_raise 或 target 映射到 ALL-IN,则该比例加注不可用。
|
||
|
||
Args:
|
||
state: OpenSpiel 的 State 对象
|
||
|
||
Returns:
|
||
长度为 5 的列表,1=合法, 0=非法。例如 [1, 1, 1, 1, 1]
|
||
"""
|
||
legal = state.legal_actions()
|
||
has_fold = ENGINE_FOLD in legal
|
||
has_call = ENGINE_CALL in legal
|
||
bet_actions = _get_bet_actions(legal)
|
||
has_any_raise = len(bet_actions) > 0
|
||
|
||
# 计算当前 street 的加注次数
|
||
raise_count = _count_raises_this_street(state)
|
||
raise_capped = raise_count >= self.RAISE_CAP
|
||
|
||
mask = [0] * NUM_CFR_ACTIONS
|
||
mask[0] = 1 if has_fold else 0 # FOLD
|
||
mask[1] = 1 if has_call else 0 # CALL
|
||
|
||
# ALL_IN 始终允许(只要引擎支持加注)
|
||
mask[4] = 1 if has_any_raise else 0
|
||
|
||
if has_any_raise and not raise_capped:
|
||
all_in_action = bet_actions[-1]
|
||
min_raise = bet_actions[0]
|
||
|
||
pot, contributions = _get_pot_and_contributions(state)
|
||
max_contribution = max(contributions)
|
||
current_player = state.current_player()
|
||
my_contribution = contributions[current_player]
|
||
call_amount = max_contribution - my_contribution
|
||
pot_after_call = pot + call_amount
|
||
|
||
for cfr_idx, multiplier in RAISE_MULTIPLIERS.items():
|
||
target = max_contribution + multiplier * pot_after_call
|
||
target_int = int(round(target))
|
||
|
||
# 检查 1: target 低于 min_raise → 该比例加注不可用
|
||
if target_int < min_raise:
|
||
mask[cfr_idx] = 0
|
||
continue
|
||
|
||
# 检查 2: target 映射后是否等于 ALL-IN → 语义碰撞,标记为非法
|
||
# 比例加注不应导致 ALL-IN(那是 ALL_IN 动作的职责)
|
||
nearest = _find_nearest_legal_action(legal, target_int)
|
||
if nearest is not None and nearest >= all_in_action:
|
||
mask[cfr_idx] = 0
|
||
continue
|
||
|
||
mask[cfr_idx] = 1
|
||
else:
|
||
mask[2] = 0
|
||
mask[3] = 0
|
||
|
||
# 如果掩码全 0(理论不应发生),兜底 Call
|
||
if sum(mask) == 0 and has_call:
|
||
mask[1] = 1
|
||
|
||
return mask
|
||
|
||
def cfr_to_engine_action(self, state, cfr_action_idx: int) -> int:
|
||
"""
|
||
将 CFR 动作索引 (0-4) 转换为 OpenSpiel 引擎可直接执行的动作整数。
|
||
|
||
转换逻辑详解:
|
||
|
||
(0) FOLD → 直接映射到引擎的 action 0
|
||
|
||
(1) CALL → 直接映射到引擎的 action 1(Check/Call)
|
||
|
||
(2) HALF_POT → 目标总贡献额 = max_contribution + 1/2 × pot_after_call
|
||
在 legal_actions 中找最接近的合法 action ID
|
||
|
||
(3) FULL_POT → 目标总贡献额 = max_contribution + 1.0 × pot_after_call
|
||
同上
|
||
|
||
(4) ALL_IN → 引擎合法动作中的最大值(即全部筹码)
|
||
|
||
=== 动作 Fallback 安全机制 ===
|
||
1. 如果 target < min_raise → 比例加注不可用,fallback 到 CALL
|
||
2. 如果 target 映射到 ALL-IN → 比例加注不应导致 ALL-IN,fallback 到 CALL
|
||
3. 如果完全没有加注动作可用,Fallback 到 Call → Fold → legal[0]
|
||
4. CALL 永远不能映射到 raise
|
||
5. FOLD 永远不能映射到 CALL/raise
|
||
|
||
Args:
|
||
state: OpenSpiel 的 State 对象
|
||
cfr_action_idx: CFR 动作索引 (0-4)
|
||
|
||
Returns:
|
||
int: 引擎可直接 state.apply_action() 的合法动作 ID
|
||
"""
|
||
legal = state.legal_actions()
|
||
bet_actions = _get_bet_actions(legal) # 所有 >1 的加注动作,已排序
|
||
|
||
# ---------------------------------------------------------------
|
||
# FOLD (0)
|
||
# ---------------------------------------------------------------
|
||
if cfr_action_idx == 0:
|
||
if ENGINE_FOLD in legal:
|
||
return ENGINE_FOLD
|
||
# 不能 Fold 时(如只能 check),fallback 到 Call
|
||
# 注意: 不应该到达这里,因为 legal_mask 会标记 FOLD=0
|
||
if ENGINE_CALL in legal:
|
||
return ENGINE_CALL
|
||
return legal[0]
|
||
|
||
# ---------------------------------------------------------------
|
||
# CALL (1)
|
||
# ---------------------------------------------------------------
|
||
if cfr_action_idx == 1:
|
||
if ENGINE_CALL in legal:
|
||
return ENGINE_CALL
|
||
# Fallback: CALL 不可用时不应该映射到 raise
|
||
if ENGINE_FOLD in legal:
|
||
return ENGINE_FOLD
|
||
return legal[0]
|
||
|
||
# ---------------------------------------------------------------
|
||
# 以下 2-4 都是加注动作,需要引擎支持加注
|
||
# Fallback: 无加注动作可用时 → Call → Fold → legal[0]
|
||
# ---------------------------------------------------------------
|
||
if not bet_actions:
|
||
if ENGINE_CALL in legal:
|
||
return ENGINE_CALL
|
||
if ENGINE_FOLD in legal:
|
||
return ENGINE_FOLD
|
||
return legal[0]
|
||
|
||
# ---------------------------------------------------------------
|
||
# ALL_IN (4) → 最大合法动作(全部筹码)
|
||
# ---------------------------------------------------------------
|
||
if cfr_action_idx == 4:
|
||
return bet_actions[-1] # 已排序,最后一个就是 All-in
|
||
|
||
# ---------------------------------------------------------------
|
||
# 比例加注 (2, 3)
|
||
# 使用 pot_after_call 计算目标(与 OpenSpiel pot_size() 对齐)
|
||
# ---------------------------------------------------------------
|
||
pot, contributions = _get_pot_and_contributions(state)
|
||
max_contribution = max(contributions)
|
||
current_player = state.current_player()
|
||
my_contribution = contributions[current_player]
|
||
call_amount = max_contribution - my_contribution
|
||
pot_after_call = pot + call_amount
|
||
|
||
multiplier = RAISE_MULTIPLIERS[cfr_action_idx]
|
||
target = max_contribution + multiplier * pot_after_call
|
||
target_int = int(round(target))
|
||
|
||
min_raise = bet_actions[0]
|
||
all_in_action = bet_actions[-1]
|
||
|
||
# 下界保护: target < min_raise → 该比例加注不可用,fallback 到 Call
|
||
if target_int < min_raise:
|
||
logger.debug(
|
||
f"CFR→Engine: {CFR_ACTIONS[cfr_action_idx]} target={target_int} "
|
||
f"< min_raise={min_raise}, fallback to CALL"
|
||
)
|
||
if ENGINE_CALL in legal:
|
||
return ENGINE_CALL
|
||
return bet_actions[0]
|
||
|
||
# 上界保护: target 映射到 ALL-IN → 比例加注不应导致 ALL-IN
|
||
nearest = _find_nearest_legal_action(legal, target_int)
|
||
if nearest is not None and nearest >= all_in_action:
|
||
logger.debug(
|
||
f"CFR→Engine: {CFR_ACTIONS[cfr_action_idx]} target={target_int} "
|
||
f"maps to ALL-IN({all_in_action}), fallback to CALL"
|
||
)
|
||
if ENGINE_CALL in legal:
|
||
return ENGINE_CALL
|
||
return bet_actions[0]
|
||
|
||
# 正常映射
|
||
if nearest is not None:
|
||
logger.debug(
|
||
f"CFR→Engine: {CFR_ACTIONS[cfr_action_idx]} "
|
||
f"max_contrib={max_contribution} pot={pot} call={call_amount} "
|
||
f"pot_after_call={pot_after_call} target={target_int} -> {nearest}"
|
||
)
|
||
return nearest
|
||
return bet_actions[0]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 2. 环境状态提取函数 (extract_env_state)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def extract_env_state(state) -> Dict:
|
||
"""
|
||
提取当前博弈树节点的局势特征,供 CFR 策略网络使用。
|
||
|
||
Botzone 规则对齐:
|
||
- pot/p0_stack/p1_stack 使用 STACK_NORMALIZE=20000.0 归一化
|
||
- 街道使用 STREET_NORMALIZE=3.0 归一化 (0/3 ~ 3/3)
|
||
- position 标识当前行动玩家 (0 或 1)
|
||
|
||
返回的字典包含:
|
||
- pot: 当前总底池(两玩家贡献额之和)
|
||
- p0_stack: 玩家 0 剩余筹码
|
||
- p1_stack: 玩家 1 剩余筹码
|
||
- street: 当前轮次 (0=Preflop, 1=Flop, 2=Turn, 3=River)
|
||
- position: 当前行动的玩家 (0 或 1)
|
||
- legal_mask: 长度为 5 的 CFR 动作合法性掩码
|
||
|
||
Args:
|
||
state: OpenSpiel 的 State 对象(必须是玩家节点,非 Chance/Terminal)
|
||
|
||
Returns:
|
||
dict: 包含上述键的字典
|
||
"""
|
||
pot, contributions = _get_pot_and_contributions(state)
|
||
stacks = _get_stacks(state)
|
||
street = _get_street(state)
|
||
position = state.current_player()
|
||
legal_mask = BetTranslator().get_cfr_legal_mask(state)
|
||
|
||
return {
|
||
"pot": pot,
|
||
"p0_stack": stacks[0],
|
||
"p1_stack": stacks[1],
|
||
"street": street,
|
||
"position": position,
|
||
"legal_mask": legal_mask,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 3. 自对弈测试流水线 (run_random_cfr_self_play)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def run_random_cfr_self_play(num_games: int = 5, verbose: bool = True):
|
||
"""
|
||
随机 CFR 动作自对弈测试。
|
||
|
||
流程:
|
||
1. 初始化 universal_poker 引擎(fullgame,筹码 20000,盲注 SB=50/BB=100)
|
||
2. while not state.is_terminal(): 循环
|
||
- Chance Node: 随机发牌
|
||
- Player Node:
|
||
a. 打印 extract_env_state 提取的信息
|
||
b. 打印当前合法的 CFR Mask
|
||
c. 在合法的 CFR 动作中随机挑选一个
|
||
d. 调用 cfr_to_engine_action 转换为引擎动作,打印映射过程
|
||
e. 执行该引擎动作
|
||
3. 打印对局 Returns(双方最终收益)
|
||
|
||
验证目标:跑 num_games 局,确保代码绝对不会因"非法动作"崩溃。
|
||
|
||
Args:
|
||
num_games: 对局数量(默认 5 局)
|
||
verbose: 是否打印详细信息
|
||
"""
|
||
game = pyspiel.load_game(HUNL_FULLGAME_STRING)
|
||
translator = BetTranslator()
|
||
|
||
for game_idx in range(num_games):
|
||
state = game.new_initial_state()
|
||
move_count = 0
|
||
|
||
if verbose:
|
||
print(f"\n{'='*60}")
|
||
print(f" 第 {game_idx + 1} 局")
|
||
print(f"{'='*60}")
|
||
|
||
while not state.is_terminal():
|
||
# ---- Chance Node: 随机发牌 ----
|
||
if state.is_chance_node():
|
||
outcomes = state.chance_outcomes()
|
||
action_list, prob_list = zip(*outcomes)
|
||
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
|
||
state.apply_action(chance_action)
|
||
continue
|
||
|
||
# ---- Player Node: CFR 动作决策 ----
|
||
current_player = state.current_player()
|
||
env_info = extract_env_state(state)
|
||
legal_mask = env_info["legal_mask"]
|
||
|
||
if verbose:
|
||
street_name = STREET_NAMES[env_info["street"]] if env_info["street"] < 4 else "Unknown"
|
||
print(f"\n [Move {move_count}] Player {current_player} | {street_name}")
|
||
print(f" 底池: {env_info['pot']}, "
|
||
f"P0筹码: {env_info['p0_stack']}, "
|
||
f"P1筹码: {env_info['p1_stack']}")
|
||
print(f" CFR Mask: {legal_mask} "
|
||
f"({' '.join(f'{CFR_ACTIONS[i]}:{v}' for i, v in enumerate(legal_mask))})")
|
||
|
||
# 在合法的 CFR 动作中随机挑选一个
|
||
legal_cfr_actions = [i for i, m in enumerate(legal_mask) if m == 1]
|
||
if not legal_cfr_actions:
|
||
if verbose:
|
||
print(" 警告:没有合法的 CFR 动作!使用 Fallback。")
|
||
legal_cfr_actions = [1] # 强制 Call
|
||
|
||
cfr_action = random.choice(legal_cfr_actions)
|
||
|
||
# 转换为引擎动作
|
||
engine_action = translator.cfr_to_engine_action(state, cfr_action)
|
||
|
||
if verbose:
|
||
cfr_name = CFR_ACTIONS[cfr_action]
|
||
action_str = state.action_to_string(current_player, engine_action)
|
||
print(f" CFR: {cfr_name} ({cfr_action}) -> Engine Action: {engine_action} ({action_str})")
|
||
|
||
# 执行引擎动作
|
||
state.apply_action(engine_action)
|
||
move_count += 1
|
||
|
||
# ---- 对局结束,打印 Returns ----
|
||
returns = state.returns()
|
||
if verbose:
|
||
print(f"\n === 对局结束 ===")
|
||
print(f" Returns: P0 = {returns[0]:+.0f}, P1 = {returns[1]:+.0f}")
|
||
if returns[0] > returns[1]:
|
||
print(f" 胜者: Player 0")
|
||
elif returns[1] > returns[0]:
|
||
print(f" 胜者: Player 1")
|
||
else:
|
||
print(f" 平局")
|
||
|
||
if verbose:
|
||
print(f"\n{'='*60}")
|
||
print(f" 全部 {num_games} 局自对弈完成,无非法动作崩溃!")
|
||
print(f"{'='*60}")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 入口:随机自对弈测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
if __name__ == "__main__":
|
||
run_random_cfr_self_play(num_games=5, verbose=True)
|