Files
Mortal-Copied/mortal/one_vs_three.py
e2hang b7a7d7404a
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled
Mortal
2025-10-07 20:30:03 +08:00

104 lines
3.7 KiB
Python

import prelude
import numpy as np
import torch
import secrets
import os
from model import Brain, DQN
from engine import MortalEngine
from libriichi.arena import OneVsThree
from config import config
def main():
cfg = config['1v3']
games_per_iter = cfg['games_per_iter']
seeds_per_iter = games_per_iter // 4
iters = cfg['iters']
log_dir = cfg['log_dir']
use_akochan = cfg['akochan']['enabled']
if (key := cfg.get('seed_key', -1)) == -1:
key = secrets.randbits(64)
if use_akochan:
os.environ['AKOCHAN_DIR'] = cfg['akochan']['dir']
os.environ['AKOCHAN_TACTICS'] = cfg['akochan']['tactics']
else:
state = torch.load(cfg['champion']['state_file'], weights_only=True, map_location=torch.device('cpu'))
cham_cfg = state['config']
version = cham_cfg['control'].get('version', 1)
conv_channels = cham_cfg['resnet']['conv_channels']
num_blocks = cham_cfg['resnet']['num_blocks']
mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
dqn = DQN(version=version).eval()
mortal.load_state_dict(state['mortal'])
dqn.load_state_dict(state['current_dqn'])
if cfg['champion']['enable_compile']:
mortal.compile()
dqn.compile()
engine_cham = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = version,
device = torch.device(cfg['champion']['device']),
enable_amp = cfg['champion']['enable_amp'],
enable_rule_based_agari_guard = cfg['champion']['enable_rule_based_agari_guard'],
name = cfg['champion']['name'],
)
state = torch.load(cfg['challenger']['state_file'], weights_only=True, map_location=torch.device('cpu'))
chal_cfg = state['config']
version = chal_cfg['control'].get('version', 1)
conv_channels = chal_cfg['resnet']['conv_channels']
num_blocks = chal_cfg['resnet']['num_blocks']
mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
dqn = DQN(version=version).eval()
mortal.load_state_dict(state['mortal'])
dqn.load_state_dict(state['current_dqn'])
if cfg['challenger']['enable_compile']:
mortal.compile()
dqn.compile()
engine_chal = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = version,
device = torch.device(cfg['challenger']['device']),
enable_amp = cfg['challenger']['enable_amp'],
enable_rule_based_agari_guard = cfg['challenger']['enable_rule_based_agari_guard'],
name = cfg['challenger']['name'],
)
seed_start = 10000
for i, seed in enumerate(range(seed_start, seed_start + seeds_per_iter * iters, seeds_per_iter)):
print('-' * 50)
print('#', i)
env = OneVsThree(
disable_progress_bar = False,
log_dir = log_dir,
)
if use_akochan:
rankings = env.ako_vs_py(
engine = engine_chal,
seed_start = (seed, key),
seed_count = seeds_per_iter,
)
else:
rankings = env.py_vs_py(
challenger = engine_chal,
champion = engine_cham,
seed_start = (seed, key),
seed_count = seeds_per_iter,
)
rankings = np.array(rankings)
avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
avg_pt = rankings @ np.array([90, 45, 0, -135]) / rankings.sum()
print(f'challenger rankings: {rankings} ({avg_rank}, {avg_pt}pt)')
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass