158 lines
5.3 KiB
Python
158 lines
5.3 KiB
Python
import torch
|
|
import numpy as np
|
|
import os
|
|
import shutil
|
|
import secrets
|
|
import logging
|
|
from os import path
|
|
from model import Brain, DQN
|
|
from engine import MortalEngine
|
|
from libriichi.stat import Stat
|
|
from libriichi.arena import OneVsThree
|
|
from config import config
|
|
|
|
class TestPlayer:
|
|
def __init__(self):
|
|
baseline_cfg = config['baseline']['test']
|
|
device = torch.device(baseline_cfg['device'])
|
|
|
|
state = torch.load(baseline_cfg['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
|
cfg = state['config']
|
|
version = cfg['control'].get('version', 1)
|
|
conv_channels = cfg['resnet']['conv_channels']
|
|
num_blocks = cfg['resnet']['num_blocks']
|
|
stable_mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
|
|
stable_dqn = DQN(version=version).eval()
|
|
stable_mortal.load_state_dict(state['mortal'])
|
|
stable_dqn.load_state_dict(state['current_dqn'])
|
|
if baseline_cfg['enable_compile']:
|
|
stable_mortal.compile()
|
|
stable_dqn.compile()
|
|
|
|
self.baseline_engine = MortalEngine(
|
|
stable_mortal,
|
|
stable_dqn,
|
|
is_oracle = False,
|
|
version = version,
|
|
device = device,
|
|
enable_amp = True,
|
|
enable_rule_based_agari_guard = True,
|
|
name = 'baseline',
|
|
)
|
|
self.chal_version = config['control']['version']
|
|
self.log_dir = path.abspath(config['test_play']['log_dir'])
|
|
|
|
def test_play(self, seed_count, mortal, dqn, device):
|
|
torch.backends.cudnn.benchmark = False
|
|
engine_chal = MortalEngine(
|
|
mortal,
|
|
dqn,
|
|
is_oracle = False,
|
|
version = self.chal_version,
|
|
device = device,
|
|
enable_amp = True,
|
|
name = 'mortal',
|
|
)
|
|
|
|
if path.isdir(self.log_dir):
|
|
shutil.rmtree(self.log_dir)
|
|
|
|
env = OneVsThree(
|
|
disable_progress_bar = False,
|
|
log_dir = self.log_dir,
|
|
)
|
|
env.py_vs_py(
|
|
challenger = engine_chal,
|
|
champion = self.baseline_engine,
|
|
seed_start = (10000, 0x2000),
|
|
seed_count = seed_count,
|
|
)
|
|
|
|
stat = Stat.from_dir(self.log_dir, 'mortal')
|
|
torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark']
|
|
return stat
|
|
|
|
class TrainPlayer:
|
|
def __init__(self):
|
|
baseline_cfg = config['baseline']['train']
|
|
device = torch.device(baseline_cfg['device'])
|
|
|
|
state = torch.load(baseline_cfg['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
|
cfg = state['config']
|
|
version = cfg['control'].get('version', 1)
|
|
conv_channels = cfg['resnet']['conv_channels']
|
|
num_blocks = cfg['resnet']['num_blocks']
|
|
stable_mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
|
|
stable_dqn = DQN(version=version).eval()
|
|
stable_mortal.load_state_dict(state['mortal'])
|
|
stable_dqn.load_state_dict(state['current_dqn'])
|
|
if baseline_cfg['enable_compile']:
|
|
stable_mortal.compile()
|
|
stable_dqn.compile()
|
|
|
|
self.baseline_engine = MortalEngine(
|
|
stable_mortal,
|
|
stable_dqn,
|
|
is_oracle = False,
|
|
version = version,
|
|
device = device,
|
|
enable_amp = True,
|
|
enable_rule_based_agari_guard = True,
|
|
name = 'baseline',
|
|
)
|
|
|
|
profile = os.environ.get('TRAIN_PLAY_PROFILE', 'default')
|
|
logging.info(f'using profile {profile}')
|
|
cfg = config['train_play'][profile]
|
|
self.chal_version = config['control']['version']
|
|
self.log_dir = path.abspath(cfg['log_dir'])
|
|
self.train_key = secrets.randbits(64)
|
|
self.train_seed = 10000
|
|
|
|
self.seed_count = cfg['games'] // 4
|
|
self.boltzmann_epsilon = cfg['boltzmann_epsilon']
|
|
self.boltzmann_temp = cfg['boltzmann_temp']
|
|
self.top_p = cfg['top_p']
|
|
|
|
self.repeats = cfg['repeats']
|
|
self.repeat_counter = 0
|
|
|
|
def train_play(self, mortal, dqn, device):
|
|
torch.backends.cudnn.benchmark = False
|
|
engine_chal = MortalEngine(
|
|
mortal,
|
|
dqn,
|
|
is_oracle = False,
|
|
version = self.chal_version,
|
|
boltzmann_epsilon = self.boltzmann_epsilon,
|
|
boltzmann_temp = self.boltzmann_temp,
|
|
top_p = self.top_p,
|
|
device = device,
|
|
enable_amp = True,
|
|
name = 'trainee',
|
|
)
|
|
|
|
if path.isdir(self.log_dir):
|
|
shutil.rmtree(self.log_dir)
|
|
|
|
env = OneVsThree(
|
|
disable_progress_bar = False,
|
|
log_dir = self.log_dir,
|
|
)
|
|
rankings = env.py_vs_py(
|
|
challenger = engine_chal,
|
|
champion = self.baseline_engine,
|
|
seed_start = (self.train_seed, self.train_key),
|
|
seed_count = self.seed_count,
|
|
)
|
|
self.repeat_counter += 1
|
|
if self.repeat_counter == self.repeats:
|
|
self.train_seed += self.seed_count
|
|
self.repeat_counter = 0
|
|
|
|
rankings = np.array(rankings)
|
|
file_list = list(map(lambda p: path.join(self.log_dir, p), os.listdir(self.log_dir)))
|
|
|
|
torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark']
|
|
return rankings, file_list
|