Files
new/play_tool/play_arena.py
2026-05-06 17:36:51 +08:00

538 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
play_arena.py — 命令行德州扑克对战环境 (Poker Arena)
完全基于 OpenSpiel 原生 Game LoopPython 层仅做薄 UI 翻译:
- HumanPlayer: 将用户输入 (F/C/数字) 映射到 state.legal_actions() 中的 action ID
- AIPlayer: 将 CFRNetwork 的 0-4 离散动作经 BetTranslator 映射为引擎 action ID
支持模式:
1: 人类 vs 人类
2: 人类 vs AI (人类 P0, AI P1)
3: AI vs 人类 (AI P0, 人类 P1)
4: AI vs AI (观战)
用法:
cd poker/
python play_tool/play_arena.py
"""
import os
import sys
import random
# ── 导入父目录模块 ──
_POKER_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if _POKER_DIR not in sys.path:
sys.path.insert(0, _POKER_DIR)
import torch
import pyspiel
from env_adapter import (
HUNL_FULLGAME_STRING,
BetTranslator,
extract_env_state,
CFR_ACTIONS,
NUM_CFR_ACTIONS,
STACK_NORMALIZE,
STREET_NAMES,
)
from card_model.config import PAD_TOKEN, BOARD_SIZE
from card_model.data_generator import extract_cards_from_state
from card_model.model import CardModel
from cfr_net import CFRNetwork, CARD_DIM, ENV_DIM, NUM_ACTIONS
# ───────────────────── 常量 ─────────────────────
STREET_NORMALIZE = 3.0
CARD_MODEL_CHECKPOINT = os.path.join(_POKER_DIR, "card_model", "data", "best_card_model.pt")
CFR_NET_CHECKPOINT = os.path.join(_POKER_DIR, "botzone_cfr_net.pt")
_SUIT_SYMBOLS = {'c': '\u2663', 'd': '\u2666', 'h': '\u2665', 's': '\u2660'}
# ───────────────────── 牌面显示工具 ─────────────────────
def card_str_to_display(card_str: str) -> str:
"""将引擎的2字符牌面 (如 'Ac', 'Js') 转为带花色符号的显示。"""
if len(card_str) != 2:
return card_str
rank_ch, suit_ch = card_str[0], card_str[1]
return f"{rank_ch}{_SUIT_SYMBOLS.get(suit_ch, suit_ch)}"
def parse_hand_str(hand_str: str) -> str:
"""将引擎返回的连体牌面字符串 (如 'AcJs') 拆分并美化。"""
if not hand_str:
return ""
cards = []
i = 0
while i + 1 < len(hand_str):
cards.append(card_str_to_display(hand_str[i] + hand_str[i + 1]))
i += 2
return " ".join(cards)
# ───────────────────── 特征构建 ─────────────────────
def build_env_features(env_info: dict) -> torch.Tensor:
"""将 extract_env_state 返回的字典归一化为 5 维 env_features Tensor。"""
features = [
env_info["pot"] / STACK_NORMALIZE,
env_info["p0_stack"] / STACK_NORMALIZE,
env_info["p1_stack"] / STACK_NORMALIZE,
env_info["street"] / STREET_NORMALIZE,
float(env_info["position"]),
]
return torch.tensor(features, dtype=torch.float32)
def build_card_features(card_model: CardModel, state) -> torch.Tensor:
"""使用 CardModel 从当前 state 提取 50 维胜率直方图。"""
hole_cards, board_cards = extract_cards_from_state(state)
model_device = next(card_model.parameters()).device
x_hole = torch.tensor([hole_cards], dtype=torch.int64, device=model_device)
padded_board = board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards))
x_board = torch.tensor([padded_board], dtype=torch.int64, device=model_device)
with torch.no_grad():
_, pred_histogram = card_model(x_hole, x_board)
return pred_histogram.squeeze(0).cpu()
# ───────────────────── Player 基类 ─────────────────────
class PlayerBase:
"""玩家基类。choose() 返回引擎原生 action ID。"""
def __init__(self, player_id: int, name: str):
self.player_id = player_id
self.name = name
def choose(self, state) -> int:
"""
根据当前状态返回引擎原生 action ID。
Args:
state: OpenSpiel State 对象
Returns:
int: 可直接传给 state.apply_action() 的引擎动作 ID
"""
raise NotImplementedError
# ───────────────────── AIPlayer ─────────────────────
class AIPlayer(PlayerBase):
"""AI 玩家CardModel + CFRNetwork avg_strategy 采样 → BetTranslator 映射。"""
def __init__(self, player_id: int, name: str = "AI"):
super().__init__(player_id, name)
# 加载 CardModel
self.card_model = CardModel()
if os.path.isfile(CARD_MODEL_CHECKPOINT):
ckpt = torch.load(CARD_MODEL_CHECKPOINT, map_location="cpu", weights_only=False)
self.card_model.load_state_dict(
ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt
)
print(f"[CardModel] 已加载权重: {CARD_MODEL_CHECKPOINT}")
else:
print(f"[CardModel] 警告: 未找到权重 {CARD_MODEL_CHECKPOINT},使用随机初始化")
self.card_model.eval()
# 加载 CFRNetwork
self.cfr_net = CFRNetwork(
card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS,
)
if os.path.isfile(CFR_NET_CHECKPOINT):
ckpt = torch.load(CFR_NET_CHECKPOINT, map_location="cpu", weights_only=False)
self.cfr_net.load_state_dict(
ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt
)
print(f"[CFRNetwork] 已加载权重: {CFR_NET_CHECKPOINT}")
else:
print(f"[CFRNetwork] 警告: 未找到权重 {CFR_NET_CHECKPOINT},使用随机初始化")
self.cfr_net.eval()
# BetTranslatorCFR 离散动作 → 引擎 action ID
self.translator = BetTranslator()
def choose(self, state) -> int:
legal_actions = state.legal_actions()
# 强制动作(盲注等):直接返回
if len(legal_actions) == 1:
return legal_actions[0]
# 提取特征,送入 CFRNetwork
env_info = extract_env_state(state)
env_features = build_env_features(env_info)
card_features = build_card_features(self.card_model, state)
legal_mask = env_info["legal_mask"]
card_input = card_features.unsqueeze(0) # [1, 50]
env_input = env_features.unsqueeze(0) # [1, 5]
legal_mask_tensor = torch.tensor([legal_mask], dtype=torch.float32)
with torch.no_grad():
_, avg_strategy = self.cfr_net.get_strategy(
card_input, env_input, legal_mask_tensor
)
strategy = avg_strategy.squeeze(0).cpu().tolist()
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
if not legal_indices:
legal_indices = [1]
probs = [strategy[i] for i in legal_indices]
prob_sum = sum(probs)
if prob_sum > 0:
probs = [p / prob_sum for p in probs]
else:
probs = [1.0 / len(legal_indices)] * len(legal_indices)
# 按 avg_strategy 概率采样一个 CFR 离散动作
#chosen_cfr_idx = random.choices(legal_indices, weights=probs, k=1)[0]
print(f"\n [AI 大脑读取中...] 当前合法 CFR 动作概率:")
for idx, p in zip(legal_indices, probs):
print(f" - {CFR_ACTIONS[idx]:<10}: {p*100:5.1f}%")
# === 核心修改 2过滤神经网络底噪 (降噪器) ===
# 将概率低于 3% 的动作直接置为 0防止 1% 的随机发疯 All-in
NOISE_THRESHOLD = 0.03
filtered_probs = [p if p > NOISE_THRESHOLD else 0.0 for p in probs]
prob_sum = sum(filtered_probs)
if prob_sum > 0:
# 重新归一化
filtered_probs = [p / prob_sum for p in filtered_probs]
else:
# 如果全部低于阈值(极端情况),就选原本概率最大的那个 (贪心)
best_idx = probs.index(max(probs))
filtered_probs = [1.0 if i == best_idx else 0.0 for i in range(len(probs))]
# 按过滤后的纯净概率采样
chosen_cfr_idx = random.choices(legal_indices, weights=filtered_probs, k=1)[0]
# 通过 BetTranslator 将 CFR 动作映射为引擎原生 action ID
engine_action = self.translator.cfr_to_engine_action(state, chosen_cfr_idx)
# === 安全防线: 防止 CALL 被映射为 raise ===
if chosen_cfr_idx == 1 and engine_action > 1:
engine_action = 1 # 强制 CALL
# === 安全防线: 防止比例加注被映射为 ALL-IN ===
if chosen_cfr_idx in (2, 3):
bet_actions_guard = [a for a in state.legal_actions() if a > 1]
if bet_actions_guard and engine_action >= bet_actions_guard[-1]:
engine_action = 1 # fallback 到 CALL
cp = state.current_player()
action_desc = state.action_to_string(cp, engine_action)
print(f" {self.name} (P{self.player_id}) 选择了 {action_desc}")
return engine_action
# ───────────────────── HumanPlayer ─────────────────────
class HumanPlayer(PlayerBase):
"""
人类玩家:直接使用 state.legal_actions() 作为绝对真理。
动作语义fullgame 模式下):
- action 0 = Fold
- action 1 = Call / Check
- action > 1 = 加注到该数值(总贡献额 bet-to
用户输入:
- F = 弃牌
- C = 跟注/过牌
- 数字 = 加注到该金额(自动映射到最接近的合法 action ID
"""
def __init__(self, player_id: int, name: str = ""):
super().__init__(player_id, name)
def choose(self, state) -> int:
legal_actions = state.legal_actions()
# 强制动作(盲注等):直接执行,无需用户选择
if len(legal_actions) == 1:
cp = state.current_player()
desc = state.action_to_string(cp, legal_actions[0])
print(f" {self.name} (P{self.player_id}) 强制执行: {desc}")
return legal_actions[0]
# ── 渲染当前状态(使用引擎原生信息) ──
d = state.to_dict()
cp = state.current_player()
# 从公共牌数量推断轮次to_dict()["round"] 在此引擎中为 None
board_str = d.get("board_cards", "")
num_board = len(board_str) // 2 if board_str else 0
street = {0: 0, 3: 1, 4: 2, 5: 3}.get(num_board, 0)
street_name = STREET_NAMES[street]
hands = d.get("player_hands", ["", ""])
my_hand = parse_hand_str(hands[self.player_id])
board_display = parse_hand_str(board_str) if board_str else "(无)"
contributions = d.get("player_contributions", [0, 0])
pot = d.get("pot_size", 0)
starting_stacks = d.get("starting_stacks", [20000, 20000])
stacks = [starting_stacks[i] - contributions[i] for i in range(2)]
print()
print("" * 50)
print(f" 轮次: {street_name} | 底池: {pot}")
print(f" P0 筹码: {stacks[0]} | P1 筹码: {stacks[1]}")
print(f" 你的底牌: {my_hand}")
print(f" 公共牌: {board_display}")
print("" * 50)
# ── 分类动作 ──
can_fold = 0 in legal_actions
can_call = 1 in legal_actions
raise_actions = [a for a in legal_actions if a > 1]
# Call 金额 = 对方贡献额 - 自己贡献额(已是 total bet-to 语义中的跟注)
call_amount = max(contributions[1 - self.player_id] - contributions[self.player_id], 0)
call_label = f"跟注 {call_amount}" if call_amount > 0 else "过牌 (Check)"
# ── 构造提示 ──
options = []
if can_fold:
options.append("F = 弃牌 (Fold)")
if can_call:
options.append(f"C = {call_label}")
if raise_actions:
min_bet = min(raise_actions)
max_bet = max(raise_actions)
options.append(f"输入 {min_bet}~{max_bet} 之间的数字 = 加注到该金额")
print(" 可选操作:")
for opt in options:
print(f" {opt}")
# ── 输入循环 ──
while True:
try:
raw = input(" 请输入: ").strip()
except KeyboardInterrupt:
print("\n 退出游戏。")
sys.exit(0)
if not raw:
print(" 输入不能为空,请重试。")
continue
upper = raw.upper()
# 弃牌
if upper == 'F':
if can_fold:
return 0
print(" 当前不可弃牌。")
continue
# 跟注/过牌
if upper == 'C':
if can_call:
return 1
print(" 当前不可跟注/过牌。")
continue
# 数字:加注金额
try:
target = int(raw)
except ValueError:
print(" 请输入 F、C 或数字。")
continue
if not raise_actions:
print(" 当前不可加注。")
continue
if target < min_bet or target > max_bet:
print(f" 加注金额须在 {min_bet}~{max_bet} 之间。")
continue
# 在合法动作中找到最接近的 action ID
closest = min(raise_actions, key=lambda a: abs(a - target))
if closest != target:
print(f" {target} 不是合法加注额,已自动修正为最接近的 {closest}")
return closest
# ───────────────────── 对局核心流程 ─────────────────────
def run_game(game, players: list):
"""
执行一局对战,使用绝对标准的 OpenSpiel Game Loop。
Args:
game: OpenSpiel 游戏实例
players: [PlayerBase, PlayerBase],分别对应 P0 和 P1
Returns:
(terminal_state, history): 终局状态和牌谱记录
"""
state = game.new_initial_state()
history = []
while not state.is_terminal():
# ── Chance Node: 发牌 ──
if state.is_chance_node():
outcomes = state.chance_outcomes()
action_list, prob_list = zip(*outcomes)
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
# 用引擎原生方法记录发牌
desc = state.action_to_string(-1, chance_action)
history.append(desc)
state.apply_action(chance_action)
continue
# ── Player Node: 决策 ──
current_player = state.current_player()
player = players[current_player]
action = player.choose(state)
# 用引擎原生方法记录玩家动作(绝对不会出错)
desc = state.action_to_string(current_player, action)
history.append(f"P{current_player} ({player.name}): {desc}")
state.apply_action(action)
return state, history
# ───────────────────── 对局结算 ─────────────────────
def show_result(state, players: list, history: list):
"""打印对局结算信息。"""
returns = state.returns()
d = state.to_dict()
print()
print("=" * 50)
print(" 牌局结束 (Showdown)")
print("=" * 50)
# 双方底牌
hands = d.get("player_hands", ["", ""])
for pid in range(2):
hand_display = parse_hand_str(hands[pid]) if hands[pid] else "(未显示)"
player_name = players[pid].name
print(f" P{pid} ({player_name}) 底牌: {hand_display}")
# 公共牌
board_str = d.get("board_cards", "")
board_display = parse_hand_str(board_str) if board_str else "(无)"
print(f" 公共牌: {board_display}")
# 牌谱(使用引擎原生 action_to_string 记录,绝对准确)
print()
print(" --- 牌谱 (Hand History) ---")
for entry in history:
print(f" {entry}")
# 收益
print()
for pid in range(2):
ret = returns[pid]
sign = "+" if ret >= 0 else ""
player_name = players[pid].name
print(f" P{pid} ({player_name}) 收益: {sign}{ret:.0f}")
print("=" * 50)
# ───────────────────── 主程序 ─────────────────────
def main():
print()
print("=" * 50)
print(" Poker Arena — 德州扑克命令行对战")
print("=" * 50)
print()
print(" 请选择模式:")
print(" 1: 人类 vs 人类")
print(" 2: 人类 vs AI (人类 P0, AI P1)")
print(" 3: AI vs 人类 (AI P0, 人类 P1)")
print(" 4: AI vs AI (观战)")
print()
while True:
try:
mode_str = input(" 输入模式编号 (1-4): ").strip()
mode = int(mode_str)
if 1 <= mode <= 4:
break
print(" 请输入 1-4 之间的数字。")
except ValueError:
print(" 请输入数字。")
except KeyboardInterrupt:
print("\n 退出。")
sys.exit(0)
# 根据模式创建玩家
players = [None, None]
if mode == 1:
players[0] = HumanPlayer(0, "玩家1")
players[1] = HumanPlayer(1, "玩家2")
elif mode == 2:
players[0] = HumanPlayer(0, "")
players[1] = AIPlayer(1, "AI")
elif mode == 3:
players[0] = AIPlayer(0, "AI")
players[1] = HumanPlayer(1, "")
elif mode == 4:
players[0] = AIPlayer(0, "AI-0")
players[1] = AIPlayer(1, "AI-1")
print()
# 初始化游戏引擎
game = pyspiel.load_game(HUNL_FULLGAME_STRING)
# 对局循环
game_num = 0
while True:
game_num += 1
print()
print("#" * 50)
print(f"{game_num}")
print("#" * 50)
terminal_state, history = run_game(game, players)
show_result(terminal_state, players, history)
# 是否继续
print()
try:
cont = input(" 按回车继续下一局,输入 q 退出: ").strip()
except KeyboardInterrupt:
print("\n 退出。")
break
if cont.lower() == 'q':
print(" 再见!")
break
if __name__ == "__main__":
main()