import torch import numpy as np import os import shutil import secrets import logging from os import path from model import Brain, DQN from engine import MortalEngine from libriichi.stat import Stat from libriichi.arena import OneVsThree from config import config class TestPlayer: def __init__(self): baseline_cfg = config['baseline']['test'] device = torch.device(baseline_cfg['device']) state = torch.load(baseline_cfg['state_file'], weights_only=True, map_location=torch.device('cpu')) cfg = state['config'] version = cfg['control'].get('version', 1) conv_channels = cfg['resnet']['conv_channels'] num_blocks = cfg['resnet']['num_blocks'] stable_mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval() stable_dqn = DQN(version=version).eval() stable_mortal.load_state_dict(state['mortal']) stable_dqn.load_state_dict(state['current_dqn']) if baseline_cfg['enable_compile']: stable_mortal.compile() stable_dqn.compile() self.baseline_engine = MortalEngine( stable_mortal, stable_dqn, is_oracle = False, version = version, device = device, enable_amp = True, enable_rule_based_agari_guard = True, name = 'baseline', ) self.chal_version = config['control']['version'] self.log_dir = path.abspath(config['test_play']['log_dir']) def test_play(self, seed_count, mortal, dqn, device): torch.backends.cudnn.benchmark = False engine_chal = MortalEngine( mortal, dqn, is_oracle = False, version = self.chal_version, device = device, enable_amp = True, name = 'mortal', ) if path.isdir(self.log_dir): shutil.rmtree(self.log_dir) env = OneVsThree( disable_progress_bar = False, log_dir = self.log_dir, ) env.py_vs_py( challenger = engine_chal, champion = self.baseline_engine, seed_start = (10000, 0x2000), seed_count = seed_count, ) stat = Stat.from_dir(self.log_dir, 'mortal') torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark'] return stat class TrainPlayer: def __init__(self): baseline_cfg = config['baseline']['train'] device = torch.device(baseline_cfg['device']) state = torch.load(baseline_cfg['state_file'], weights_only=True, map_location=torch.device('cpu')) cfg = state['config'] version = cfg['control'].get('version', 1) conv_channels = cfg['resnet']['conv_channels'] num_blocks = cfg['resnet']['num_blocks'] stable_mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval() stable_dqn = DQN(version=version).eval() stable_mortal.load_state_dict(state['mortal']) stable_dqn.load_state_dict(state['current_dqn']) if baseline_cfg['enable_compile']: stable_mortal.compile() stable_dqn.compile() self.baseline_engine = MortalEngine( stable_mortal, stable_dqn, is_oracle = False, version = version, device = device, enable_amp = True, enable_rule_based_agari_guard = True, name = 'baseline', ) profile = os.environ.get('TRAIN_PLAY_PROFILE', 'default') logging.info(f'using profile {profile}') cfg = config['train_play'][profile] self.chal_version = config['control']['version'] self.log_dir = path.abspath(cfg['log_dir']) self.train_key = secrets.randbits(64) self.train_seed = 10000 self.seed_count = cfg['games'] // 4 self.boltzmann_epsilon = cfg['boltzmann_epsilon'] self.boltzmann_temp = cfg['boltzmann_temp'] self.top_p = cfg['top_p'] self.repeats = cfg['repeats'] self.repeat_counter = 0 def train_play(self, mortal, dqn, device): torch.backends.cudnn.benchmark = False engine_chal = MortalEngine( mortal, dqn, is_oracle = False, version = self.chal_version, boltzmann_epsilon = self.boltzmann_epsilon, boltzmann_temp = self.boltzmann_temp, top_p = self.top_p, device = device, enable_amp = True, name = 'trainee', ) if path.isdir(self.log_dir): shutil.rmtree(self.log_dir) env = OneVsThree( disable_progress_bar = False, log_dir = self.log_dir, ) rankings = env.py_vs_py( challenger = engine_chal, champion = self.baseline_engine, seed_start = (self.train_seed, self.train_key), seed_count = self.seed_count, ) self.repeat_counter += 1 if self.repeat_counter == self.repeats: self.train_seed += self.seed_count self.repeat_counter = 0 rankings = np.array(rankings) file_list = list(map(lambda p: path.join(self.log_dir, p), os.listdir(self.log_dir))) torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark'] return rankings, file_list