Mortal
This commit is contained in:
103
mortal/one_vs_three.py
Normal file
103
mortal/one_vs_three.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import prelude
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import secrets
|
||||
import os
|
||||
from model import Brain, DQN
|
||||
from engine import MortalEngine
|
||||
from libriichi.arena import OneVsThree
|
||||
from config import config
|
||||
|
||||
def main():
|
||||
cfg = config['1v3']
|
||||
games_per_iter = cfg['games_per_iter']
|
||||
seeds_per_iter = games_per_iter // 4
|
||||
iters = cfg['iters']
|
||||
log_dir = cfg['log_dir']
|
||||
use_akochan = cfg['akochan']['enabled']
|
||||
|
||||
if (key := cfg.get('seed_key', -1)) == -1:
|
||||
key = secrets.randbits(64)
|
||||
|
||||
if use_akochan:
|
||||
os.environ['AKOCHAN_DIR'] = cfg['akochan']['dir']
|
||||
os.environ['AKOCHAN_TACTICS'] = cfg['akochan']['tactics']
|
||||
else:
|
||||
state = torch.load(cfg['champion']['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
||||
cham_cfg = state['config']
|
||||
version = cham_cfg['control'].get('version', 1)
|
||||
conv_channels = cham_cfg['resnet']['conv_channels']
|
||||
num_blocks = cham_cfg['resnet']['num_blocks']
|
||||
mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
|
||||
dqn = DQN(version=version).eval()
|
||||
mortal.load_state_dict(state['mortal'])
|
||||
dqn.load_state_dict(state['current_dqn'])
|
||||
if cfg['champion']['enable_compile']:
|
||||
mortal.compile()
|
||||
dqn.compile()
|
||||
engine_cham = MortalEngine(
|
||||
mortal,
|
||||
dqn,
|
||||
is_oracle = False,
|
||||
version = version,
|
||||
device = torch.device(cfg['champion']['device']),
|
||||
enable_amp = cfg['champion']['enable_amp'],
|
||||
enable_rule_based_agari_guard = cfg['champion']['enable_rule_based_agari_guard'],
|
||||
name = cfg['champion']['name'],
|
||||
)
|
||||
|
||||
state = torch.load(cfg['challenger']['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
||||
chal_cfg = state['config']
|
||||
version = chal_cfg['control'].get('version', 1)
|
||||
conv_channels = chal_cfg['resnet']['conv_channels']
|
||||
num_blocks = chal_cfg['resnet']['num_blocks']
|
||||
mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
|
||||
dqn = DQN(version=version).eval()
|
||||
mortal.load_state_dict(state['mortal'])
|
||||
dqn.load_state_dict(state['current_dqn'])
|
||||
if cfg['challenger']['enable_compile']:
|
||||
mortal.compile()
|
||||
dqn.compile()
|
||||
engine_chal = MortalEngine(
|
||||
mortal,
|
||||
dqn,
|
||||
is_oracle = False,
|
||||
version = version,
|
||||
device = torch.device(cfg['challenger']['device']),
|
||||
enable_amp = cfg['challenger']['enable_amp'],
|
||||
enable_rule_based_agari_guard = cfg['challenger']['enable_rule_based_agari_guard'],
|
||||
name = cfg['challenger']['name'],
|
||||
)
|
||||
|
||||
seed_start = 10000
|
||||
for i, seed in enumerate(range(seed_start, seed_start + seeds_per_iter * iters, seeds_per_iter)):
|
||||
print('-' * 50)
|
||||
print('#', i)
|
||||
env = OneVsThree(
|
||||
disable_progress_bar = False,
|
||||
log_dir = log_dir,
|
||||
)
|
||||
if use_akochan:
|
||||
rankings = env.ako_vs_py(
|
||||
engine = engine_chal,
|
||||
seed_start = (seed, key),
|
||||
seed_count = seeds_per_iter,
|
||||
)
|
||||
else:
|
||||
rankings = env.py_vs_py(
|
||||
challenger = engine_chal,
|
||||
champion = engine_cham,
|
||||
seed_start = (seed, key),
|
||||
seed_count = seeds_per_iter,
|
||||
)
|
||||
rankings = np.array(rankings)
|
||||
avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
|
||||
avg_pt = rankings @ np.array([90, 45, 0, -135]) / rankings.sum()
|
||||
print(f'challenger rankings: {rankings} ({avg_rank}, {avg_pt}pt)')
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
Reference in New Issue
Block a user