def train(): import prelude import logging import sys import os import gc import gzip import json import shutil import random import torch from os import path from glob import glob from datetime import datetime from itertools import chain from torch import optim, nn from torch.amp import GradScaler from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter from common import submit_param, parameter_count, drain, filtered_trimmed_lines, tqdm from player import TestPlayer from dataloader import FileDatasetsIter, worker_init_fn from lr_scheduler import LinearWarmUpCosineAnnealingLR from model import Brain, DQN, AuxNet from libriichi.consts import obs_shape from config import config version = config['control']['version'] online = config['control']['online'] batch_size = config['control']['batch_size'] opt_step_every = config['control']['opt_step_every'] save_every = config['control']['save_every'] test_every = config['control']['test_every'] submit_every = config['control']['submit_every'] test_games = config['test_play']['games'] min_q_weight = config['cql']['min_q_weight'] next_rank_weight = config['aux']['next_rank_weight'] assert save_every % opt_step_every == 0 assert test_every % save_every == 0 device = torch.device(config['control']['device']) torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark'] enable_amp = config['control']['enable_amp'] enable_compile = config['control']['enable_compile'] pts = config['env']['pts'] gamma = config['env']['gamma'] file_batch_size = config['dataset']['file_batch_size'] reserve_ratio = config['dataset']['reserve_ratio'] num_workers = config['dataset']['num_workers'] num_epochs = config['dataset']['num_epochs'] enable_augmentation = config['dataset']['enable_augmentation'] augmented_first = config['dataset']['augmented_first'] eps = config['optim']['eps'] betas = config['optim']['betas'] weight_decay = config['optim']['weight_decay'] max_grad_norm = config['optim']['max_grad_norm'] mortal = Brain(version=version, **config['resnet']).to(device) dqn = DQN(version=version).to(device) aux_net = AuxNet((4,)).to(device) all_models = (mortal, dqn, aux_net) if enable_compile: for m in all_models: m.compile() logging.info(f'version: {version}') logging.info(f'obs shape: {obs_shape(version)}') logging.info(f'mortal params: {parameter_count(mortal):,}') logging.info(f'dqn params: {parameter_count(dqn):,}') logging.info(f'aux params: {parameter_count(aux_net):,}') mortal.freeze_bn(config['freeze_bn']['mortal']) decay_params = [] no_decay_params = [] for model in all_models: params_dict = {} to_decay = set() for mod_name, mod in model.named_modules(): for name, param in mod.named_parameters(prefix=mod_name, recurse=False): params_dict[name] = param if isinstance(mod, (nn.Linear, nn.Conv1d)) and name.endswith('weight'): to_decay.add(name) decay_params.extend(params_dict[name] for name in sorted(to_decay)) no_decay_params.extend(params_dict[name] for name in sorted(params_dict.keys() - to_decay)) param_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': no_decay_params}, ] optimizer = optim.AdamW(param_groups, lr=1, weight_decay=0, betas=betas, eps=eps) scheduler = LinearWarmUpCosineAnnealingLR(optimizer, **config['optim']['scheduler']) scaler = GradScaler(device.type, enabled=enable_amp) test_player = TestPlayer() best_perf = { 'avg_rank': 4., 'avg_pt': -135., } steps = 0 state_file = config['control']['state_file'] best_state_file = config['control']['best_state_file'] if path.exists(state_file): state = torch.load(state_file, weights_only=True, map_location=device) timestamp = datetime.fromtimestamp(state['timestamp']).strftime('%Y-%m-%d %H:%M:%S') logging.info(f'loaded: {timestamp}') mortal.load_state_dict(state['mortal']) dqn.load_state_dict(state['current_dqn']) aux_net.load_state_dict(state['aux_net']) if not online or state['config']['control']['online']: optimizer.load_state_dict(state['optimizer']) scheduler.load_state_dict(state['scheduler']) scaler.load_state_dict(state['scaler']) best_perf = state['best_perf'] steps = state['steps'] optimizer.zero_grad(set_to_none=True) mse = nn.MSELoss() ce = nn.CrossEntropyLoss() if device.type == 'cuda': logging.info(f'device: {device} ({torch.cuda.get_device_name(device)})') else: logging.info(f'device: {device}') if online: submit_param(mortal, dqn, is_idle=True) logging.info('param has been submitted') writer = SummaryWriter(config['control']['tensorboard_dir']) stats = { 'dqn_loss': 0, 'cql_loss': 0, 'next_rank_loss': 0, } all_q = torch.zeros((save_every, batch_size), device=device, dtype=torch.float32) all_q_target = torch.zeros((save_every, batch_size), device=device, dtype=torch.float32) idx = 0 def train_epoch(): nonlocal steps nonlocal idx player_names = [] if online: player_names = ['trainee'] dirname = drain() file_list = list(map(lambda p: path.join(dirname, p), os.listdir(dirname))) else: player_names_set = set() for filename in config['dataset']['player_names_files']: with open(filename) as f: player_names_set.update(filtered_trimmed_lines(f)) player_names = list(player_names_set) logging.info(f'loaded {len(player_names):,} players') file_index = config['dataset']['file_index'] if path.exists(file_index): index = torch.load(file_index, weights_only=True) file_list = index['file_list'] else: logging.info('building file index...') file_list = [] for pat in config['dataset']['globs']: file_list.extend(glob(pat, recursive=True)) if len(player_names_set) > 0: filtered = [] for filename in tqdm(file_list, unit='file'): with gzip.open(filename, 'rt') as f: start = json.loads(next(f)) if not set(start['names']).isdisjoint(player_names_set): filtered.append(filename) file_list = filtered file_list.sort(reverse=True) torch.save({'file_list': file_list}, file_index) logging.info(f'file list size: {len(file_list):,}') before_next_test_play = (test_every - steps % test_every) % test_every logging.info(f'total steps: {steps:,} (~{before_next_test_play:,})') if num_workers > 1: random.shuffle(file_list) file_data = FileDatasetsIter( version = version, file_list = file_list, pts = pts, file_batch_size = file_batch_size, reserve_ratio = reserve_ratio, player_names = player_names, num_epochs = num_epochs, enable_augmentation = enable_augmentation, augmented_first = augmented_first, ) data_loader = iter(DataLoader( dataset = file_data, batch_size = batch_size, drop_last = False, num_workers = num_workers, pin_memory = True, worker_init_fn = worker_init_fn, )) remaining_obs = [] remaining_actions = [] remaining_masks = [] remaining_steps_to_done = [] remaining_kyoku_rewards = [] remaining_player_ranks = [] remaining_bs = 0 pb = tqdm(total=save_every, desc='TRAIN', initial=steps % save_every) def train_batch(obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks): nonlocal steps nonlocal idx nonlocal pb obs = obs.to(dtype=torch.float32, device=device) actions = actions.to(dtype=torch.int64, device=device) masks = masks.to(dtype=torch.bool, device=device) steps_to_done = steps_to_done.to(dtype=torch.int64, device=device) kyoku_rewards = kyoku_rewards.to(dtype=torch.float64, device=device) player_ranks = player_ranks.to(dtype=torch.int64, device=device) assert masks[range(batch_size), actions].all() q_target_mc = gamma ** steps_to_done * kyoku_rewards q_target_mc = q_target_mc.to(torch.float32) with torch.autocast(device.type, enabled=enable_amp): phi = mortal(obs) q_out = dqn(phi, masks) q = q_out[range(batch_size), actions] dqn_loss = 0.5 * mse(q, q_target_mc) cql_loss = 0 if not online: cql_loss = q_out.logsumexp(-1).mean() - q.mean() next_rank_logits, = aux_net(phi) next_rank_loss = ce(next_rank_logits, player_ranks) loss = sum(( dqn_loss, cql_loss * min_q_weight, next_rank_loss * next_rank_weight, )) scaler.scale(loss / opt_step_every).backward() with torch.inference_mode(): stats['dqn_loss'] += dqn_loss if not online: stats['cql_loss'] += cql_loss stats['next_rank_loss'] += next_rank_loss all_q[idx] = q all_q_target[idx] = q_target_mc steps += 1 idx += 1 if idx % opt_step_every == 0: if max_grad_norm > 0: scaler.unscale_(optimizer) params = chain.from_iterable(g['params'] for g in optimizer.param_groups) clip_grad_norm_(params, max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) scheduler.step() pb.update(1) if online and steps % submit_every == 0: submit_param(mortal, dqn, is_idle=False) logging.info('param has been submitted') if steps % save_every == 0: pb.close() # downsample to reduce tensorboard event size all_q_1d = all_q.cpu().numpy().flatten()[::128] all_q_target_1d = all_q_target.cpu().numpy().flatten()[::128] writer.add_scalar('loss/dqn_loss', stats['dqn_loss'] / save_every, steps) if not online: writer.add_scalar('loss/cql_loss', stats['cql_loss'] / save_every, steps) writer.add_scalar('loss/next_rank_loss', stats['next_rank_loss'] / save_every, steps) writer.add_scalar('hparam/lr', scheduler.get_last_lr()[0], steps) writer.add_histogram('q_predicted', all_q_1d, steps) writer.add_histogram('q_target', all_q_target_1d, steps) writer.flush() for k in stats: stats[k] = 0 idx = 0 before_next_test_play = (test_every - steps % test_every) % test_every logging.info(f'total steps: {steps:,} (~{before_next_test_play:,})') state = { 'mortal': mortal.state_dict(), 'current_dqn': dqn.state_dict(), 'aux_net': aux_net.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), 'scaler': scaler.state_dict(), 'steps': steps, 'timestamp': datetime.now().timestamp(), 'best_perf': best_perf, 'config': config, } torch.save(state, state_file) if online and steps % submit_every != 0: submit_param(mortal, dqn, is_idle=False) logging.info('param has been submitted') if steps % test_every == 0: stat = test_player.test_play(test_games // 4, mortal, dqn, device) mortal.train() dqn.train() avg_pt = stat.avg_pt([90, 45, 0, -135]) # for display only, never used in training better = avg_pt >= best_perf['avg_pt'] and stat.avg_rank <= best_perf['avg_rank'] if better: past_best = best_perf.copy() best_perf['avg_pt'] = avg_pt best_perf['avg_rank'] = stat.avg_rank logging.info(f'avg rank: {stat.avg_rank:.6}') logging.info(f'avg pt: {avg_pt:.6}') writer.add_scalar('test_play/avg_ranking', stat.avg_rank, steps) writer.add_scalar('test_play/avg_pt', avg_pt, steps) writer.add_scalars('test_play/ranking', { '1st': stat.rank_1_rate, '2nd': stat.rank_2_rate, '3rd': stat.rank_3_rate, '4th': stat.rank_4_rate, }, steps) writer.add_scalars('test_play/behavior', { 'agari': stat.agari_rate, 'houjuu': stat.houjuu_rate, 'fuuro': stat.fuuro_rate, 'riichi': stat.riichi_rate, }, steps) writer.add_scalars('test_play/agari_point', { 'overall': stat.avg_point_per_agari, 'riichi': stat.avg_point_per_riichi_agari, 'fuuro': stat.avg_point_per_fuuro_agari, 'dama': stat.avg_point_per_dama_agari, }, steps) writer.add_scalar('test_play/houjuu_point', stat.avg_point_per_houjuu, steps) writer.add_scalar('test_play/point_per_round', stat.avg_point_per_round, steps) writer.add_scalars('test_play/key_step', { 'agari_jun': stat.avg_agari_jun, 'houjuu_jun': stat.avg_houjuu_jun, 'riichi_jun': stat.avg_riichi_jun, }, steps) writer.add_scalars('test_play/riichi', { 'agari_after_riichi': stat.agari_rate_after_riichi, 'houjuu_after_riichi': stat.houjuu_rate_after_riichi, 'chasing_riichi': stat.chasing_riichi_rate, 'riichi_chased': stat.riichi_chased_rate, }, steps) writer.add_scalar('test_play/riichi_point', stat.avg_riichi_point, steps) writer.add_scalars('test_play/fuuro', { 'agari_after_fuuro': stat.agari_rate_after_fuuro, 'houjuu_after_fuuro': stat.houjuu_rate_after_fuuro, }, steps) writer.add_scalar('test_play/fuuro_num', stat.avg_fuuro_num, steps) writer.add_scalar('test_play/fuuro_point', stat.avg_fuuro_point, steps) writer.flush() if better: torch.save(state, state_file) logging.info( 'a new record has been made, ' f'pt: {past_best["avg_pt"]:.4} -> {best_perf["avg_pt"]:.4}, ' f'rank: {past_best["avg_rank"]:.4} -> {best_perf["avg_rank"]:.4}, ' f'saving to {best_state_file}' ) shutil.copy(state_file, best_state_file) if online: # BUG: This is a bug with unknown reason. When training # in online mode, the process will get stuck here. This # is the reason why `main` spawns a sub process to train # in online mode instead of going for training directly. sys.exit(0) pb = tqdm(total=save_every, desc='TRAIN') for obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks in data_loader: bs = obs.shape[0] if bs != batch_size: remaining_obs.append(obs) remaining_actions.append(actions) remaining_masks.append(masks) remaining_steps_to_done.append(steps_to_done) remaining_kyoku_rewards.append(kyoku_rewards) remaining_player_ranks.append(player_ranks) remaining_bs += bs continue train_batch(obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks) remaining_batches = remaining_bs // batch_size if remaining_batches > 0: obs = torch.cat(remaining_obs, dim=0) actions = torch.cat(remaining_actions, dim=0) masks = torch.cat(remaining_masks, dim=0) steps_to_done = torch.cat(remaining_steps_to_done, dim=0) kyoku_rewards = torch.cat(remaining_kyoku_rewards, dim=0) player_ranks = torch.cat(remaining_player_ranks, dim=0) start = 0 end = batch_size while end <= remaining_bs: train_batch( obs[start:end], actions[start:end], masks[start:end], steps_to_done[start:end], kyoku_rewards[start:end], player_ranks[start:end], ) start = end end += batch_size pb.close() if online: submit_param(mortal, dqn, is_idle=True) logging.info('param has been submitted') while True: train_epoch() gc.collect() # torch.cuda.empty_cache() # torch.cuda.synchronize() if not online: # only run one epoch for offline for easier control break def main(): import os import sys import time from subprocess import Popen from config import config # do not set this env manually is_sub_proc_key = 'MORTAL_IS_SUB_PROC' online = config['control']['online'] if not online or os.environ.get(is_sub_proc_key, '0') == '1': train() return cmd = (sys.executable, __file__) env = { is_sub_proc_key: '1', **os.environ.copy(), } while True: child = Popen( cmd, stdin = sys.stdin, stdout = sys.stdout, stderr = sys.stderr, env = env, ) if (code := child.wait()) != 0: sys.exit(code) time.sleep(3) if __name__ == '__main__': try: main() except KeyboardInterrupt: pass