214 lines
7.3 KiB
Python
214 lines
7.3 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
run_tournament.py — 全自动 Round Robin 锦标赛
|
||
|
||
扫描 checkpoints/ 下的波谷模型(ckpt_valley_iter_*.pt),两两对战,
|
||
调用 eval_elo.py 计算 bb/100,最终输出排行榜。
|
||
|
||
用法:
|
||
python run_tournament.py
|
||
python run_tournament.py --num_games 50000
|
||
python run_tournament.py --checkpoints_dir checkpoints/
|
||
"""
|
||
|
||
import glob
|
||
import itertools
|
||
import os
|
||
import re
|
||
import subprocess
|
||
import sys
|
||
from datetime import datetime
|
||
from typing import Dict, List, Optional
|
||
|
||
# ── 默认参数 ──
|
||
_POKER_DIR = os.path.abspath(os.path.dirname(__file__))
|
||
DEFAULT_CKPT_DIR = os.path.join(_POKER_DIR, "checkpoints")
|
||
DEFAULT_NUM_GAMES = 50000
|
||
EVAL_SCRIPT = os.path.join(_POKER_DIR, "eval_elo.py")
|
||
LOG_FILE = os.path.join(_POKER_DIR, "tournament.log")
|
||
|
||
|
||
def find_valley_models(ckpt_dir: str) -> List[str]:
|
||
"""扫描 ckpt_dir 下所有 ckpt_valley_iter_*.pt,按 iteration 排序返回路径列表。"""
|
||
pattern = os.path.join(ckpt_dir, "ckpt_valley_iter_*.pt")
|
||
paths = sorted(glob.glob(pattern))
|
||
return paths
|
||
|
||
|
||
def extract_model_name(path: str) -> str:
|
||
"""从路径提取简短模型名,如 'valley_500'。"""
|
||
basename = os.path.basename(path)
|
||
# ckpt_valley_iter_500.pt -> valley_500
|
||
m = re.search(r"ckpt_valley_iter_(\d+)\.pt", basename)
|
||
if m:
|
||
return f"valley_{m.group(1)}"
|
||
return basename
|
||
|
||
|
||
def run_match(model_a: str, model_b: str, num_games: int) -> Optional[float]:
|
||
"""
|
||
调用 eval_elo.py 进行一对对战,返回 Model A 对 Model B 的 bb/100。
|
||
失败时返回 None。
|
||
"""
|
||
cmd = [
|
||
sys.executable, EVAL_SCRIPT,
|
||
"--model_a", model_a,
|
||
"--model_b", model_b,
|
||
"--num_games", str(num_games),
|
||
]
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=3600, # 单场最多 1 小时
|
||
)
|
||
except subprocess.TimeoutExpired:
|
||
print(f" [超时] {extract_model_name(model_a)} vs {extract_model_name(model_b)}")
|
||
return None
|
||
except Exception as e:
|
||
print(f" [异常] {extract_model_name(model_a)} vs {extract_model_name(model_b)}: {e}")
|
||
return None
|
||
|
||
if result.returncode != 0:
|
||
print(f" [失败] 返回码 {result.returncode}")
|
||
if result.stderr:
|
||
print(f" stderr: {result.stderr[:500]}")
|
||
return None
|
||
|
||
# 从最后一行输出中提取 bb/100
|
||
# 格式: "经过 50,000 局对抗,Model A 对 Model B 的百手赢率为: +12.3 bb/100"
|
||
output = result.stdout
|
||
match = re.search(r"([+-]?\d+\.?\d*)\s*bb/100", output)
|
||
if match:
|
||
return float(match.group(1))
|
||
|
||
# 备选:尝试匹配更简单的格式
|
||
match = re.search(r"百手赢率为:\s*([+-]?\d+\.?\d*)", output)
|
||
if match:
|
||
return float(match.group(1))
|
||
|
||
print(f" [解析失败] 未能从输出中提取 bb/100")
|
||
print(f" 输出末尾: {output[-300:]}")
|
||
return None
|
||
|
||
|
||
def main():
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description="Round Robin Tournament for Valley Models")
|
||
parser.add_argument("--checkpoints_dir", default=DEFAULT_CKPT_DIR,
|
||
help="波谷模型目录 (默认: checkpoints/)")
|
||
parser.add_argument("--num_games", type=int, default=DEFAULT_NUM_GAMES,
|
||
help="每对对战的局数 (默认: 50000)")
|
||
args = parser.parse_args()
|
||
|
||
ckpt_dir = args.checkpoints_dir
|
||
num_games = args.num_games
|
||
|
||
# 1. 扫描模型
|
||
models = find_valley_models(ckpt_dir)
|
||
if len(models) < 2:
|
||
print(f"[错误] 仅找到 {len(models)} 个波谷模型,至少需要 2 个才能开赛。")
|
||
print(f" 扫描路径: {os.path.join(ckpt_dir, 'ckpt_valley_iter_*.pt')}")
|
||
sys.exit(1)
|
||
|
||
model_names = [extract_model_name(p) for p in models]
|
||
print(f"\n{'='*70}")
|
||
print(f" Round Robin Tournament")
|
||
print(f" 参赛模型: {len(models)} 个")
|
||
print(f" 每场局数: {num_games:,}")
|
||
print(f" 总对战数: {len(models) * (len(models) - 1) // 2}")
|
||
print(f"{'='*70}")
|
||
for i, (path, name) in enumerate(zip(models, model_names)):
|
||
print(f" [{i+1}] {name} ({os.path.basename(path)})")
|
||
print()
|
||
|
||
# 2. 初始化积分和战绩矩阵
|
||
scores: Dict[str, float] = {name: 0.0 for name in model_names}
|
||
# match_results[i][j] = A_i 打 A_j 的 bb/100(None 表示未进行)
|
||
match_results: List[List[Optional[float]]] = [
|
||
[None] * len(models) for _ in range(len(models))
|
||
]
|
||
|
||
# 3. 两两对战
|
||
total_matches = len(models) * (len(models) - 1) // 2
|
||
match_count = 0
|
||
|
||
for i, j in itertools.combinations(range(len(models)), 2):
|
||
match_count += 1
|
||
name_a = model_names[i]
|
||
name_b = model_names[j]
|
||
print(f"[{match_count}/{total_matches}] {name_a} vs {name_b} ... ", end="", flush=True)
|
||
|
||
bb_per_100 = run_match(models[i], models[j], num_games)
|
||
|
||
if bb_per_100 is not None:
|
||
sign = "+" if bb_per_100 >= 0 else ""
|
||
print(f"{sign}{bb_per_100:.1f} bb/100")
|
||
scores[name_a] += bb_per_100
|
||
scores[name_b] -= bb_per_100
|
||
match_results[i][j] = bb_per_100
|
||
match_results[j][i] = -bb_per_100
|
||
else:
|
||
print("失败 (跳过)")
|
||
match_results[i][j] = None
|
||
match_results[j][i] = None
|
||
|
||
# 4. 排行榜
|
||
ranking = sorted(scores.items(), key=lambda x: x[1], reverse=True)
|
||
|
||
# 5. 构建输出
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
lines: List[str] = []
|
||
lines.append(f"\n{'='*70}")
|
||
lines.append(f" Tournament Results — {timestamp}")
|
||
lines.append(f" 参赛模型: {len(models)} | 每场: {num_games:,} 局")
|
||
lines.append(f"{'='*70}")
|
||
|
||
# 对战矩阵
|
||
# 列宽根据模型名长度自适应
|
||
max_name_len = max(len(n) for n in model_names)
|
||
col_w = max(max_name_len + 2, 10)
|
||
|
||
header = " " * col_w + "".join(f"{n:>{col_w}}" for n in model_names)
|
||
lines.append("\n [对战矩阵] (行 vs 列 = bb/100,正值=行方占优)")
|
||
lines.append(header)
|
||
for i, name_i in enumerate(model_names):
|
||
row = f"{name_i:>{col_w}}"
|
||
for j in range(len(model_names)):
|
||
if i == j:
|
||
row += f"{'---':>{col_w}}"
|
||
else:
|
||
val = match_results[i][j]
|
||
if val is None:
|
||
row += f"{'N/A':>{col_w}}"
|
||
else:
|
||
sign = "+" if val >= 0 else ""
|
||
row += f"{sign}{val:.1f}".rjust(col_w)
|
||
lines.append(row)
|
||
|
||
# 排行榜
|
||
lines.append(f"\n [排行榜] (按净 bb/100 降序)")
|
||
lines.append(f" {'排名':>4} {'模型':>{max_name_len}} {'净 bb/100':>10}")
|
||
lines.append(f" {'----':>4} {'----':>{max_name_len}} {'--------':>10}")
|
||
for rank, (name, score) in enumerate(ranking, 1):
|
||
sign = "+" if score >= 0 else ""
|
||
lines.append(f" {rank:>4} {name:>{max_name_len}} {sign}{score:.1f}")
|
||
lines.append(f"{'='*70}\n")
|
||
|
||
# 6. 输出到终端和日志文件
|
||
output_text = "\n".join(lines)
|
||
print(output_text)
|
||
|
||
try:
|
||
with open(LOG_FILE, "a", encoding="utf-8") as f:
|
||
f.write(output_text + "\n")
|
||
print(f"[日志] 结果已追加到 {LOG_FILE}")
|
||
except Exception as e:
|
||
print(f"[警告] 写入日志文件失败: {e}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|