Files
Mortal-Copied/mortal/player.py
e2hang b7a7d7404a
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled
Mortal
2025-10-07 20:30:03 +08:00

158 lines
5.3 KiB
Python

import torch
import numpy as np
import os
import shutil
import secrets
import logging
from os import path
from model import Brain, DQN
from engine import MortalEngine
from libriichi.stat import Stat
from libriichi.arena import OneVsThree
from config import config
class TestPlayer:
def __init__(self):
baseline_cfg = config['baseline']['test']
device = torch.device(baseline_cfg['device'])
state = torch.load(baseline_cfg['state_file'], weights_only=True, map_location=torch.device('cpu'))
cfg = state['config']
version = cfg['control'].get('version', 1)
conv_channels = cfg['resnet']['conv_channels']
num_blocks = cfg['resnet']['num_blocks']
stable_mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
stable_dqn = DQN(version=version).eval()
stable_mortal.load_state_dict(state['mortal'])
stable_dqn.load_state_dict(state['current_dqn'])
if baseline_cfg['enable_compile']:
stable_mortal.compile()
stable_dqn.compile()
self.baseline_engine = MortalEngine(
stable_mortal,
stable_dqn,
is_oracle = False,
version = version,
device = device,
enable_amp = True,
enable_rule_based_agari_guard = True,
name = 'baseline',
)
self.chal_version = config['control']['version']
self.log_dir = path.abspath(config['test_play']['log_dir'])
def test_play(self, seed_count, mortal, dqn, device):
torch.backends.cudnn.benchmark = False
engine_chal = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = self.chal_version,
device = device,
enable_amp = True,
name = 'mortal',
)
if path.isdir(self.log_dir):
shutil.rmtree(self.log_dir)
env = OneVsThree(
disable_progress_bar = False,
log_dir = self.log_dir,
)
env.py_vs_py(
challenger = engine_chal,
champion = self.baseline_engine,
seed_start = (10000, 0x2000),
seed_count = seed_count,
)
stat = Stat.from_dir(self.log_dir, 'mortal')
torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark']
return stat
class TrainPlayer:
def __init__(self):
baseline_cfg = config['baseline']['train']
device = torch.device(baseline_cfg['device'])
state = torch.load(baseline_cfg['state_file'], weights_only=True, map_location=torch.device('cpu'))
cfg = state['config']
version = cfg['control'].get('version', 1)
conv_channels = cfg['resnet']['conv_channels']
num_blocks = cfg['resnet']['num_blocks']
stable_mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
stable_dqn = DQN(version=version).eval()
stable_mortal.load_state_dict(state['mortal'])
stable_dqn.load_state_dict(state['current_dqn'])
if baseline_cfg['enable_compile']:
stable_mortal.compile()
stable_dqn.compile()
self.baseline_engine = MortalEngine(
stable_mortal,
stable_dqn,
is_oracle = False,
version = version,
device = device,
enable_amp = True,
enable_rule_based_agari_guard = True,
name = 'baseline',
)
profile = os.environ.get('TRAIN_PLAY_PROFILE', 'default')
logging.info(f'using profile {profile}')
cfg = config['train_play'][profile]
self.chal_version = config['control']['version']
self.log_dir = path.abspath(cfg['log_dir'])
self.train_key = secrets.randbits(64)
self.train_seed = 10000
self.seed_count = cfg['games'] // 4
self.boltzmann_epsilon = cfg['boltzmann_epsilon']
self.boltzmann_temp = cfg['boltzmann_temp']
self.top_p = cfg['top_p']
self.repeats = cfg['repeats']
self.repeat_counter = 0
def train_play(self, mortal, dqn, device):
torch.backends.cudnn.benchmark = False
engine_chal = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = self.chal_version,
boltzmann_epsilon = self.boltzmann_epsilon,
boltzmann_temp = self.boltzmann_temp,
top_p = self.top_p,
device = device,
enable_amp = True,
name = 'trainee',
)
if path.isdir(self.log_dir):
shutil.rmtree(self.log_dir)
env = OneVsThree(
disable_progress_bar = False,
log_dir = self.log_dir,
)
rankings = env.py_vs_py(
challenger = engine_chal,
champion = self.baseline_engine,
seed_start = (self.train_seed, self.train_key),
seed_count = self.seed_count,
)
self.repeat_counter += 1
if self.repeat_counter == self.repeats:
self.train_seed += self.seed_count
self.repeat_counter = 0
rankings = np.array(rankings)
file_list = list(map(lambda p: path.join(self.log_dir, p), os.listdir(self.log_dir)))
torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark']
return rankings, file_list