538 lines
18 KiB
Python
538 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
play_arena.py — 命令行德州扑克对战环境 (Poker Arena)
|
||
|
||
完全基于 OpenSpiel 原生 Game Loop,Python 层仅做薄 UI 翻译:
|
||
- HumanPlayer: 将用户输入 (F/C/数字) 映射到 state.legal_actions() 中的 action ID
|
||
- AIPlayer: 将 CFRNetwork 的 0-4 离散动作经 BetTranslator 映射为引擎 action ID
|
||
|
||
支持模式:
|
||
1: 人类 vs 人类
|
||
2: 人类 vs AI (人类 P0, AI P1)
|
||
3: AI vs 人类 (AI P0, 人类 P1)
|
||
4: AI vs AI (观战)
|
||
|
||
用法:
|
||
cd poker/
|
||
python play_tool/play_arena.py
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import random
|
||
|
||
# ── 导入父目录模块 ──
|
||
_POKER_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||
if _POKER_DIR not in sys.path:
|
||
sys.path.insert(0, _POKER_DIR)
|
||
|
||
import torch
|
||
import pyspiel
|
||
|
||
from env_adapter import (
|
||
HUNL_FULLGAME_STRING,
|
||
BetTranslator,
|
||
extract_env_state,
|
||
CFR_ACTIONS,
|
||
NUM_CFR_ACTIONS,
|
||
STACK_NORMALIZE,
|
||
STREET_NAMES,
|
||
)
|
||
from card_model.config import PAD_TOKEN, BOARD_SIZE
|
||
from card_model.data_generator import extract_cards_from_state
|
||
from card_model.model import CardModel
|
||
from cfr_net import CFRNetwork, CARD_DIM, ENV_DIM, NUM_ACTIONS
|
||
|
||
|
||
# ───────────────────── 常量 ─────────────────────
|
||
|
||
STREET_NORMALIZE = 3.0
|
||
|
||
CARD_MODEL_CHECKPOINT = os.path.join(_POKER_DIR, "card_model", "data", "best_card_model.pt")
|
||
CFR_NET_CHECKPOINT = os.path.join(_POKER_DIR, "botzone_cfr_net.pt")
|
||
|
||
_SUIT_SYMBOLS = {'c': '\u2663', 'd': '\u2666', 'h': '\u2665', 's': '\u2660'}
|
||
|
||
|
||
# ───────────────────── 牌面显示工具 ─────────────────────
|
||
|
||
def card_str_to_display(card_str: str) -> str:
|
||
"""将引擎的2字符牌面 (如 'Ac', 'Js') 转为带花色符号的显示。"""
|
||
if len(card_str) != 2:
|
||
return card_str
|
||
rank_ch, suit_ch = card_str[0], card_str[1]
|
||
return f"{rank_ch}{_SUIT_SYMBOLS.get(suit_ch, suit_ch)}"
|
||
|
||
|
||
def parse_hand_str(hand_str: str) -> str:
|
||
"""将引擎返回的连体牌面字符串 (如 'AcJs') 拆分并美化。"""
|
||
if not hand_str:
|
||
return ""
|
||
cards = []
|
||
i = 0
|
||
while i + 1 < len(hand_str):
|
||
cards.append(card_str_to_display(hand_str[i] + hand_str[i + 1]))
|
||
i += 2
|
||
return " ".join(cards)
|
||
|
||
|
||
# ───────────────────── 特征构建 ─────────────────────
|
||
|
||
def build_env_features(env_info: dict) -> torch.Tensor:
|
||
"""将 extract_env_state 返回的字典归一化为 5 维 env_features Tensor。"""
|
||
features = [
|
||
env_info["pot"] / STACK_NORMALIZE,
|
||
env_info["p0_stack"] / STACK_NORMALIZE,
|
||
env_info["p1_stack"] / STACK_NORMALIZE,
|
||
env_info["street"] / STREET_NORMALIZE,
|
||
float(env_info["position"]),
|
||
]
|
||
return torch.tensor(features, dtype=torch.float32)
|
||
|
||
|
||
def build_card_features(card_model: CardModel, state) -> torch.Tensor:
|
||
"""使用 CardModel 从当前 state 提取 50 维胜率直方图。"""
|
||
hole_cards, board_cards = extract_cards_from_state(state)
|
||
|
||
model_device = next(card_model.parameters()).device
|
||
x_hole = torch.tensor([hole_cards], dtype=torch.int64, device=model_device)
|
||
padded_board = board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards))
|
||
x_board = torch.tensor([padded_board], dtype=torch.int64, device=model_device)
|
||
|
||
with torch.no_grad():
|
||
_, pred_histogram = card_model(x_hole, x_board)
|
||
|
||
return pred_histogram.squeeze(0).cpu()
|
||
|
||
|
||
# ───────────────────── Player 基类 ─────────────────────
|
||
|
||
class PlayerBase:
|
||
"""玩家基类。choose() 返回引擎原生 action ID。"""
|
||
|
||
def __init__(self, player_id: int, name: str):
|
||
self.player_id = player_id
|
||
self.name = name
|
||
|
||
def choose(self, state) -> int:
|
||
"""
|
||
根据当前状态返回引擎原生 action ID。
|
||
|
||
Args:
|
||
state: OpenSpiel State 对象
|
||
|
||
Returns:
|
||
int: 可直接传给 state.apply_action() 的引擎动作 ID
|
||
"""
|
||
raise NotImplementedError
|
||
|
||
|
||
# ───────────────────── AIPlayer ─────────────────────
|
||
|
||
class AIPlayer(PlayerBase):
|
||
"""AI 玩家:CardModel + CFRNetwork avg_strategy 采样 → BetTranslator 映射。"""
|
||
|
||
def __init__(self, player_id: int, name: str = "AI"):
|
||
super().__init__(player_id, name)
|
||
|
||
# 加载 CardModel
|
||
self.card_model = CardModel()
|
||
if os.path.isfile(CARD_MODEL_CHECKPOINT):
|
||
ckpt = torch.load(CARD_MODEL_CHECKPOINT, map_location="cpu", weights_only=False)
|
||
self.card_model.load_state_dict(
|
||
ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt
|
||
)
|
||
print(f"[CardModel] 已加载权重: {CARD_MODEL_CHECKPOINT}")
|
||
else:
|
||
print(f"[CardModel] 警告: 未找到权重 {CARD_MODEL_CHECKPOINT},使用随机初始化")
|
||
self.card_model.eval()
|
||
|
||
# 加载 CFRNetwork
|
||
self.cfr_net = CFRNetwork(
|
||
card_dim=CARD_DIM, env_dim=ENV_DIM, num_actions=NUM_ACTIONS,
|
||
)
|
||
if os.path.isfile(CFR_NET_CHECKPOINT):
|
||
ckpt = torch.load(CFR_NET_CHECKPOINT, map_location="cpu", weights_only=False)
|
||
self.cfr_net.load_state_dict(
|
||
ckpt["model_state_dict"] if "model_state_dict" in ckpt else ckpt
|
||
)
|
||
print(f"[CFRNetwork] 已加载权重: {CFR_NET_CHECKPOINT}")
|
||
else:
|
||
print(f"[CFRNetwork] 警告: 未找到权重 {CFR_NET_CHECKPOINT},使用随机初始化")
|
||
self.cfr_net.eval()
|
||
|
||
# BetTranslator:CFR 离散动作 → 引擎 action ID
|
||
self.translator = BetTranslator()
|
||
|
||
def choose(self, state) -> int:
|
||
legal_actions = state.legal_actions()
|
||
|
||
# 强制动作(盲注等):直接返回
|
||
if len(legal_actions) == 1:
|
||
return legal_actions[0]
|
||
|
||
# 提取特征,送入 CFRNetwork
|
||
env_info = extract_env_state(state)
|
||
env_features = build_env_features(env_info)
|
||
card_features = build_card_features(self.card_model, state)
|
||
legal_mask = env_info["legal_mask"]
|
||
|
||
card_input = card_features.unsqueeze(0) # [1, 50]
|
||
env_input = env_features.unsqueeze(0) # [1, 5]
|
||
legal_mask_tensor = torch.tensor([legal_mask], dtype=torch.float32)
|
||
|
||
with torch.no_grad():
|
||
_, avg_strategy = self.cfr_net.get_strategy(
|
||
card_input, env_input, legal_mask_tensor
|
||
)
|
||
|
||
strategy = avg_strategy.squeeze(0).cpu().tolist()
|
||
legal_indices = [i for i, m in enumerate(legal_mask) if m == 1]
|
||
if not legal_indices:
|
||
legal_indices = [1]
|
||
|
||
probs = [strategy[i] for i in legal_indices]
|
||
prob_sum = sum(probs)
|
||
if prob_sum > 0:
|
||
probs = [p / prob_sum for p in probs]
|
||
else:
|
||
probs = [1.0 / len(legal_indices)] * len(legal_indices)
|
||
|
||
# 按 avg_strategy 概率采样一个 CFR 离散动作
|
||
#chosen_cfr_idx = random.choices(legal_indices, weights=probs, k=1)[0]
|
||
|
||
print(f"\n [AI 大脑读取中...] 当前合法 CFR 动作概率:")
|
||
for idx, p in zip(legal_indices, probs):
|
||
print(f" - {CFR_ACTIONS[idx]:<10}: {p*100:5.1f}%")
|
||
|
||
# === 核心修改 2:过滤神经网络底噪 (降噪器) ===
|
||
# 将概率低于 3% 的动作直接置为 0,防止 1% 的随机发疯 All-in
|
||
NOISE_THRESHOLD = 0.03
|
||
filtered_probs = [p if p > NOISE_THRESHOLD else 0.0 for p in probs]
|
||
|
||
prob_sum = sum(filtered_probs)
|
||
if prob_sum > 0:
|
||
# 重新归一化
|
||
filtered_probs = [p / prob_sum for p in filtered_probs]
|
||
else:
|
||
# 如果全部低于阈值(极端情况),就选原本概率最大的那个 (贪心)
|
||
best_idx = probs.index(max(probs))
|
||
filtered_probs = [1.0 if i == best_idx else 0.0 for i in range(len(probs))]
|
||
|
||
# 按过滤后的纯净概率采样
|
||
chosen_cfr_idx = random.choices(legal_indices, weights=filtered_probs, k=1)[0]
|
||
|
||
# 通过 BetTranslator 将 CFR 动作映射为引擎原生 action ID
|
||
engine_action = self.translator.cfr_to_engine_action(state, chosen_cfr_idx)
|
||
|
||
# === 安全防线: 防止 CALL 被映射为 raise ===
|
||
if chosen_cfr_idx == 1 and engine_action > 1:
|
||
engine_action = 1 # 强制 CALL
|
||
|
||
# === 安全防线: 防止比例加注被映射为 ALL-IN ===
|
||
if chosen_cfr_idx in (2, 3):
|
||
bet_actions_guard = [a for a in state.legal_actions() if a > 1]
|
||
if bet_actions_guard and engine_action >= bet_actions_guard[-1]:
|
||
engine_action = 1 # fallback 到 CALL
|
||
|
||
cp = state.current_player()
|
||
action_desc = state.action_to_string(cp, engine_action)
|
||
print(f" {self.name} (P{self.player_id}) 选择了 {action_desc}")
|
||
|
||
return engine_action
|
||
|
||
|
||
# ───────────────────── HumanPlayer ─────────────────────
|
||
|
||
class HumanPlayer(PlayerBase):
|
||
"""
|
||
人类玩家:直接使用 state.legal_actions() 作为绝对真理。
|
||
|
||
动作语义(fullgame 模式下):
|
||
- action 0 = Fold
|
||
- action 1 = Call / Check
|
||
- action > 1 = 加注到该数值(总贡献额 bet-to)
|
||
|
||
用户输入:
|
||
- F = 弃牌
|
||
- C = 跟注/过牌
|
||
- 数字 = 加注到该金额(自动映射到最接近的合法 action ID)
|
||
"""
|
||
|
||
def __init__(self, player_id: int, name: str = "你"):
|
||
super().__init__(player_id, name)
|
||
|
||
def choose(self, state) -> int:
|
||
legal_actions = state.legal_actions()
|
||
|
||
# 强制动作(盲注等):直接执行,无需用户选择
|
||
if len(legal_actions) == 1:
|
||
cp = state.current_player()
|
||
desc = state.action_to_string(cp, legal_actions[0])
|
||
print(f" {self.name} (P{self.player_id}) 强制执行: {desc}")
|
||
return legal_actions[0]
|
||
|
||
# ── 渲染当前状态(使用引擎原生信息) ──
|
||
d = state.to_dict()
|
||
cp = state.current_player()
|
||
|
||
# 从公共牌数量推断轮次(to_dict()["round"] 在此引擎中为 None)
|
||
board_str = d.get("board_cards", "")
|
||
num_board = len(board_str) // 2 if board_str else 0
|
||
street = {0: 0, 3: 1, 4: 2, 5: 3}.get(num_board, 0)
|
||
street_name = STREET_NAMES[street]
|
||
|
||
hands = d.get("player_hands", ["", ""])
|
||
my_hand = parse_hand_str(hands[self.player_id])
|
||
board_display = parse_hand_str(board_str) if board_str else "(无)"
|
||
|
||
contributions = d.get("player_contributions", [0, 0])
|
||
pot = d.get("pot_size", 0)
|
||
starting_stacks = d.get("starting_stacks", [20000, 20000])
|
||
stacks = [starting_stacks[i] - contributions[i] for i in range(2)]
|
||
|
||
print()
|
||
print("─" * 50)
|
||
print(f" 轮次: {street_name} | 底池: {pot}")
|
||
print(f" P0 筹码: {stacks[0]} | P1 筹码: {stacks[1]}")
|
||
print(f" 你的底牌: {my_hand}")
|
||
print(f" 公共牌: {board_display}")
|
||
print("─" * 50)
|
||
|
||
# ── 分类动作 ──
|
||
can_fold = 0 in legal_actions
|
||
can_call = 1 in legal_actions
|
||
raise_actions = [a for a in legal_actions if a > 1]
|
||
|
||
# Call 金额 = 对方贡献额 - 自己贡献额(已是 total bet-to 语义中的跟注)
|
||
call_amount = max(contributions[1 - self.player_id] - contributions[self.player_id], 0)
|
||
call_label = f"跟注 {call_amount}" if call_amount > 0 else "过牌 (Check)"
|
||
|
||
# ── 构造提示 ──
|
||
options = []
|
||
if can_fold:
|
||
options.append("F = 弃牌 (Fold)")
|
||
if can_call:
|
||
options.append(f"C = {call_label}")
|
||
|
||
if raise_actions:
|
||
min_bet = min(raise_actions)
|
||
max_bet = max(raise_actions)
|
||
options.append(f"输入 {min_bet}~{max_bet} 之间的数字 = 加注到该金额")
|
||
|
||
print(" 可选操作:")
|
||
for opt in options:
|
||
print(f" {opt}")
|
||
|
||
# ── 输入循环 ──
|
||
while True:
|
||
try:
|
||
raw = input(" 请输入: ").strip()
|
||
except KeyboardInterrupt:
|
||
print("\n 退出游戏。")
|
||
sys.exit(0)
|
||
|
||
if not raw:
|
||
print(" 输入不能为空,请重试。")
|
||
continue
|
||
|
||
upper = raw.upper()
|
||
|
||
# 弃牌
|
||
if upper == 'F':
|
||
if can_fold:
|
||
return 0
|
||
print(" 当前不可弃牌。")
|
||
continue
|
||
|
||
# 跟注/过牌
|
||
if upper == 'C':
|
||
if can_call:
|
||
return 1
|
||
print(" 当前不可跟注/过牌。")
|
||
continue
|
||
|
||
# 数字:加注金额
|
||
try:
|
||
target = int(raw)
|
||
except ValueError:
|
||
print(" 请输入 F、C 或数字。")
|
||
continue
|
||
|
||
if not raise_actions:
|
||
print(" 当前不可加注。")
|
||
continue
|
||
|
||
if target < min_bet or target > max_bet:
|
||
print(f" 加注金额须在 {min_bet}~{max_bet} 之间。")
|
||
continue
|
||
|
||
# 在合法动作中找到最接近的 action ID
|
||
closest = min(raise_actions, key=lambda a: abs(a - target))
|
||
if closest != target:
|
||
print(f" {target} 不是合法加注额,已自动修正为最接近的 {closest}")
|
||
return closest
|
||
|
||
|
||
# ───────────────────── 对局核心流程 ─────────────────────
|
||
|
||
def run_game(game, players: list):
|
||
"""
|
||
执行一局对战,使用绝对标准的 OpenSpiel Game Loop。
|
||
|
||
Args:
|
||
game: OpenSpiel 游戏实例
|
||
players: [PlayerBase, PlayerBase],分别对应 P0 和 P1
|
||
|
||
Returns:
|
||
(terminal_state, history): 终局状态和牌谱记录
|
||
"""
|
||
state = game.new_initial_state()
|
||
history = []
|
||
|
||
while not state.is_terminal():
|
||
# ── Chance Node: 发牌 ──
|
||
if state.is_chance_node():
|
||
outcomes = state.chance_outcomes()
|
||
action_list, prob_list = zip(*outcomes)
|
||
chance_action = random.choices(action_list, weights=prob_list, k=1)[0]
|
||
|
||
# 用引擎原生方法记录发牌
|
||
desc = state.action_to_string(-1, chance_action)
|
||
history.append(desc)
|
||
|
||
state.apply_action(chance_action)
|
||
continue
|
||
|
||
# ── Player Node: 决策 ──
|
||
current_player = state.current_player()
|
||
player = players[current_player]
|
||
|
||
action = player.choose(state)
|
||
|
||
# 用引擎原生方法记录玩家动作(绝对不会出错)
|
||
desc = state.action_to_string(current_player, action)
|
||
history.append(f"P{current_player} ({player.name}): {desc}")
|
||
|
||
state.apply_action(action)
|
||
|
||
return state, history
|
||
|
||
|
||
# ───────────────────── 对局结算 ─────────────────────
|
||
|
||
def show_result(state, players: list, history: list):
|
||
"""打印对局结算信息。"""
|
||
returns = state.returns()
|
||
d = state.to_dict()
|
||
|
||
print()
|
||
print("=" * 50)
|
||
print(" 牌局结束 (Showdown)")
|
||
print("=" * 50)
|
||
|
||
# 双方底牌
|
||
hands = d.get("player_hands", ["", ""])
|
||
for pid in range(2):
|
||
hand_display = parse_hand_str(hands[pid]) if hands[pid] else "(未显示)"
|
||
player_name = players[pid].name
|
||
print(f" P{pid} ({player_name}) 底牌: {hand_display}")
|
||
|
||
# 公共牌
|
||
board_str = d.get("board_cards", "")
|
||
board_display = parse_hand_str(board_str) if board_str else "(无)"
|
||
print(f" 公共牌: {board_display}")
|
||
|
||
# 牌谱(使用引擎原生 action_to_string 记录,绝对准确)
|
||
print()
|
||
print(" --- 牌谱 (Hand History) ---")
|
||
for entry in history:
|
||
print(f" {entry}")
|
||
|
||
# 收益
|
||
print()
|
||
for pid in range(2):
|
||
ret = returns[pid]
|
||
sign = "+" if ret >= 0 else ""
|
||
player_name = players[pid].name
|
||
print(f" P{pid} ({player_name}) 收益: {sign}{ret:.0f}")
|
||
|
||
print("=" * 50)
|
||
|
||
|
||
# ───────────────────── 主程序 ─────────────────────
|
||
|
||
def main():
|
||
print()
|
||
print("=" * 50)
|
||
print(" Poker Arena — 德州扑克命令行对战")
|
||
print("=" * 50)
|
||
print()
|
||
print(" 请选择模式:")
|
||
print(" 1: 人类 vs 人类")
|
||
print(" 2: 人类 vs AI (人类 P0, AI P1)")
|
||
print(" 3: AI vs 人类 (AI P0, 人类 P1)")
|
||
print(" 4: AI vs AI (观战)")
|
||
print()
|
||
|
||
while True:
|
||
try:
|
||
mode_str = input(" 输入模式编号 (1-4): ").strip()
|
||
mode = int(mode_str)
|
||
if 1 <= mode <= 4:
|
||
break
|
||
print(" 请输入 1-4 之间的数字。")
|
||
except ValueError:
|
||
print(" 请输入数字。")
|
||
except KeyboardInterrupt:
|
||
print("\n 退出。")
|
||
sys.exit(0)
|
||
|
||
# 根据模式创建玩家
|
||
players = [None, None]
|
||
if mode == 1:
|
||
players[0] = HumanPlayer(0, "玩家1")
|
||
players[1] = HumanPlayer(1, "玩家2")
|
||
elif mode == 2:
|
||
players[0] = HumanPlayer(0, "你")
|
||
players[1] = AIPlayer(1, "AI")
|
||
elif mode == 3:
|
||
players[0] = AIPlayer(0, "AI")
|
||
players[1] = HumanPlayer(1, "你")
|
||
elif mode == 4:
|
||
players[0] = AIPlayer(0, "AI-0")
|
||
players[1] = AIPlayer(1, "AI-1")
|
||
|
||
print()
|
||
|
||
# 初始化游戏引擎
|
||
game = pyspiel.load_game(HUNL_FULLGAME_STRING)
|
||
|
||
# 对局循环
|
||
game_num = 0
|
||
while True:
|
||
game_num += 1
|
||
print()
|
||
print("#" * 50)
|
||
print(f" 第 {game_num} 局")
|
||
print("#" * 50)
|
||
|
||
terminal_state, history = run_game(game, players)
|
||
show_result(terminal_state, players, history)
|
||
|
||
# 是否继续
|
||
print()
|
||
try:
|
||
cont = input(" 按回车继续下一局,输入 q 退出: ").strip()
|
||
except KeyboardInterrupt:
|
||
print("\n 退出。")
|
||
break
|
||
|
||
if cont.lower() == 'q':
|
||
print(" 再见!")
|
||
break
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|