Mortal
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled

This commit is contained in:
e2hang
2025-10-07 20:30:03 +08:00
commit b7a7d7404a
441 changed files with 23367 additions and 0 deletions

103
mortal/one_vs_three.py Normal file
View File

@@ -0,0 +1,103 @@
import prelude
import numpy as np
import torch
import secrets
import os
from model import Brain, DQN
from engine import MortalEngine
from libriichi.arena import OneVsThree
from config import config
def main():
cfg = config['1v3']
games_per_iter = cfg['games_per_iter']
seeds_per_iter = games_per_iter // 4
iters = cfg['iters']
log_dir = cfg['log_dir']
use_akochan = cfg['akochan']['enabled']
if (key := cfg.get('seed_key', -1)) == -1:
key = secrets.randbits(64)
if use_akochan:
os.environ['AKOCHAN_DIR'] = cfg['akochan']['dir']
os.environ['AKOCHAN_TACTICS'] = cfg['akochan']['tactics']
else:
state = torch.load(cfg['champion']['state_file'], weights_only=True, map_location=torch.device('cpu'))
cham_cfg = state['config']
version = cham_cfg['control'].get('version', 1)
conv_channels = cham_cfg['resnet']['conv_channels']
num_blocks = cham_cfg['resnet']['num_blocks']
mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
dqn = DQN(version=version).eval()
mortal.load_state_dict(state['mortal'])
dqn.load_state_dict(state['current_dqn'])
if cfg['champion']['enable_compile']:
mortal.compile()
dqn.compile()
engine_cham = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = version,
device = torch.device(cfg['champion']['device']),
enable_amp = cfg['champion']['enable_amp'],
enable_rule_based_agari_guard = cfg['champion']['enable_rule_based_agari_guard'],
name = cfg['champion']['name'],
)
state = torch.load(cfg['challenger']['state_file'], weights_only=True, map_location=torch.device('cpu'))
chal_cfg = state['config']
version = chal_cfg['control'].get('version', 1)
conv_channels = chal_cfg['resnet']['conv_channels']
num_blocks = chal_cfg['resnet']['num_blocks']
mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
dqn = DQN(version=version).eval()
mortal.load_state_dict(state['mortal'])
dqn.load_state_dict(state['current_dqn'])
if cfg['challenger']['enable_compile']:
mortal.compile()
dqn.compile()
engine_chal = MortalEngine(
mortal,
dqn,
is_oracle = False,
version = version,
device = torch.device(cfg['challenger']['device']),
enable_amp = cfg['challenger']['enable_amp'],
enable_rule_based_agari_guard = cfg['challenger']['enable_rule_based_agari_guard'],
name = cfg['challenger']['name'],
)
seed_start = 10000
for i, seed in enumerate(range(seed_start, seed_start + seeds_per_iter * iters, seeds_per_iter)):
print('-' * 50)
print('#', i)
env = OneVsThree(
disable_progress_bar = False,
log_dir = log_dir,
)
if use_akochan:
rankings = env.ako_vs_py(
engine = engine_chal,
seed_start = (seed, key),
seed_count = seeds_per_iter,
)
else:
rankings = env.py_vs_py(
challenger = engine_chal,
champion = engine_cham,
seed_start = (seed, key),
seed_count = seeds_per_iter,
)
rankings = np.array(rankings)
avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
avg_pt = rankings @ np.array([90, 45, 0, -135]) / rankings.sum()
print(f'challenger rankings: {rankings} ({avg_rank}, {avg_pt}pt)')
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass