#!/usr/bin/env python3 """ run_tournament.py — 全自动 Round Robin 锦标赛 扫描 checkpoints/ 下的波谷模型(ckpt_valley_iter_*.pt),两两对战, 调用 eval_elo.py 计算 bb/100,最终输出排行榜。 用法: python run_tournament.py python run_tournament.py --num_games 50000 python run_tournament.py --checkpoints_dir checkpoints/ """ import glob import itertools import os import re import subprocess import sys from datetime import datetime from typing import Dict, List, Optional # ── 默认参数 ── _POKER_DIR = os.path.abspath(os.path.dirname(__file__)) DEFAULT_CKPT_DIR = os.path.join(_POKER_DIR, "checkpoints") DEFAULT_NUM_GAMES = 50000 EVAL_SCRIPT = os.path.join(_POKER_DIR, "eval_elo.py") LOG_FILE = os.path.join(_POKER_DIR, "tournament.log") def find_valley_models(ckpt_dir: str) -> List[str]: """扫描 ckpt_dir 下所有 ckpt_valley_iter_*.pt,按 iteration 排序返回路径列表。""" pattern = os.path.join(ckpt_dir, "ckpt_valley_iter_*.pt") paths = sorted(glob.glob(pattern)) return paths def extract_model_name(path: str) -> str: """从路径提取简短模型名,如 'valley_500'。""" basename = os.path.basename(path) # ckpt_valley_iter_500.pt -> valley_500 m = re.search(r"ckpt_valley_iter_(\d+)\.pt", basename) if m: return f"valley_{m.group(1)}" return basename def run_match(model_a: str, model_b: str, num_games: int) -> Optional[float]: """ 调用 eval_elo.py 进行一对对战,返回 Model A 对 Model B 的 bb/100。 失败时返回 None。 """ cmd = [ sys.executable, EVAL_SCRIPT, "--model_a", model_a, "--model_b", model_b, "--num_games", str(num_games), ] try: result = subprocess.run( cmd, capture_output=True, text=True, timeout=3600, # 单场最多 1 小时 ) except subprocess.TimeoutExpired: print(f" [超时] {extract_model_name(model_a)} vs {extract_model_name(model_b)}") return None except Exception as e: print(f" [异常] {extract_model_name(model_a)} vs {extract_model_name(model_b)}: {e}") return None if result.returncode != 0: print(f" [失败] 返回码 {result.returncode}") if result.stderr: print(f" stderr: {result.stderr[:500]}") return None # 从最后一行输出中提取 bb/100 # 格式: "经过 50,000 局对抗,Model A 对 Model B 的百手赢率为: +12.3 bb/100" output = result.stdout match = re.search(r"([+-]?\d+\.?\d*)\s*bb/100", output) if match: return float(match.group(1)) # 备选:尝试匹配更简单的格式 match = re.search(r"百手赢率为:\s*([+-]?\d+\.?\d*)", output) if match: return float(match.group(1)) print(f" [解析失败] 未能从输出中提取 bb/100") print(f" 输出末尾: {output[-300:]}") return None def main(): import argparse parser = argparse.ArgumentParser(description="Round Robin Tournament for Valley Models") parser.add_argument("--checkpoints_dir", default=DEFAULT_CKPT_DIR, help="波谷模型目录 (默认: checkpoints/)") parser.add_argument("--num_games", type=int, default=DEFAULT_NUM_GAMES, help="每对对战的局数 (默认: 50000)") args = parser.parse_args() ckpt_dir = args.checkpoints_dir num_games = args.num_games # 1. 扫描模型 models = find_valley_models(ckpt_dir) if len(models) < 2: print(f"[错误] 仅找到 {len(models)} 个波谷模型,至少需要 2 个才能开赛。") print(f" 扫描路径: {os.path.join(ckpt_dir, 'ckpt_valley_iter_*.pt')}") sys.exit(1) model_names = [extract_model_name(p) for p in models] print(f"\n{'='*70}") print(f" Round Robin Tournament") print(f" 参赛模型: {len(models)} 个") print(f" 每场局数: {num_games:,}") print(f" 总对战数: {len(models) * (len(models) - 1) // 2}") print(f"{'='*70}") for i, (path, name) in enumerate(zip(models, model_names)): print(f" [{i+1}] {name} ({os.path.basename(path)})") print() # 2. 初始化积分和战绩矩阵 scores: Dict[str, float] = {name: 0.0 for name in model_names} # match_results[i][j] = A_i 打 A_j 的 bb/100(None 表示未进行) match_results: List[List[Optional[float]]] = [ [None] * len(models) for _ in range(len(models)) ] # 3. 两两对战 total_matches = len(models) * (len(models) - 1) // 2 match_count = 0 for i, j in itertools.combinations(range(len(models)), 2): match_count += 1 name_a = model_names[i] name_b = model_names[j] print(f"[{match_count}/{total_matches}] {name_a} vs {name_b} ... ", end="", flush=True) bb_per_100 = run_match(models[i], models[j], num_games) if bb_per_100 is not None: sign = "+" if bb_per_100 >= 0 else "" print(f"{sign}{bb_per_100:.1f} bb/100") scores[name_a] += bb_per_100 scores[name_b] -= bb_per_100 match_results[i][j] = bb_per_100 match_results[j][i] = -bb_per_100 else: print("失败 (跳过)") match_results[i][j] = None match_results[j][i] = None # 4. 排行榜 ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True) # 5. 构建输出 timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") lines: List[str] = [] lines.append(f"\n{'='*70}") lines.append(f" Tournament Results — {timestamp}") lines.append(f" 参赛模型: {len(models)} | 每场: {num_games:,} 局") lines.append(f"{'='*70}") # 对战矩阵 # 列宽根据模型名长度自适应 max_name_len = max(len(n) for n in model_names) col_w = max(max_name_len + 2, 10) header = " " * col_w + "".join(f"{n:>{col_w}}" for n in model_names) lines.append("\n [对战矩阵] (行 vs 列 = bb/100,正值=行方占优)") lines.append(header) for i, name_i in enumerate(model_names): row = f"{name_i:>{col_w}}" for j in range(len(model_names)): if i == j: row += f"{'---':>{col_w}}" else: val = match_results[i][j] if val is None: row += f"{'N/A':>{col_w}}" else: sign = "+" if val >= 0 else "" row += f"{sign}{val:.1f}".rjust(col_w) lines.append(row) # 排行榜 lines.append(f"\n [排行榜] (按净 bb/100 降序)") lines.append(f" {'排名':>4} {'模型':>{max_name_len}} {'净 bb/100':>10}") lines.append(f" {'----':>4} {'----':>{max_name_len}} {'--------':>10}") for rank, (name, score) in enumerate(ranking, 1): sign = "+" if score >= 0 else "" lines.append(f" {rank:>4} {name:>{max_name_len}} {sign}{score:.1f}") lines.append(f"{'='*70}\n") # 6. 输出到终端和日志文件 output_text = "\n".join(lines) print(output_text) try: with open(LOG_FILE, "a", encoding="utf-8") as f: f.write(output_text + "\n") print(f"[日志] 结果已追加到 {LOG_FILE}") except Exception as e: print(f"[警告] 写入日志文件失败: {e}") if __name__ == "__main__": main()