Files
new/env_adapter.py
2026-05-06 17:46:46 +08:00

635 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
env_adapter.py — 环境状态提取与下注动作映射 (Env Adapter & Bet Translation)
本模块是 HUNL AI 系统的第二阶段,负责:
1. 将 OpenSpiel 引擎的连续下注动作空间离散化为 5 个 CFR 标准动作
2. 提取当前博弈节点的局势特征,供策略网络使用
3. 提供随机自对弈测试流水线,验证映射的健壮性
=== Botzone 比赛规则对齐 ===
- 玩家数: 2人 (Heads-up)
- 初始筹码: 每局固定 20000
- 盲注: SB=50, BB=100
- 轮次发牌: Preflop(0张), Flop(3张), Turn(1张), River(1张)
- 计分: 赢取筹码 / 100 = 最终得分 (纯 EV 导向)
- 边界规则: 筹码不足跟注/加注时, 只能 Fold 或 All-in
=== OpenSpiel fullgame 动作编码规则 ===
使用 bettingAbstraction=fullgame 时:
- action 0 = Fold弃牌
- action 1 = Call/Check跟注/过牌)
- action N (N>=2) = 总贡献额bet-to amount即下注后该玩家的累计投入
例如action 500 表示当前玩家累计投入 500而非额外加注 500
=== 底池与筹码计算 ===
- 实际底池 = sum(player_contributions) (来自 state.to_dict()
- 注意state.pot_size(multiple) 的语义是 "下注 multiple 倍底池大小后的总贡献额"
并非当前底池金额其公式为round(maxSpent + multiple * pot_after_call)
- 剩余筹码 = starting_stacks[i] - player_contributions[i]
- 归一化因子: STACK_NORMALIZE = 20000.0(与 Botzone 初始筹码对齐)
=== 5 个 CFR 离散动作 ===
0: FOLD — 弃牌
1: CALL — 跟注/过牌
2: HALF_POT — 1/2 底池加注(目标 = 当前最大贡献 + 1/2 × 跟注后底池)
3: FULL_POT — 满池底加注(目标 = 当前最大贡献 + 1.0 × 跟注后底池)
4: ALL_IN — 全押(引擎允许的最大动作)
"""
import re
import random
import logging
from typing import Dict, List, Optional, Tuple
import pyspiel
logger = logging.getLogger("poker.bet_translator")
# ---------------------------------------------------------------------------
# 常量定义
# ---------------------------------------------------------------------------
# 5 个离散化的 CFR 动作
CFR_ACTIONS = ["FOLD", "CALL", "HALF_POT", "FULL_POT", "ALL_IN"]
NUM_CFR_ACTIONS = len(CFR_ACTIONS) # 5
# 每个加注动作对应的底池倍率(用于计算目标贡献额)
# target = max_contribution + RAISE_MULTIPLIERS[idx] * pot_after_call
# 其中 pot_after_call = pot + call_amount跟注后的底池标准扑克定义
RAISE_MULTIPLIERS = {
2: 1 / 2, # HALF_POT: 加注额 = 1/2 跟注后底池
3: 1.0, # FULL_POT: 加注额 = 1.0 跟注后底池
}
# OpenSpiel 引擎动作常量
ENGINE_FOLD = 0 # Fold 固定为 action 0
ENGINE_CALL = 1 # Call/Check 固定为 action 1
# Botzone 比赛规则对齐常量
STACK_NORMALIZE = 20000.0 # 归一化因子 = Botzone 初始筹码 (env_features 的筹码/底池归一化)
# HUNL 游戏配置字符串(必须使用 fullgame 以获得连续下注空间)
# Botzone 规则: SB=50, BB=100, 初始筹码=20000
# OpenSpiel blind 参数顺序: [player0_blind, player1_blind]
# blind=50 100 → player0=SB(50), player1=BB(100)
HUNL_FULLGAME_STRING = (
"universal_poker("
"betting=nolimit,"
"numPlayers=2,"
"numRanks=13,"
"numSuits=4,"
"numHoleCards=2,"
"numRounds=4,"
"numBoardCards=0 3 1 1,"
"stack=20000 20000,"
"blind=50 100," # Botzone: SB=50, BB=100
"bettingAbstraction=fullgame)"
)
# 街道名称映射
STREET_NAMES = ["Preflop", "Flop", "Turn", "River"]
# ---------------------------------------------------------------------------
# 工具函数
# ---------------------------------------------------------------------------
def _get_street(state) -> int:
"""
从 state 的字符串表示中解析当前轮次street
OpenSpiel 的 universal_poker 未在 Python 层暴露 current_round() 方法,
但 str(state) 中包含 "Round: N" 字段,可以通过正则解析。
备选方案:通过 board_cards 数量推断——
0 张公共牌 = Preflop (0)
3 张公共牌 = Flop (1)
4 张公共牌 = Turn (2)
5 张公共牌 = River (3)
Returns:
int: 0=Preflop, 1=Flop, 2=Turn, 3=River
"""
# 优先从字符串解析
state_str = str(state)
match = re.search(r"Round:\s*(\d+)", state_str)
if match:
return int(match.group(1))
# 备选:通过公共牌数量推断
d = state.to_dict()
board_str = d.get("board_cards", "")
num_board = len(board_str) // 2 # 每张牌 2 个字符(如 "Ac", "Js"
if num_board == 0:
return 0 # Preflop
elif num_board == 3:
return 1 # Flop
elif num_board == 4:
return 2 # Turn
elif num_board == 5:
return 3 # River
else:
return 0 # 默认
def _get_pot_and_contributions(state) -> Tuple[int, List[int]]:
"""
获取当前底池金额和各玩家贡献额。
注意state.pot_size() 的返回值并非当前底池金额,
其语义为 "下注 1 倍底池后的总贡献额"(含数学公式)。
真实底池 = 两名玩家贡献额之和。
Returns:
(pot, contributions): 底池金额, [P0贡献, P1贡献]
"""
d = state.to_dict()
contributions = list(d["player_contributions"])
pot = sum(contributions)
return pot, contributions
def _get_stacks(state) -> List[int]:
"""
计算两名玩家的剩余筹码。
剩余筹码 = 起始筹码 - 已贡献金额
Returns:
[P0剩余筹码, P1剩余筹码]
"""
d = state.to_dict()
starting = list(d["starting_stacks"])
contributions = list(d["player_contributions"])
return [starting[i] - contributions[i] for i in range(2)]
def _get_bet_actions(legal_actions: List[int]) -> List[int]:
"""
从合法动作列表中提取所有加注动作action > 1
在 fullgame 模式下action > 1 代表具体的总贡献额。
这些动作按升序排列,最小值 = 最小加注,最大值 = All-in。
Returns:
排序后的加注动作列表
"""
return sorted([a for a in legal_actions if a > ENGINE_CALL])
def _find_nearest_legal_action(
legal_actions: List[int],
target_contribution: int,
) -> Optional[int]:
"""
在合法动作列表中找到最接近目标贡献额的合法动作。
使用二分查找的思想,遍历所有合法的加注动作(>1
找到与 target_contribution 差值绝对值最小的那个。
Args:
legal_actions: 引擎返回的合法动作列表
target_contribution: 目标总贡献额bet-to amount
Returns:
最接近的合法动作 ID如果没有加注动作可用则返回 None
"""
bet_actions = _get_bet_actions(legal_actions)
if not bet_actions:
return None
# 找差值最小的动作
best_action = min(bet_actions, key=lambda a: abs(a - target_contribution))
return best_action
def _count_raises_this_street(state) -> int:
"""
统计当前轮次 (street) 已经发生的加注次数。
使用 state.full_history() 逆向遍历:
从 history 末尾向前遍历,遇到发牌动作 (player < 0) 即停止,
统计该 street 内的所有加注动作 (action > 1)。
原理OpenSpiel full_history() 返回 List[(player_id, action)]
其中 chance 节点的 player_id < 0标志着新 street 的发牌分界线。
从末尾向前遍历,遇到第一个 chance 动作就意味着跨越了 street 边界,
此前统计的加注数即为当前 street 的加注次数。
Returns:
int: 当前 street 的加注次数
"""
raise_count = 0
for pa in reversed(state.full_history()):
# pyspiel.PlayerAction 对象,通过 .player / .action 属性访问
player = pa.player
action = pa.action
# chance 节点 (player < 0) = 发牌,是 street 分界线,停止向前追溯
if player < 0:
break
# fullgame 中 action > 1 代表加注 (bet-to amount)
if action > ENGINE_CALL:
raise_count += 1
return raise_count
# ---------------------------------------------------------------------------
# 1. 动作转换器类 (BetTranslator)
# ---------------------------------------------------------------------------
class BetTranslator:
"""
引擎动作 <-> CFR 动作 的双向映射器。
核心职责:
- get_cfr_legal_mask: 根据引擎当前合法动作,生成 5 维 CFR 动作掩码
- cfr_to_engine_action: 将 CFR 动作索引转换为引擎可直接执行的动作 ID
=== 5 个 CFR 动作 ===
0: FOLD — 弃牌
1: CALL — 跟注/过牌
2: HALF_POT — 1/2 底池加注
3: FULL_POT — 满池底加注
4: ALL_IN — 全押
=== 下注金额映射的数学逻辑 ===
在 No-Limit Hold'em 中,"加注到 X" 意味着当前玩家的总贡献额变为 X。
标准扑克规则:加注额基于"跟注后的底池"pot_after_call计算。
例如:当前玩家已贡献 100对手已贡献 300
- 跟注 (CALL) = 补到 300即额外投入 200
- 跟注后底池 = 100 + 300 + 200(跟注) = 600
- 1/2 底池加注 (HALF_POT):
目标贡献额 = 当前最大贡献 + 1/2 × 跟注后底池 = 300 + 300 = 600
- 满池底加注 (FULL_POT):
目标贡献额 = 当前最大贡献 + 1.0 × 跟注后底池 = 300 + 600 = 900
这与 OpenSpiel 的 state.pot_size(multiple) 公式一致:
pot_size(m) = round(maxSpent + m × pot_after_call)
注意OpenSpiel fullgame 中的 action 值就是"总贡献额"bet-to
因此计算目标后,在 legal_actions 中找最接近的合法 action ID。
=== 安全机制 ===
1. 如果 target < min_raise → 比例加注不可用fallback 到 CALL
2. 如果 target 映射到 ALL-IN → 比例加注不应导致 ALL-INfallback 到 CALL
3. CALL 永远不能映射到 raise
4. FOLD 永远不能映射到 CALL/raise
"""
# 加注次数上限:防止博弈树爆炸
# 当当前 street 的加注次数 >= 此值时,禁止 HALF_POT/FULL_POT
# 仅保留 ALL_IN 作为终结手段
RAISE_CAP = 2
def get_cfr_legal_mask(self, state) -> List[int]:
"""
生成 5 维 CFR 动作合法性掩码。
掩码逻辑:
- FOLD (0): 引擎合法动作包含 0 时为合法
- CALL (1): 引擎合法动作包含 1 时为合法
- HALF_POT (2): 需要加注动作可用 且 target >= min_raise 且映射不碰撞
- FULL_POT (3): 同上
- ALL_IN (4): 只要引擎存在 >1 的加注动作即合法
关键修正: 比例加注只有在映射后语义独立(不等于 ALL-IN时才标记为合法。
如果 target < min_raise 或 target 映射到 ALL-IN则该比例加注不可用。
Args:
state: OpenSpiel 的 State 对象
Returns:
长度为 5 的列表1=合法, 0=非法。例如 [1, 1, 1, 1, 1]
"""
legal = state.legal_actions()
has_fold = ENGINE_FOLD in legal
has_call = ENGINE_CALL in legal
bet_actions = _get_bet_actions(legal)
has_any_raise = len(bet_actions) > 0
# 计算当前 street 的加注次数
raise_count = _count_raises_this_street(state)
raise_capped = raise_count >= self.RAISE_CAP
mask = [0] * NUM_CFR_ACTIONS
mask[0] = 1 if has_fold else 0 # FOLD
mask[1] = 1 if has_call else 0 # CALL
# ALL_IN 始终允许(只要引擎支持加注)
mask[4] = 1 if has_any_raise else 0
if has_any_raise and not raise_capped:
all_in_action = bet_actions[-1]
min_raise = bet_actions[0]
pot, contributions = _get_pot_and_contributions(state)
max_contribution = max(contributions)
current_player = state.current_player()
my_contribution = contributions[current_player]
call_amount = max_contribution - my_contribution
pot_after_call = pot + call_amount
for cfr_idx, multiplier in RAISE_MULTIPLIERS.items():
target = max_contribution + multiplier * pot_after_call
target_int = int(round(target))
# 检查 1: target 低于 min_raise → 该比例加注不可用
if target_int < min_raise:
mask[cfr_idx] = 0
continue
# 检查 2: target 映射后是否等于 ALL-IN → 语义碰撞,标记为非法
# 比例加注不应导致 ALL-IN那是 ALL_IN 动作的职责)
nearest = _find_nearest_legal_action(legal, target_int)
if nearest is not None and nearest >= all_in_action:
mask[cfr_idx] = 0
continue
mask[cfr_idx] = 1
else:
mask[2] = 0
mask[3] = 0
# 如果掩码全 0理论不应发生兜底 Call
if sum(mask) == 0 and has_call:
mask[1] = 1
return mask
def cfr_to_engine_action(self, state, cfr_action_idx: int) -> int:
"""
将 CFR 动作索引 (0-4) 转换为 OpenSpiel 引擎可直接执行的动作整数。
转换逻辑详解:
(0) FOLD → 直接映射到引擎的 action 0
(1) CALL → 直接映射到引擎的 action 1Check/Call
(2) HALF_POT → 目标总贡献额 = max_contribution + 1/2 × pot_after_call
在 legal_actions 中找最接近的合法 action ID
(3) FULL_POT → 目标总贡献额 = max_contribution + 1.0 × pot_after_call
同上
(4) ALL_IN → 引擎合法动作中的最大值(即全部筹码)
=== 动作 Fallback 安全机制 ===
1. 如果 target < min_raise → 比例加注不可用fallback 到 CALL
2. 如果 target 映射到 ALL-IN → 比例加注不应导致 ALL-INfallback 到 CALL
3. 如果完全没有加注动作可用Fallback 到 Call → Fold → legal[0]
4. CALL 永远不能映射到 raise
5. FOLD 永远不能映射到 CALL/raise
Args:
state: OpenSpiel 的 State 对象
cfr_action_idx: CFR 动作索引 (0-4)
Returns:
int: 引擎可直接 state.apply_action() 的合法动作 ID
"""
legal = state.legal_actions()
bet_actions = _get_bet_actions(legal) # 所有 >1 的加注动作,已排序
# ---------------------------------------------------------------
# FOLD (0)
# ---------------------------------------------------------------
if cfr_action_idx == 0:
if ENGINE_FOLD in legal:
return ENGINE_FOLD
# 不能 Fold 时(如只能 checkfallback 到 Call
# 注意: 不应该到达这里,因为 legal_mask 会标记 FOLD=0
if ENGINE_CALL in legal:
return ENGINE_CALL
return legal[0]
# ---------------------------------------------------------------
# CALL (1)
# ---------------------------------------------------------------
if cfr_action_idx == 1:
if ENGINE_CALL in legal:
return ENGINE_CALL
# Fallback: CALL 不可用时不应该映射到 raise
if ENGINE_FOLD in legal:
return ENGINE_FOLD
return legal[0]
# ---------------------------------------------------------------
# 以下 2-4 都是加注动作,需要引擎支持加注
# Fallback: 无加注动作可用时 → Call → Fold → legal[0]
# ---------------------------------------------------------------
if not bet_actions:
if ENGINE_CALL in legal:
return ENGINE_CALL
if ENGINE_FOLD in legal:
return ENGINE_FOLD
return legal[0]
# ---------------------------------------------------------------
# ALL_IN (4) → 最大合法动作(全部筹码)
# ---------------------------------------------------------------
if cfr_action_idx == 4:
return bet_actions[-1] # 已排序,最后一个就是 All-in
# ---------------------------------------------------------------
# 比例加注 (2, 3)
# 使用 pot_after_call 计算目标(与 OpenSpiel pot_size() 对齐)
# ---------------------------------------------------------------
pot, contributions = _get_pot_and_contributions(state)
max_contribution = max(contributions)
current_player = state.current_player()
my_contribution = contributions[current_player]
call_amount = max_contribution - my_contribution
pot_after_call = pot + call_amount
multiplier = RAISE_MULTIPLIERS[cfr_action_idx]
target = max_contribution + multiplier * pot_after_call
target_int = int(round(target))
min_raise = bet_actions[0]
all_in_action = bet_actions[-1]
# 下界保护: target < min_raise → 该比例加注不可用fallback 到 Call
if target_int < min_raise:
logger.debug(
f"CFR→Engine: {CFR_ACTIONS[cfr_action_idx]} target={target_int} "
f"< min_raise={min_raise}, fallback to CALL"
)
if ENGINE_CALL in legal:
return ENGINE_CALL
return bet_actions[0]
# 上界保护: target 映射到 ALL-IN → 比例加注不应导致 ALL-IN
nearest = _find_nearest_legal_action(legal, target_int)
if nearest is not None and nearest >= all_in_action:
logger.debug(
f"CFR→Engine: {CFR_ACTIONS[cfr_action_idx]} target={target_int} "
f"maps to ALL-IN({all_in_action}), fallback to CALL"
)
if ENGINE_CALL in legal:
return ENGINE_CALL
return bet_actions[0]
# 正常映射
if nearest is not None:
logger.debug(
f"CFR→Engine: {CFR_ACTIONS[cfr_action_idx]} "
f"max_contrib={max_contribution} pot={pot} call={call_amount} "
f"pot_after_call={pot_after_call} target={target_int} -> {nearest}"
)
return nearest
return bet_actions[0]
# ---------------------------------------------------------------------------
# 2. 环境状态提取函数 (extract_env_state)
# ---------------------------------------------------------------------------
def extract_env_state(state) -> Dict:
"""
提取当前博弈树节点的局势特征,供 CFR 策略网络使用。
Botzone 规则对齐:
- pot/p0_stack/p1_stack 使用 STACK_NORMALIZE=20000.0 归一化
- 街道使用 STREET_NORMALIZE=3.0 归一化 (0/3 ~ 3/3)
- position 标识当前行动玩家 (0 或 1)
返回的字典包含:
- pot: 当前总底池(两玩家贡献额之和)
- p0_stack: 玩家 0 剩余筹码
- p1_stack: 玩家 1 剩余筹码
- street: 当前轮次 (0=Preflop, 1=Flop, 2=Turn, 3=River)
- position: 当前行动的玩家 (0 或 1)
- legal_mask: 长度为 5 的 CFR 动作合法性掩码
Args:
state: OpenSpiel 的 State 对象(必须是玩家节点,非 Chance/Terminal
Returns:
dict: 包含上述键的字典
"""
pot, contributions = _get_pot_and_contributions(state)
stacks = _get_stacks(state)
street = _get_street(state)
position = state.current_player()
legal_mask = BetTranslator().get_cfr_legal_mask(state)
return {
"pot": pot,
"p0_stack": stacks[0],
"p1_stack": stacks[1],
"street": street,
"position": position,
"legal_mask": legal_mask,
}
# ---------------------------------------------------------------------------
# 3. 自对弈测试流水线 (run_random_cfr_self_play)
# ---------------------------------------------------------------------------
def run_random_cfr_self_play(num_games: int = 5, verbose: bool = True):
"""
随机 CFR 动作自对弈测试。
流程:
1. 初始化 universal_poker 引擎fullgame筹码 20000盲注 SB=50/BB=100
2. while not state.is_terminal(): 循环
- Chance Node: 随机发牌
- Player Node:
a. 打印 extract_env_state 提取的信息
b. 打印当前合法的 CFR Mask
c. 在合法的 CFR 动作中随机挑选一个
d. 调用 cfr_to_engine_action 转换为引擎动作,打印映射过程
e. 执行该引擎动作
3. 打印对局 Returns双方最终收益
验证目标:跑 num_games 局,确保代码绝对不会因"非法动作"崩溃。
Args:
num_games: 对局数量(默认 5 局)
verbose: 是否打印详细信息
"""
game = pyspiel.load_game(HUNL_FULLGAME_STRING)
translator = BetTranslator()
for game_idx in range(num_games):
state = game.new_initial_state()
move_count = 0
if verbose:
print(f"\n{'='*60}")
print(f"{game_idx + 1}")
print(f"{'='*60}")
while not state.is_terminal():
# ---- Chance Node: 随机发牌 ----
if state.is_chance_node():
outcomes = state.chance_outcomes()
action_list, prob_list = zip(*outcomes)
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
state.apply_action(chance_action)
continue
# ---- Player Node: CFR 动作决策 ----
current_player = state.current_player()
env_info = extract_env_state(state)
legal_mask = env_info["legal_mask"]
if verbose:
street_name = STREET_NAMES[env_info["street"]] if env_info["street"] < 4 else "Unknown"
print(f"\n [Move {move_count}] Player {current_player} | {street_name}")
print(f" 底池: {env_info['pot']}, "
f"P0筹码: {env_info['p0_stack']}, "
f"P1筹码: {env_info['p1_stack']}")
print(f" CFR Mask: {legal_mask} "
f"({' '.join(f'{CFR_ACTIONS[i]}:{v}' for i, v in enumerate(legal_mask))})")
# 在合法的 CFR 动作中随机挑选一个
legal_cfr_actions = [i for i, m in enumerate(legal_mask) if m == 1]
if not legal_cfr_actions:
if verbose:
print(" 警告:没有合法的 CFR 动作!使用 Fallback。")
legal_cfr_actions = [1] # 强制 Call
cfr_action = random.choice(legal_cfr_actions)
# 转换为引擎动作
engine_action = translator.cfr_to_engine_action(state, cfr_action)
if verbose:
cfr_name = CFR_ACTIONS[cfr_action]
action_str = state.action_to_string(current_player, engine_action)
print(f" CFR: {cfr_name} ({cfr_action}) -> Engine Action: {engine_action} ({action_str})")
# 执行引擎动作
state.apply_action(engine_action)
move_count += 1
# ---- 对局结束,打印 Returns ----
returns = state.returns()
if verbose:
print(f"\n === 对局结束 ===")
print(f" Returns: P0 = {returns[0]:+.0f}, P1 = {returns[1]:+.0f}")
if returns[0] > returns[1]:
print(f" 胜者: Player 0")
elif returns[1] > returns[0]:
print(f" 胜者: Player 1")
else:
print(f" 平局")
if verbose:
print(f"\n{'='*60}")
print(f" 全部 {num_games} 局自对弈完成,无非法动作崩溃!")
print(f"{'='*60}")
# ---------------------------------------------------------------------------
# 入口:随机自对弈测试
# ---------------------------------------------------------------------------
if __name__ == "__main__":
run_random_cfr_self_play(num_games=5, verbose=True)