104 lines
3.7 KiB
Python
104 lines
3.7 KiB
Python
import prelude
|
|
|
|
import numpy as np
|
|
import torch
|
|
import secrets
|
|
import os
|
|
from model import Brain, DQN
|
|
from engine import MortalEngine
|
|
from libriichi.arena import OneVsThree
|
|
from config import config
|
|
|
|
def main():
|
|
cfg = config['1v3']
|
|
games_per_iter = cfg['games_per_iter']
|
|
seeds_per_iter = games_per_iter // 4
|
|
iters = cfg['iters']
|
|
log_dir = cfg['log_dir']
|
|
use_akochan = cfg['akochan']['enabled']
|
|
|
|
if (key := cfg.get('seed_key', -1)) == -1:
|
|
key = secrets.randbits(64)
|
|
|
|
if use_akochan:
|
|
os.environ['AKOCHAN_DIR'] = cfg['akochan']['dir']
|
|
os.environ['AKOCHAN_TACTICS'] = cfg['akochan']['tactics']
|
|
else:
|
|
state = torch.load(cfg['champion']['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
|
cham_cfg = state['config']
|
|
version = cham_cfg['control'].get('version', 1)
|
|
conv_channels = cham_cfg['resnet']['conv_channels']
|
|
num_blocks = cham_cfg['resnet']['num_blocks']
|
|
mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
|
|
dqn = DQN(version=version).eval()
|
|
mortal.load_state_dict(state['mortal'])
|
|
dqn.load_state_dict(state['current_dqn'])
|
|
if cfg['champion']['enable_compile']:
|
|
mortal.compile()
|
|
dqn.compile()
|
|
engine_cham = MortalEngine(
|
|
mortal,
|
|
dqn,
|
|
is_oracle = False,
|
|
version = version,
|
|
device = torch.device(cfg['champion']['device']),
|
|
enable_amp = cfg['champion']['enable_amp'],
|
|
enable_rule_based_agari_guard = cfg['champion']['enable_rule_based_agari_guard'],
|
|
name = cfg['champion']['name'],
|
|
)
|
|
|
|
state = torch.load(cfg['challenger']['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
|
chal_cfg = state['config']
|
|
version = chal_cfg['control'].get('version', 1)
|
|
conv_channels = chal_cfg['resnet']['conv_channels']
|
|
num_blocks = chal_cfg['resnet']['num_blocks']
|
|
mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
|
|
dqn = DQN(version=version).eval()
|
|
mortal.load_state_dict(state['mortal'])
|
|
dqn.load_state_dict(state['current_dqn'])
|
|
if cfg['challenger']['enable_compile']:
|
|
mortal.compile()
|
|
dqn.compile()
|
|
engine_chal = MortalEngine(
|
|
mortal,
|
|
dqn,
|
|
is_oracle = False,
|
|
version = version,
|
|
device = torch.device(cfg['challenger']['device']),
|
|
enable_amp = cfg['challenger']['enable_amp'],
|
|
enable_rule_based_agari_guard = cfg['challenger']['enable_rule_based_agari_guard'],
|
|
name = cfg['challenger']['name'],
|
|
)
|
|
|
|
seed_start = 10000
|
|
for i, seed in enumerate(range(seed_start, seed_start + seeds_per_iter * iters, seeds_per_iter)):
|
|
print('-' * 50)
|
|
print('#', i)
|
|
env = OneVsThree(
|
|
disable_progress_bar = False,
|
|
log_dir = log_dir,
|
|
)
|
|
if use_akochan:
|
|
rankings = env.ako_vs_py(
|
|
engine = engine_chal,
|
|
seed_start = (seed, key),
|
|
seed_count = seeds_per_iter,
|
|
)
|
|
else:
|
|
rankings = env.py_vs_py(
|
|
challenger = engine_chal,
|
|
champion = engine_cham,
|
|
seed_start = (seed, key),
|
|
seed_count = seeds_per_iter,
|
|
)
|
|
rankings = np.array(rankings)
|
|
avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
|
|
avg_pt = rankings @ np.array([90, 45, 0, -135]) / rankings.sum()
|
|
print(f'challenger rankings: {rankings} ({avg_rank}, {avg_pt}pt)')
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
pass
|