Mortal
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled

This commit is contained in:
e2hang
2025-10-07 20:30:03 +08:00
commit b7a7d7404a
441 changed files with 23367 additions and 0 deletions

157
mortal/player.py Normal file
View File

@@ -0,0 +1,157 @@
import torch
import numpy as np
import os
import shutil
import secrets
import logging
from os import path
from model import Brain, DQN
from engine import MortalEngine
from libriichi.stat import Stat
from libriichi.arena import OneVsThree
from config import config
class TestPlayer:
def __init__(self):
baseline_cfg = config['baseline']['test']
device = torch.device(baseline_cfg['device'])
state = torch.load(baseline_cfg['state_file'], weights_only=True, map_location=torch.device('cpu'))
cfg = state['config']
version = cfg['control'].get('version', 1)
conv_channels = cfg['resnet']['conv_channels']
num_blocks = cfg['resnet']['num_blocks']
stable_mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
stable_dqn = DQN(version=version).eval()
stable_mortal.load_state_dict(state['mortal'])
stable_dqn.load_state_dict(state['current_dqn'])
if baseline_cfg['enable_compile']:
stable_mortal.compile()
stable_dqn.compile()
self.baseline_engine = MortalEngine(
stable_mortal,
stable_dqn,
is_oracle = False,
version = version,
device = device,
enable_amp = True,
enable_rule_based_agari_guard = True,
name = 'baseline',
)
self.chal_version = config['control']['version']
self.log_dir = path.abspath(config['test_play']['log_dir'])
def test_play(self, seed_count, mortal, dqn, device):
torch.backends.cudnn.benchmark = False
engine_chal = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = self.chal_version,
device = device,
enable_amp = True,
name = 'mortal',
)
if path.isdir(self.log_dir):
shutil.rmtree(self.log_dir)
env = OneVsThree(
disable_progress_bar = False,
log_dir = self.log_dir,
)
env.py_vs_py(
challenger = engine_chal,
champion = self.baseline_engine,
seed_start = (10000, 0x2000),
seed_count = seed_count,
)
stat = Stat.from_dir(self.log_dir, 'mortal')
torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark']
return stat
class TrainPlayer:
def __init__(self):
baseline_cfg = config['baseline']['train']
device = torch.device(baseline_cfg['device'])
state = torch.load(baseline_cfg['state_file'], weights_only=True, map_location=torch.device('cpu'))
cfg = state['config']
version = cfg['control'].get('version', 1)
conv_channels = cfg['resnet']['conv_channels']
num_blocks = cfg['resnet']['num_blocks']
stable_mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
stable_dqn = DQN(version=version).eval()
stable_mortal.load_state_dict(state['mortal'])
stable_dqn.load_state_dict(state['current_dqn'])
if baseline_cfg['enable_compile']:
stable_mortal.compile()
stable_dqn.compile()
self.baseline_engine = MortalEngine(
stable_mortal,
stable_dqn,
is_oracle = False,
version = version,
device = device,
enable_amp = True,
enable_rule_based_agari_guard = True,
name = 'baseline',
)
profile = os.environ.get('TRAIN_PLAY_PROFILE', 'default')
logging.info(f'using profile {profile}')
cfg = config['train_play'][profile]
self.chal_version = config['control']['version']
self.log_dir = path.abspath(cfg['log_dir'])
self.train_key = secrets.randbits(64)
self.train_seed = 10000
self.seed_count = cfg['games'] // 4
self.boltzmann_epsilon = cfg['boltzmann_epsilon']
self.boltzmann_temp = cfg['boltzmann_temp']
self.top_p = cfg['top_p']
self.repeats = cfg['repeats']
self.repeat_counter = 0
def train_play(self, mortal, dqn, device):
torch.backends.cudnn.benchmark = False
engine_chal = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = self.chal_version,
boltzmann_epsilon = self.boltzmann_epsilon,
boltzmann_temp = self.boltzmann_temp,
top_p = self.top_p,
device = device,
enable_amp = True,
name = 'trainee',
)
if path.isdir(self.log_dir):
shutil.rmtree(self.log_dir)
env = OneVsThree(
disable_progress_bar = False,
log_dir = self.log_dir,
)
rankings = env.py_vs_py(
challenger = engine_chal,
champion = self.baseline_engine,
seed_start = (self.train_seed, self.train_key),
seed_count = self.seed_count,
)
self.repeat_counter += 1
if self.repeat_counter == self.repeats:
self.train_seed += self.seed_count
self.repeat_counter = 0
rankings = np.array(rankings)
file_list = list(map(lambda p: path.join(self.log_dir, p), os.listdir(self.log_dir)))
torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark']
return rankings, file_list