Files
Mortal-Copied/mortal/train.py
e2hang b7a7d7404a
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled
Mortal
2025-10-07 20:30:03 +08:00

474 lines
19 KiB
Python

def train():
import prelude
import logging
import sys
import os
import gc
import gzip
import json
import shutil
import random
import torch
from os import path
from glob import glob
from datetime import datetime
from itertools import chain
from torch import optim, nn
from torch.amp import GradScaler
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from common import submit_param, parameter_count, drain, filtered_trimmed_lines, tqdm
from player import TestPlayer
from dataloader import FileDatasetsIter, worker_init_fn
from lr_scheduler import LinearWarmUpCosineAnnealingLR
from model import Brain, DQN, AuxNet
from libriichi.consts import obs_shape
from config import config
version = config['control']['version']
online = config['control']['online']
batch_size = config['control']['batch_size']
opt_step_every = config['control']['opt_step_every']
save_every = config['control']['save_every']
test_every = config['control']['test_every']
submit_every = config['control']['submit_every']
test_games = config['test_play']['games']
min_q_weight = config['cql']['min_q_weight']
next_rank_weight = config['aux']['next_rank_weight']
assert save_every % opt_step_every == 0
assert test_every % save_every == 0
device = torch.device(config['control']['device'])
torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark']
enable_amp = config['control']['enable_amp']
enable_compile = config['control']['enable_compile']
pts = config['env']['pts']
gamma = config['env']['gamma']
file_batch_size = config['dataset']['file_batch_size']
reserve_ratio = config['dataset']['reserve_ratio']
num_workers = config['dataset']['num_workers']
num_epochs = config['dataset']['num_epochs']
enable_augmentation = config['dataset']['enable_augmentation']
augmented_first = config['dataset']['augmented_first']
eps = config['optim']['eps']
betas = config['optim']['betas']
weight_decay = config['optim']['weight_decay']
max_grad_norm = config['optim']['max_grad_norm']
mortal = Brain(version=version, **config['resnet']).to(device)
dqn = DQN(version=version).to(device)
aux_net = AuxNet((4,)).to(device)
all_models = (mortal, dqn, aux_net)
if enable_compile:
for m in all_models:
m.compile()
logging.info(f'version: {version}')
logging.info(f'obs shape: {obs_shape(version)}')
logging.info(f'mortal params: {parameter_count(mortal):,}')
logging.info(f'dqn params: {parameter_count(dqn):,}')
logging.info(f'aux params: {parameter_count(aux_net):,}')
mortal.freeze_bn(config['freeze_bn']['mortal'])
decay_params = []
no_decay_params = []
for model in all_models:
params_dict = {}
to_decay = set()
for mod_name, mod in model.named_modules():
for name, param in mod.named_parameters(prefix=mod_name, recurse=False):
params_dict[name] = param
if isinstance(mod, (nn.Linear, nn.Conv1d)) and name.endswith('weight'):
to_decay.add(name)
decay_params.extend(params_dict[name] for name in sorted(to_decay))
no_decay_params.extend(params_dict[name] for name in sorted(params_dict.keys() - to_decay))
param_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': no_decay_params},
]
optimizer = optim.AdamW(param_groups, lr=1, weight_decay=0, betas=betas, eps=eps)
scheduler = LinearWarmUpCosineAnnealingLR(optimizer, **config['optim']['scheduler'])
scaler = GradScaler(device.type, enabled=enable_amp)
test_player = TestPlayer()
best_perf = {
'avg_rank': 4.,
'avg_pt': -135.,
}
steps = 0
state_file = config['control']['state_file']
best_state_file = config['control']['best_state_file']
if path.exists(state_file):
state = torch.load(state_file, weights_only=True, map_location=device)
timestamp = datetime.fromtimestamp(state['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
logging.info(f'loaded: {timestamp}')
mortal.load_state_dict(state['mortal'])
dqn.load_state_dict(state['current_dqn'])
aux_net.load_state_dict(state['aux_net'])
if not online or state['config']['control']['online']:
optimizer.load_state_dict(state['optimizer'])
scheduler.load_state_dict(state['scheduler'])
scaler.load_state_dict(state['scaler'])
best_perf = state['best_perf']
steps = state['steps']
optimizer.zero_grad(set_to_none=True)
mse = nn.MSELoss()
ce = nn.CrossEntropyLoss()
if device.type == 'cuda':
logging.info(f'device: {device} ({torch.cuda.get_device_name(device)})')
else:
logging.info(f'device: {device}')
if online:
submit_param(mortal, dqn, is_idle=True)
logging.info('param has been submitted')
writer = SummaryWriter(config['control']['tensorboard_dir'])
stats = {
'dqn_loss': 0,
'cql_loss': 0,
'next_rank_loss': 0,
}
all_q = torch.zeros((save_every, batch_size), device=device, dtype=torch.float32)
all_q_target = torch.zeros((save_every, batch_size), device=device, dtype=torch.float32)
idx = 0
def train_epoch():
nonlocal steps
nonlocal idx
player_names = []
if online:
player_names = ['trainee']
dirname = drain()
file_list = list(map(lambda p: path.join(dirname, p), os.listdir(dirname)))
else:
player_names_set = set()
for filename in config['dataset']['player_names_files']:
with open(filename) as f:
player_names_set.update(filtered_trimmed_lines(f))
player_names = list(player_names_set)
logging.info(f'loaded {len(player_names):,} players')
file_index = config['dataset']['file_index']
if path.exists(file_index):
index = torch.load(file_index, weights_only=True)
file_list = index['file_list']
else:
logging.info('building file index...')
file_list = []
for pat in config['dataset']['globs']:
file_list.extend(glob(pat, recursive=True))
if len(player_names_set) > 0:
filtered = []
for filename in tqdm(file_list, unit='file'):
with gzip.open(filename, 'rt') as f:
start = json.loads(next(f))
if not set(start['names']).isdisjoint(player_names_set):
filtered.append(filename)
file_list = filtered
file_list.sort(reverse=True)
torch.save({'file_list': file_list}, file_index)
logging.info(f'file list size: {len(file_list):,}')
before_next_test_play = (test_every - steps % test_every) % test_every
logging.info(f'total steps: {steps:,} (~{before_next_test_play:,})')
if num_workers > 1:
random.shuffle(file_list)
file_data = FileDatasetsIter(
version = version,
file_list = file_list,
pts = pts,
file_batch_size = file_batch_size,
reserve_ratio = reserve_ratio,
player_names = player_names,
num_epochs = num_epochs,
enable_augmentation = enable_augmentation,
augmented_first = augmented_first,
)
data_loader = iter(DataLoader(
dataset = file_data,
batch_size = batch_size,
drop_last = False,
num_workers = num_workers,
pin_memory = True,
worker_init_fn = worker_init_fn,
))
remaining_obs = []
remaining_actions = []
remaining_masks = []
remaining_steps_to_done = []
remaining_kyoku_rewards = []
remaining_player_ranks = []
remaining_bs = 0
pb = tqdm(total=save_every, desc='TRAIN', initial=steps % save_every)
def train_batch(obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks):
nonlocal steps
nonlocal idx
nonlocal pb
obs = obs.to(dtype=torch.float32, device=device)
actions = actions.to(dtype=torch.int64, device=device)
masks = masks.to(dtype=torch.bool, device=device)
steps_to_done = steps_to_done.to(dtype=torch.int64, device=device)
kyoku_rewards = kyoku_rewards.to(dtype=torch.float64, device=device)
player_ranks = player_ranks.to(dtype=torch.int64, device=device)
assert masks[range(batch_size), actions].all()
q_target_mc = gamma ** steps_to_done * kyoku_rewards
q_target_mc = q_target_mc.to(torch.float32)
with torch.autocast(device.type, enabled=enable_amp):
phi = mortal(obs)
q_out = dqn(phi, masks)
q = q_out[range(batch_size), actions]
dqn_loss = 0.5 * mse(q, q_target_mc)
cql_loss = 0
if not online:
cql_loss = q_out.logsumexp(-1).mean() - q.mean()
next_rank_logits, = aux_net(phi)
next_rank_loss = ce(next_rank_logits, player_ranks)
loss = sum((
dqn_loss,
cql_loss * min_q_weight,
next_rank_loss * next_rank_weight,
))
scaler.scale(loss / opt_step_every).backward()
with torch.inference_mode():
stats['dqn_loss'] += dqn_loss
if not online:
stats['cql_loss'] += cql_loss
stats['next_rank_loss'] += next_rank_loss
all_q[idx] = q
all_q_target[idx] = q_target_mc
steps += 1
idx += 1
if idx % opt_step_every == 0:
if max_grad_norm > 0:
scaler.unscale_(optimizer)
params = chain.from_iterable(g['params'] for g in optimizer.param_groups)
clip_grad_norm_(params, max_grad_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
pb.update(1)
if online and steps % submit_every == 0:
submit_param(mortal, dqn, is_idle=False)
logging.info('param has been submitted')
if steps % save_every == 0:
pb.close()
# downsample to reduce tensorboard event size
all_q_1d = all_q.cpu().numpy().flatten()[::128]
all_q_target_1d = all_q_target.cpu().numpy().flatten()[::128]
writer.add_scalar('loss/dqn_loss', stats['dqn_loss'] / save_every, steps)
if not online:
writer.add_scalar('loss/cql_loss', stats['cql_loss'] / save_every, steps)
writer.add_scalar('loss/next_rank_loss', stats['next_rank_loss'] / save_every, steps)
writer.add_scalar('hparam/lr', scheduler.get_last_lr()[0], steps)
writer.add_histogram('q_predicted', all_q_1d, steps)
writer.add_histogram('q_target', all_q_target_1d, steps)
writer.flush()
for k in stats:
stats[k] = 0
idx = 0
before_next_test_play = (test_every - steps % test_every) % test_every
logging.info(f'total steps: {steps:,} (~{before_next_test_play:,})')
state = {
'mortal': mortal.state_dict(),
'current_dqn': dqn.state_dict(),
'aux_net': aux_net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'scaler': scaler.state_dict(),
'steps': steps,
'timestamp': datetime.now().timestamp(),
'best_perf': best_perf,
'config': config,
}
torch.save(state, state_file)
if online and steps % submit_every != 0:
submit_param(mortal, dqn, is_idle=False)
logging.info('param has been submitted')
if steps % test_every == 0:
stat = test_player.test_play(test_games // 4, mortal, dqn, device)
mortal.train()
dqn.train()
avg_pt = stat.avg_pt([90, 45, 0, -135]) # for display only, never used in training
better = avg_pt >= best_perf['avg_pt'] and stat.avg_rank <= best_perf['avg_rank']
if better:
past_best = best_perf.copy()
best_perf['avg_pt'] = avg_pt
best_perf['avg_rank'] = stat.avg_rank
logging.info(f'avg rank: {stat.avg_rank:.6}')
logging.info(f'avg pt: {avg_pt:.6}')
writer.add_scalar('test_play/avg_ranking', stat.avg_rank, steps)
writer.add_scalar('test_play/avg_pt', avg_pt, steps)
writer.add_scalars('test_play/ranking', {
'1st': stat.rank_1_rate,
'2nd': stat.rank_2_rate,
'3rd': stat.rank_3_rate,
'4th': stat.rank_4_rate,
}, steps)
writer.add_scalars('test_play/behavior', {
'agari': stat.agari_rate,
'houjuu': stat.houjuu_rate,
'fuuro': stat.fuuro_rate,
'riichi': stat.riichi_rate,
}, steps)
writer.add_scalars('test_play/agari_point', {
'overall': stat.avg_point_per_agari,
'riichi': stat.avg_point_per_riichi_agari,
'fuuro': stat.avg_point_per_fuuro_agari,
'dama': stat.avg_point_per_dama_agari,
}, steps)
writer.add_scalar('test_play/houjuu_point', stat.avg_point_per_houjuu, steps)
writer.add_scalar('test_play/point_per_round', stat.avg_point_per_round, steps)
writer.add_scalars('test_play/key_step', {
'agari_jun': stat.avg_agari_jun,
'houjuu_jun': stat.avg_houjuu_jun,
'riichi_jun': stat.avg_riichi_jun,
}, steps)
writer.add_scalars('test_play/riichi', {
'agari_after_riichi': stat.agari_rate_after_riichi,
'houjuu_after_riichi': stat.houjuu_rate_after_riichi,
'chasing_riichi': stat.chasing_riichi_rate,
'riichi_chased': stat.riichi_chased_rate,
}, steps)
writer.add_scalar('test_play/riichi_point', stat.avg_riichi_point, steps)
writer.add_scalars('test_play/fuuro', {
'agari_after_fuuro': stat.agari_rate_after_fuuro,
'houjuu_after_fuuro': stat.houjuu_rate_after_fuuro,
}, steps)
writer.add_scalar('test_play/fuuro_num', stat.avg_fuuro_num, steps)
writer.add_scalar('test_play/fuuro_point', stat.avg_fuuro_point, steps)
writer.flush()
if better:
torch.save(state, state_file)
logging.info(
'a new record has been made, '
f'pt: {past_best["avg_pt"]:.4} -> {best_perf["avg_pt"]:.4}, '
f'rank: {past_best["avg_rank"]:.4} -> {best_perf["avg_rank"]:.4}, '
f'saving to {best_state_file}'
)
shutil.copy(state_file, best_state_file)
if online:
# BUG: This is a bug with unknown reason. When training
# in online mode, the process will get stuck here. This
# is the reason why `main` spawns a sub process to train
# in online mode instead of going for training directly.
sys.exit(0)
pb = tqdm(total=save_every, desc='TRAIN')
for obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks in data_loader:
bs = obs.shape[0]
if bs != batch_size:
remaining_obs.append(obs)
remaining_actions.append(actions)
remaining_masks.append(masks)
remaining_steps_to_done.append(steps_to_done)
remaining_kyoku_rewards.append(kyoku_rewards)
remaining_player_ranks.append(player_ranks)
remaining_bs += bs
continue
train_batch(obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks)
remaining_batches = remaining_bs // batch_size
if remaining_batches > 0:
obs = torch.cat(remaining_obs, dim=0)
actions = torch.cat(remaining_actions, dim=0)
masks = torch.cat(remaining_masks, dim=0)
steps_to_done = torch.cat(remaining_steps_to_done, dim=0)
kyoku_rewards = torch.cat(remaining_kyoku_rewards, dim=0)
player_ranks = torch.cat(remaining_player_ranks, dim=0)
start = 0
end = batch_size
while end <= remaining_bs:
train_batch(
obs[start:end],
actions[start:end],
masks[start:end],
steps_to_done[start:end],
kyoku_rewards[start:end],
player_ranks[start:end],
)
start = end
end += batch_size
pb.close()
if online:
submit_param(mortal, dqn, is_idle=True)
logging.info('param has been submitted')
while True:
train_epoch()
gc.collect()
# torch.cuda.empty_cache()
# torch.cuda.synchronize()
if not online:
# only run one epoch for offline for easier control
break
def main():
import os
import sys
import time
from subprocess import Popen
from config import config
# do not set this env manually
is_sub_proc_key = 'MORTAL_IS_SUB_PROC'
online = config['control']['online']
if not online or os.environ.get(is_sub_proc_key, '0') == '1':
train()
return
cmd = (sys.executable, __file__)
env = {
is_sub_proc_key: '1',
**os.environ.copy(),
}
while True:
child = Popen(
cmd,
stdin = sys.stdin,
stdout = sys.stdout,
stderr = sys.stderr,
env = env,
)
if (code := child.wait()) != 0:
sys.exit(code)
time.sleep(3)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass