#!/usr/bin/env python3 """ play_arena.py — 命令行德州扑克对战环境 (Poker Arena) 完全基于 OpenSpiel 原生 Game Loop,Python 层仅做薄 UI 翻译: - HumanPlayer: 将用户输入 (F/C/数字) 映射到 state.legal_actions() 中的 action ID - AIPlayer: 将 CFRNetwork 的 0-4 离散动作经 BetTranslator 映射为引擎 action ID 支持模式: 1: 人类 vs 人类 2: 人类 vs AI (人类 P0, AI P1) 3: AI vs 人类 (AI P0, 人类 P1) 4: AI vs AI (观战) 用法: cd poker/ python play_tool/play_arena.py """ import os import sys import random # ── 导入父目录模块 ── _POKER_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) if _POKER_DIR not in sys.path: sys.path.insert(0, _POKER_DIR) import torch import pyspiel from env_adapter import ( HUNL_FULLGAME_STRING, BetTranslator, extract_env_state, CFR_ACTIONS, NUM_CFR_ACTIONS, STACK_NORMALIZE, STREET_NAMES, ) from card_model.config import PAD_TOKEN, BOARD_SIZE from card_model.data_generator import extract_cards_from_state from card_model.model import CardModel from cfr_net import CFRNetwork, CARD_DIM, ENV_DIM, NUM_ACTIONS # ───────────────────── 常量 ───────────────────── STREET_NORMALIZE = 3.0 CARD_MODEL_CHECKPOINT = os.path.join(_POKER_DIR, "card_model", "data", "best_card_model.pt") CFR_NET_CHECKPOINT = os.path.join(_POKER_DIR, "botzone_cfr_net.pt") _SUIT_SYMBOLS = {'c': '\u2663', 'd': '\u2666', 'h': '\u2665', 's': '\u2660'} # ───────────────────── 牌面显示工具 ───────────────────── def card_str_to_display(card_str: str) -> str: """将引擎的2字符牌面 (如 'Ac', 'Js') 转为带花色符号的显示。""" if len(card_str) != 2: return card_str rank_ch, suit_ch = card_str[0], card_str[1] return f"{rank_ch}{_SUIT_SYMBOLS.get(suit_ch, suit_ch)}" def parse_hand_str(hand_str: str) -> str: """将引擎返回的连体牌面字符串 (如 'AcJs') 拆分并美化。""" if not hand_str: return "" cards = [] i = 0 while i + 1 < len(hand_str): cards.append(card_str_to_display(hand_str[i] + hand_str[i + 1])) i += 2 return " ".join(cards) # ───────────────────── 特征构建 ───────────────────── def build_env_features(env_info: dict) -> torch.Tensor: """将 extract_env_state 返回的字典归一化为 5 维 env_features Tensor。""" features = [ env_info["pot"] / STACK_NORMALIZE, env_info["p0_stack"] / STACK_NORMALIZE, env_info["p1_stack"] / STACK_NORMALIZE, env_info["street"] / STREET_NORMALIZE, float(env_info["position"]), ] return torch.tensor(features, dtype=torch.float32) def build_card_features(card_model: CardModel, state) -> torch.Tensor: """使用 CardModel 从当前 state 提取 50 维胜率直方图。""" hole_cards, board_cards = extract_cards_from_state(state) model_device = next(card_model.parameters()).device x_hole = torch.tensor([hole_cards], dtype=torch.int64, device=model_device) padded_board = board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards)) x_board = torch.tensor([padded_board], dtype=torch.int64, device=model_device) with torch.no_grad(): _, pred_histogram = card_model(x_hole, x_board) return pred_histogram.squeeze(0).cpu() # ───────────────────── Player 基类 ───────────────────── class PlayerBase: """玩家基类。choose() 返回引擎原生 action ID。""" def __init__(self, player_id: int, name: str): self.player_id = player_id self.name = name def choose(self, state) -> int: """ 根据当前状态返回引擎原生 action ID。 Args: state: OpenSpiel State 对象 Returns: int: 可直接传给 state.apply_action() 的引擎动作 ID """ raise NotImplementedError # ───────────────────── AIPlayer ───────────────────── class AIPlayer(PlayerBase): """AI 玩家:CardModel + CFRNetwork avg_strategy 采样 → BetTranslator 映射。""" def __init__(self, player_id: int, name: str = "AI"): super().__init__(player_id, name) # 加载 CardModel self.card_model = CardModel() if os.path.isfile(CARD_MODEL_CHECKPOINT): ckpt = torch.load(CARD_MODEL_CHECKPOINT, map_location="cpu", weights_only=False) self.card_model.load_state_dict( ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt ) print(f"[CardModel] 已加载权重: {CARD_MODEL_CHECKPOINT}") else: print(f"[CardModel] 警告: 未找到权重 {CARD_MODEL_CHECKPOINT},使用随机初始化") self.card_model.eval() # 加载 CFRNetwork self.cfr_net = CFRNetwork( card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS, ) if os.path.isfile(CFR_NET_CHECKPOINT): ckpt = torch.load(CFR_NET_CHECKPOINT, map_location="cpu", weights_only=False) self.cfr_net.load_state_dict( ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt ) print(f"[CFRNetwork] 已加载权重: {CFR_NET_CHECKPOINT}") else: print(f"[CFRNetwork] 警告: 未找到权重 {CFR_NET_CHECKPOINT},使用随机初始化") self.cfr_net.eval() # BetTranslator:CFR 离散动作 → 引擎 action ID self.translator = BetTranslator() def choose(self, state) -> int: legal_actions = state.legal_actions() # 强制动作(盲注等):直接返回 if len(legal_actions) == 1: return legal_actions[0] # 提取特征,送入 CFRNetwork env_info = extract_env_state(state) env_features = build_env_features(env_info) card_features = build_card_features(self.card_model, state) legal_mask = env_info["legal_mask"] card_input = card_features.unsqueeze(0) # [1, 50] env_input = env_features.unsqueeze(0) # [1, 5] legal_mask_tensor = torch.tensor([legal_mask], dtype=torch.float32) with torch.no_grad(): _, avg_strategy = self.cfr_net.get_strategy( card_input, env_input, legal_mask_tensor ) strategy = avg_strategy.squeeze(0).cpu().tolist() legal_indices = [i for i, m in enumerate(legal_mask) if m == 1] if not legal_indices: legal_indices = [1] probs = [strategy[i] for i in legal_indices] prob_sum = sum(probs) if prob_sum > 0: probs = [p / prob_sum for p in probs] else: probs = [1.0 / len(legal_indices)] * len(legal_indices) # 按 avg_strategy 概率采样一个 CFR 离散动作 #chosen_cfr_idx = random.choices(legal_indices, weights=probs, k=1)[0] print(f"\n [AI 大脑读取中...] 当前合法 CFR 动作概率:") for idx, p in zip(legal_indices, probs): print(f" - {CFR_ACTIONS[idx]:<10}: {p*100:5.1f}%") # === 核心修改 2:过滤神经网络底噪 (降噪器) === # 将概率低于 3% 的动作直接置为 0,防止 1% 的随机发疯 All-in NOISE_THRESHOLD = 0.03 filtered_probs = [p if p > NOISE_THRESHOLD else 0.0 for p in probs] prob_sum = sum(filtered_probs) if prob_sum > 0: # 重新归一化 filtered_probs = [p / prob_sum for p in filtered_probs] else: # 如果全部低于阈值(极端情况),就选原本概率最大的那个 (贪心) best_idx = probs.index(max(probs)) filtered_probs = [1.0 if i == best_idx else 0.0 for i in range(len(probs))] # 按过滤后的纯净概率采样 chosen_cfr_idx = random.choices(legal_indices, weights=filtered_probs, k=1)[0] # 通过 BetTranslator 将 CFR 动作映射为引擎原生 action ID engine_action = self.translator.cfr_to_engine_action(state, chosen_cfr_idx) # === 安全防线: 防止 CALL 被映射为 raise === if chosen_cfr_idx == 1 and engine_action > 1: engine_action = 1 # 强制 CALL # === 安全防线: 防止比例加注被映射为 ALL-IN === if chosen_cfr_idx in (2, 3): bet_actions_guard = [a for a in state.legal_actions() if a > 1] if bet_actions_guard and engine_action >= bet_actions_guard[-1]: engine_action = 1 # fallback 到 CALL cp = state.current_player() action_desc = state.action_to_string(cp, engine_action) print(f" {self.name} (P{self.player_id}) 选择了 {action_desc}") return engine_action # ───────────────────── HumanPlayer ───────────────────── class HumanPlayer(PlayerBase): """ 人类玩家:直接使用 state.legal_actions() 作为绝对真理。 动作语义(fullgame 模式下): - action 0 = Fold - action 1 = Call / Check - action > 1 = 加注到该数值(总贡献额 bet-to) 用户输入: - F = 弃牌 - C = 跟注/过牌 - 数字 = 加注到该金额(自动映射到最接近的合法 action ID) """ def __init__(self, player_id: int, name: str = "你"): super().__init__(player_id, name) def choose(self, state) -> int: legal_actions = state.legal_actions() # 强制动作(盲注等):直接执行,无需用户选择 if len(legal_actions) == 1: cp = state.current_player() desc = state.action_to_string(cp, legal_actions[0]) print(f" {self.name} (P{self.player_id}) 强制执行: {desc}") return legal_actions[0] # ── 渲染当前状态(使用引擎原生信息) ── d = state.to_dict() cp = state.current_player() # 从公共牌数量推断轮次(to_dict()["round"] 在此引擎中为 None) board_str = d.get("board_cards", "") num_board = len(board_str) // 2 if board_str else 0 street = {0: 0, 3: 1, 4: 2, 5: 3}.get(num_board, 0) street_name = STREET_NAMES[street] hands = d.get("player_hands", ["", ""]) my_hand = parse_hand_str(hands[self.player_id]) board_display = parse_hand_str(board_str) if board_str else "(无)" contributions = d.get("player_contributions", [0, 0]) pot = d.get("pot_size", 0) starting_stacks = d.get("starting_stacks", [20000, 20000]) stacks = [starting_stacks[i] - contributions[i] for i in range(2)] print() print("─" * 50) print(f" 轮次: {street_name} | 底池: {pot}") print(f" P0 筹码: {stacks[0]} | P1 筹码: {stacks[1]}") print(f" 你的底牌: {my_hand}") print(f" 公共牌: {board_display}") print("─" * 50) # ── 分类动作 ── can_fold = 0 in legal_actions can_call = 1 in legal_actions raise_actions = [a for a in legal_actions if a > 1] # Call 金额 = 对方贡献额 - 自己贡献额(已是 total bet-to 语义中的跟注) call_amount = max(contributions[1 - self.player_id] - contributions[self.player_id], 0) call_label = f"跟注 {call_amount}" if call_amount > 0 else "过牌 (Check)" # ── 构造提示 ── options = [] if can_fold: options.append("F = 弃牌 (Fold)") if can_call: options.append(f"C = {call_label}") if raise_actions: min_bet = min(raise_actions) max_bet = max(raise_actions) options.append(f"输入 {min_bet}~{max_bet} 之间的数字 = 加注到该金额") print(" 可选操作:") for opt in options: print(f" {opt}") # ── 输入循环 ── while True: try: raw = input(" 请输入: ").strip() except KeyboardInterrupt: print("\n 退出游戏。") sys.exit(0) if not raw: print(" 输入不能为空,请重试。") continue upper = raw.upper() # 弃牌 if upper == 'F': if can_fold: return 0 print(" 当前不可弃牌。") continue # 跟注/过牌 if upper == 'C': if can_call: return 1 print(" 当前不可跟注/过牌。") continue # 数字:加注金额 try: target = int(raw) except ValueError: print(" 请输入 F、C 或数字。") continue if not raise_actions: print(" 当前不可加注。") continue if target < min_bet or target > max_bet: print(f" 加注金额须在 {min_bet}~{max_bet} 之间。") continue # 在合法动作中找到最接近的 action ID closest = min(raise_actions, key=lambda a: abs(a - target)) if closest != target: print(f" {target} 不是合法加注额,已自动修正为最接近的 {closest}") return closest # ───────────────────── 对局核心流程 ───────────────────── def run_game(game, players: list): """ 执行一局对战,使用绝对标准的 OpenSpiel Game Loop。 Args: game: OpenSpiel 游戏实例 players: [PlayerBase, PlayerBase],分别对应 P0 和 P1 Returns: (terminal_state, history): 终局状态和牌谱记录 """ state = game.new_initial_state() history = [] while not state.is_terminal(): # ── Chance Node: 发牌 ── if state.is_chance_node(): outcomes = state.chance_outcomes() action_list, prob_list = zip(*outcomes) chance_action = random.choices(action_list, weights=prob_list, k=1)[0] # 用引擎原生方法记录发牌 desc = state.action_to_string(-1, chance_action) history.append(desc) state.apply_action(chance_action) continue # ── Player Node: 决策 ── current_player = state.current_player() player = players[current_player] action = player.choose(state) # 用引擎原生方法记录玩家动作(绝对不会出错) desc = state.action_to_string(current_player, action) history.append(f"P{current_player} ({player.name}): {desc}") state.apply_action(action) return state, history # ───────────────────── 对局结算 ───────────────────── def show_result(state, players: list, history: list): """打印对局结算信息。""" returns = state.returns() d = state.to_dict() print() print("=" * 50) print(" 牌局结束 (Showdown)") print("=" * 50) # 双方底牌 hands = d.get("player_hands", ["", ""]) for pid in range(2): hand_display = parse_hand_str(hands[pid]) if hands[pid] else "(未显示)" player_name = players[pid].name print(f" P{pid} ({player_name}) 底牌: {hand_display}") # 公共牌 board_str = d.get("board_cards", "") board_display = parse_hand_str(board_str) if board_str else "(无)" print(f" 公共牌: {board_display}") # 牌谱(使用引擎原生 action_to_string 记录,绝对准确) print() print(" --- 牌谱 (Hand History) ---") for entry in history: print(f" {entry}") # 收益 print() for pid in range(2): ret = returns[pid] sign = "+" if ret >= 0 else "" player_name = players[pid].name print(f" P{pid} ({player_name}) 收益: {sign}{ret:.0f}") print("=" * 50) # ───────────────────── 主程序 ───────────────────── def main(): print() print("=" * 50) print(" Poker Arena — 德州扑克命令行对战") print("=" * 50) print() print(" 请选择模式:") print(" 1: 人类 vs 人类") print(" 2: 人类 vs AI (人类 P0, AI P1)") print(" 3: AI vs 人类 (AI P0, 人类 P1)") print(" 4: AI vs AI (观战)") print() while True: try: mode_str = input(" 输入模式编号 (1-4): ").strip() mode = int(mode_str) if 1 <= mode <= 4: break print(" 请输入 1-4 之间的数字。") except ValueError: print(" 请输入数字。") except KeyboardInterrupt: print("\n 退出。") sys.exit(0) # 根据模式创建玩家 players = [None, None] if mode == 1: players[0] = HumanPlayer(0, "玩家1") players[1] = HumanPlayer(1, "玩家2") elif mode == 2: players[0] = HumanPlayer(0, "你") players[1] = AIPlayer(1, "AI") elif mode == 3: players[0] = AIPlayer(0, "AI") players[1] = HumanPlayer(1, "你") elif mode == 4: players[0] = AIPlayer(0, "AI-0") players[1] = AIPlayer(1, "AI-1") print() # 初始化游戏引擎 game = pyspiel.load_game(HUNL_FULLGAME_STRING) # 对局循环 game_num = 0 while True: game_num += 1 print() print("#" * 50) print(f" 第 {game_num} 局") print("#" * 50) terminal_state, history = run_game(game, players) show_result(terminal_state, players, history) # 是否继续 print() try: cont = input(" 按回车继续下一局,输入 q 退出: ").strip() except KeyboardInterrupt: print("\n 退出。") break if cont.lower() == 'q': print(" 再见!") break if __name__ == "__main__": main()