import prelude import random import torch import logging from os import path from glob import glob from datetime import datetime from torch import optim from torch.nn import functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence from torch.utils.data import DataLoader, IterableDataset from torch.utils.tensorboard import SummaryWriter from model import GRP from libriichi.dataset import Grp from common import tqdm from config import config class GrpFileDatasetsIter(IterableDataset): def __init__( self, file_list, file_batch_size = 50, cycle = False ): super().__init__() self.file_list = file_list self.file_batch_size = file_batch_size self.cycle = cycle self.buffer = [] self.iterator = None def build_iter(self): while True: random.shuffle(self.file_list) for start_idx in range(0, len(self.file_list), self.file_batch_size): self.populate_buffer(start_idx) buffer_size = len(self.buffer) for i in random.sample(range(buffer_size), buffer_size): yield self.buffer[i] self.buffer.clear() if not self.cycle: break def populate_buffer(self, start_idx): file_list = self.file_list[start_idx:start_idx + self.file_batch_size] data = Grp.load_gz_log_files(file_list) for game in data: feature = game.take_feature() rank_by_player = game.take_rank_by_player() for i in range(feature.shape[0]): inputs_seq = torch.as_tensor(feature[:i + 1], dtype=torch.float64) self.buffer.append(( inputs_seq, rank_by_player, )) def __iter__(self): if self.iterator is None: self.iterator = self.build_iter() return self.iterator def collate(batch): inputs = [] lengths = [] rank_by_players = [] for inputs_seq, rank_by_player in batch: inputs.append(inputs_seq) lengths.append(len(inputs_seq)) rank_by_players.append(rank_by_player) lengths = torch.tensor(lengths) rank_by_players = torch.tensor(rank_by_players, dtype=torch.int64, pin_memory=True) padded = pad_sequence(inputs, batch_first=True) packed_inputs = pack_padded_sequence(padded, lengths, batch_first=True, enforce_sorted=False) packed_inputs.pin_memory() return packed_inputs, rank_by_players def train(): cfg = config['grp'] batch_size = cfg['control']['batch_size'] save_every = cfg['control']['save_every'] val_steps = cfg['control']['val_steps'] device = torch.device(cfg['control']['device']) torch.backends.cudnn.benchmark = cfg['control']['enable_cudnn_benchmark'] if device.type == 'cuda': logging.info(f'device: {device} ({torch.cuda.get_device_name(device)})') else: logging.info(f'device: {device}') grp = GRP(**cfg['network']).to(device) optimizer = optim.AdamW(grp.parameters()) state_file = cfg['state_file'] if path.exists(state_file): state = torch.load(state_file, weights_only=True, map_location=device) timestamp = datetime.fromtimestamp(state['timestamp']).strftime('%Y-%m-%d %H:%M:%S') logging.info(f'loaded: {timestamp}') grp.load_state_dict(state['model']) optimizer.load_state_dict(state['optimizer']) steps = state['steps'] else: steps = 0 lr = cfg['optim']['lr'] optimizer.param_groups[0]['lr'] = lr file_index = cfg['dataset']['file_index'] train_globs = cfg['dataset']['train_globs'] val_globs = cfg['dataset']['val_globs'] if path.exists(file_index): index = torch.load(file_index, weights_only=True) train_file_list = index['train_file_list'] val_file_list = index['val_file_list'] else: logging.info('building file index...') train_file_list = [] val_file_list = [] for pat in train_globs: train_file_list.extend(glob(pat, recursive=True)) for pat in val_globs: val_file_list.extend(glob(pat, recursive=True)) train_file_list.sort(reverse=True) val_file_list.sort(reverse=True) torch.save({'train_file_list': train_file_list, 'val_file_list': val_file_list}, file_index) writer = SummaryWriter(cfg['control']['tensorboard_dir']) train_file_data = GrpFileDatasetsIter( file_list = train_file_list, file_batch_size = cfg['dataset']['file_batch_size'], cycle = True, ) train_data_loader = iter(DataLoader( dataset = train_file_data, batch_size = batch_size, drop_last = True, num_workers = 1, collate_fn = collate, )) val_file_data = GrpFileDatasetsIter( file_list = val_file_list, file_batch_size = cfg['dataset']['file_batch_size'], cycle = True, ) val_data_loader = iter(DataLoader( dataset = val_file_data, batch_size = batch_size, drop_last = True, num_workers = 1, collate_fn = collate, )) stats = { 'train_loss': 0, 'train_acc': 0, 'val_loss': 0, 'val_acc': 0, } logging.info(f'train file list size: {len(train_file_list):,}') logging.info(f'val file list size: {len(val_file_list):,}') approx_percent = steps * batch_size / (len(train_file_list) * 10) * 100 logging.info(f'total steps: {steps:,} est. {approx_percent:6.3f}%') pb = tqdm(total=save_every, desc='TRAIN') for inputs, rank_by_players in train_data_loader: inputs = inputs.to(dtype=torch.float64, device=device) rank_by_players = rank_by_players.to(dtype=torch.int64, device=device) logits = grp.forward_packed(inputs) labels = grp.get_label(rank_by_players) loss = F.cross_entropy(logits, labels) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step() with torch.inference_mode(): stats['train_loss'] += loss stats['train_acc'] += (logits.argmax(-1) == labels).to(torch.float64).mean() steps += 1 pb.update(1) if steps % save_every == 0: pb.close() with torch.inference_mode(): grp.eval() pb = tqdm(total=val_steps, desc='VAL') for idx, (inputs, rank_by_players) in enumerate(val_data_loader): if idx == val_steps: break inputs = inputs.to(dtype=torch.float64, device=device) rank_by_players = rank_by_players.to(dtype=torch.int64, device=device) logits = grp.forward_packed(inputs) labels = grp.get_label(rank_by_players) loss = F.cross_entropy(logits, labels) stats['val_loss'] += loss stats['val_acc'] += (logits.argmax(-1) == labels).to(torch.float64).mean() pb.update(1) pb.close() grp.train() writer.add_scalars('loss', { 'train': stats['train_loss'] / save_every, 'val': stats['val_loss'] / val_steps, }, steps) writer.add_scalars('acc', { 'train': stats['train_acc'] / save_every, 'val': stats['val_acc'] / val_steps, }, steps) writer.add_scalar('lr', lr, steps) writer.flush() for k in stats: stats[k] = 0 approx_percent = steps * batch_size / (len(train_file_list) * 10) * 100 logging.info(f'total steps: {steps:,} est. {approx_percent:6.3f}%') state = { 'model': grp.state_dict(), 'optimizer': optimizer.state_dict(), 'steps': steps, 'timestamp': datetime.now().timestamp(), } torch.save(state, state_file) pb = tqdm(total=save_every, desc='TRAIN') pb.close() if __name__ == '__main__': try: train() except KeyboardInterrupt: pass