Files
TexasPoker-AI/run_tournament.py
2026-05-13 17:48:45 +08:00

214 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
run_tournament.py — 全自动 Round Robin 锦标赛
扫描 checkpoints/ 下的波谷模型ckpt_valley_iter_*.pt两两对战
调用 eval_elo.py 计算 bb/100最终输出排行榜。
用法:
python run_tournament.py
python run_tournament.py --num_games 50000
python run_tournament.py --checkpoints_dir checkpoints/
"""
import glob
import itertools
import os
import re
import subprocess
import sys
from datetime import datetime
from typing import Dict, List, Optional
# ── 默认参数 ──
_POKER_DIR = os.path.abspath(os.path.dirname(__file__))
DEFAULT_CKPT_DIR = os.path.join(_POKER_DIR, "checkpoints")
DEFAULT_NUM_GAMES = 50000
EVAL_SCRIPT = os.path.join(_POKER_DIR, "eval_elo.py")
LOG_FILE = os.path.join(_POKER_DIR, "tournament.log")
def find_valley_models(ckpt_dir: str) -> List[str]:
"""扫描 ckpt_dir 下所有 ckpt_valley_iter_*.pt按 iteration 排序返回路径列表。"""
pattern = os.path.join(ckpt_dir, "ckpt_valley_iter_*.pt")
paths = sorted(glob.glob(pattern))
return paths
def extract_model_name(path: str) -> str:
"""从路径提取简短模型名,如 'valley_500'"""
basename = os.path.basename(path)
# ckpt_valley_iter_500.pt -> valley_500
m = re.search(r"ckpt_valley_iter_(\d+)\.pt", basename)
if m:
return f"valley_{m.group(1)}"
return basename
def run_match(model_a: str, model_b: str, num_games: int) -> Optional[float]:
"""
调用 eval_elo.py 进行一对对战,返回 Model A 对 Model B 的 bb/100。
失败时返回 None。
"""
cmd = [
sys.executable, EVAL_SCRIPT,
"--model_a", model_a,
"--model_b", model_b,
"--num_games", str(num_games),
]
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=3600, # 单场最多 1 小时
)
except subprocess.TimeoutExpired:
print(f" [超时] {extract_model_name(model_a)} vs {extract_model_name(model_b)}")
return None
except Exception as e:
print(f" [异常] {extract_model_name(model_a)} vs {extract_model_name(model_b)}: {e}")
return None
if result.returncode != 0:
print(f" [失败] 返回码 {result.returncode}")
if result.stderr:
print(f" stderr: {result.stderr[:500]}")
return None
# 从最后一行输出中提取 bb/100
# 格式: "经过 50,000 局对抗Model A 对 Model B 的百手赢率为: +12.3 bb/100"
output = result.stdout
match = re.search(r"([+-]?\d+\.?\d*)\s*bb/100", output)
if match:
return float(match.group(1))
# 备选:尝试匹配更简单的格式
match = re.search(r"百手赢率为:\s*([+-]?\d+\.?\d*)", output)
if match:
return float(match.group(1))
print(f" [解析失败] 未能从输出中提取 bb/100")
print(f" 输出末尾: {output[-300:]}")
return None
def main():
import argparse
parser = argparse.ArgumentParser(description="Round Robin Tournament for Valley Models")
parser.add_argument("--checkpoints_dir", default=DEFAULT_CKPT_DIR,
help="波谷模型目录 (默认: checkpoints/)")
parser.add_argument("--num_games", type=int, default=DEFAULT_NUM_GAMES,
help="每对对战的局数 (默认: 50000)")
args = parser.parse_args()
ckpt_dir = args.checkpoints_dir
num_games = args.num_games
# 1. 扫描模型
models = find_valley_models(ckpt_dir)
if len(models) < 2:
print(f"[错误] 仅找到 {len(models)} 个波谷模型,至少需要 2 个才能开赛。")
print(f" 扫描路径: {os.path.join(ckpt_dir, 'ckpt_valley_iter_*.pt')}")
sys.exit(1)
model_names = [extract_model_name(p) for p in models]
print(f"\n{'='*70}")
print(f" Round Robin Tournament")
print(f" 参赛模型: {len(models)}")
print(f" 每场局数: {num_games:,}")
print(f" 总对战数: {len(models) * (len(models) - 1) // 2}")
print(f"{'='*70}")
for i, (path, name) in enumerate(zip(models, model_names)):
print(f" [{i+1}] {name} ({os.path.basename(path)})")
print()
# 2. 初始化积分和战绩矩阵
scores: Dict[str, float] = {name: 0.0 for name in model_names}
# match_results[i][j] = A_i 打 A_j 的 bb/100None 表示未进行)
match_results: List[List[Optional[float]]] = [
[None] * len(models) for _ in range(len(models))
]
# 3. 两两对战
total_matches = len(models) * (len(models) - 1) // 2
match_count = 0
for i, j in itertools.combinations(range(len(models)), 2):
match_count += 1
name_a = model_names[i]
name_b = model_names[j]
print(f"[{match_count}/{total_matches}] {name_a} vs {name_b} ... ", end="", flush=True)
bb_per_100 = run_match(models[i], models[j], num_games)
if bb_per_100 is not None:
sign = "+" if bb_per_100 >= 0 else ""
print(f"{sign}{bb_per_100:.1f} bb/100")
scores[name_a] += bb_per_100
scores[name_b] -= bb_per_100
match_results[i][j] = bb_per_100
match_results[j][i] = -bb_per_100
else:
print("失败 (跳过)")
match_results[i][j] = None
match_results[j][i] = None
# 4. 排行榜
ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True)
# 5. 构建输出
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
lines: List[str] = []
lines.append(f"\n{'='*70}")
lines.append(f" Tournament Results — {timestamp}")
lines.append(f" 参赛模型: {len(models)} | 每场: {num_games:,}")
lines.append(f"{'='*70}")
# 对战矩阵
# 列宽根据模型名长度自适应
max_name_len = max(len(n) for n in model_names)
col_w = max(max_name_len + 2, 10)
header = " " * col_w + "".join(f"{n:>{col_w}}" for n in model_names)
lines.append("\n [对战矩阵] (行 vs 列 = bb/100正值=行方占优)")
lines.append(header)
for i, name_i in enumerate(model_names):
row = f"{name_i:>{col_w}}"
for j in range(len(model_names)):
if i == j:
row += f"{'---':>{col_w}}"
else:
val = match_results[i][j]
if val is None:
row += f"{'N/A':>{col_w}}"
else:
sign = "+" if val >= 0 else ""
row += f"{sign}{val:.1f}".rjust(col_w)
lines.append(row)
# 排行榜
lines.append(f"\n [排行榜] (按净 bb/100 降序)")
lines.append(f" {'排名':>4} {'模型':>{max_name_len}} {'净 bb/100':>10}")
lines.append(f" {'----':>4} {'----':>{max_name_len}} {'--------':>10}")
for rank, (name, score) in enumerate(ranking, 1):
sign = "+" if score >= 0 else ""
lines.append(f" {rank:>4} {name:>{max_name_len}} {sign}{score:.1f}")
lines.append(f"{'='*70}\n")
# 6. 输出到终端和日志文件
output_text = "\n".join(lines)
print(output_text)
try:
with open(LOG_FILE, "a", encoding="utf-8") as f:
f.write(output_text + "\n")
print(f"[日志] 结果已追加到 {LOG_FILE}")
except Exception as e:
print(f"[警告] 写入日志文件失败: {e}")
if __name__ == "__main__":
main()