Mortal
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled

This commit is contained in:
e2hang
2025-10-07 20:30:03 +08:00
commit b7a7d7404a
441 changed files with 23367 additions and 0 deletions

243
mortal/train_grp.py Normal file
View File

@@ -0,0 +1,243 @@
import prelude
import random
import torch
import logging
from os import path
from glob import glob
from datetime import datetime
from torch import optim
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.tensorboard import SummaryWriter
from model import GRP
from libriichi.dataset import Grp
from common import tqdm
from config import config
class GrpFileDatasetsIter(IterableDataset):
def __init__(
self,
file_list,
file_batch_size = 50,
cycle = False
):
super().__init__()
self.file_list = file_list
self.file_batch_size = file_batch_size
self.cycle = cycle
self.buffer = []
self.iterator = None
def build_iter(self):
while True:
random.shuffle(self.file_list)
for start_idx in range(0, len(self.file_list), self.file_batch_size):
self.populate_buffer(start_idx)
buffer_size = len(self.buffer)
for i in random.sample(range(buffer_size), buffer_size):
yield self.buffer[i]
self.buffer.clear()
if not self.cycle:
break
def populate_buffer(self, start_idx):
file_list = self.file_list[start_idx:start_idx + self.file_batch_size]
data = Grp.load_gz_log_files(file_list)
for game in data:
feature = game.take_feature()
rank_by_player = game.take_rank_by_player()
for i in range(feature.shape[0]):
inputs_seq = torch.as_tensor(feature[:i + 1], dtype=torch.float64)
self.buffer.append((
inputs_seq,
rank_by_player,
))
def __iter__(self):
if self.iterator is None:
self.iterator = self.build_iter()
return self.iterator
def collate(batch):
inputs = []
lengths = []
rank_by_players = []
for inputs_seq, rank_by_player in batch:
inputs.append(inputs_seq)
lengths.append(len(inputs_seq))
rank_by_players.append(rank_by_player)
lengths = torch.tensor(lengths)
rank_by_players = torch.tensor(rank_by_players, dtype=torch.int64, pin_memory=True)
padded = pad_sequence(inputs, batch_first=True)
packed_inputs = pack_padded_sequence(padded, lengths, batch_first=True, enforce_sorted=False)
packed_inputs.pin_memory()
return packed_inputs, rank_by_players
def train():
cfg = config['grp']
batch_size = cfg['control']['batch_size']
save_every = cfg['control']['save_every']
val_steps = cfg['control']['val_steps']
device = torch.device(cfg['control']['device'])
torch.backends.cudnn.benchmark = cfg['control']['enable_cudnn_benchmark']
if device.type == 'cuda':
logging.info(f'device: {device} ({torch.cuda.get_device_name(device)})')
else:
logging.info(f'device: {device}')
grp = GRP(**cfg['network']).to(device)
optimizer = optim.AdamW(grp.parameters())
state_file = cfg['state_file']
if path.exists(state_file):
state = torch.load(state_file, weights_only=True, map_location=device)
timestamp = datetime.fromtimestamp(state['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
logging.info(f'loaded: {timestamp}')
grp.load_state_dict(state['model'])
optimizer.load_state_dict(state['optimizer'])
steps = state['steps']
else:
steps = 0
lr = cfg['optim']['lr']
optimizer.param_groups[0]['lr'] = lr
file_index = cfg['dataset']['file_index']
train_globs = cfg['dataset']['train_globs']
val_globs = cfg['dataset']['val_globs']
if path.exists(file_index):
index = torch.load(file_index, weights_only=True)
train_file_list = index['train_file_list']
val_file_list = index['val_file_list']
else:
logging.info('building file index...')
train_file_list = []
val_file_list = []
for pat in train_globs:
train_file_list.extend(glob(pat, recursive=True))
for pat in val_globs:
val_file_list.extend(glob(pat, recursive=True))
train_file_list.sort(reverse=True)
val_file_list.sort(reverse=True)
torch.save({'train_file_list': train_file_list, 'val_file_list': val_file_list}, file_index)
writer = SummaryWriter(cfg['control']['tensorboard_dir'])
train_file_data = GrpFileDatasetsIter(
file_list = train_file_list,
file_batch_size = cfg['dataset']['file_batch_size'],
cycle = True,
)
train_data_loader = iter(DataLoader(
dataset = train_file_data,
batch_size = batch_size,
drop_last = True,
num_workers = 1,
collate_fn = collate,
))
val_file_data = GrpFileDatasetsIter(
file_list = val_file_list,
file_batch_size = cfg['dataset']['file_batch_size'],
cycle = True,
)
val_data_loader = iter(DataLoader(
dataset = val_file_data,
batch_size = batch_size,
drop_last = True,
num_workers = 1,
collate_fn = collate,
))
stats = {
'train_loss': 0,
'train_acc': 0,
'val_loss': 0,
'val_acc': 0,
}
logging.info(f'train file list size: {len(train_file_list):,}')
logging.info(f'val file list size: {len(val_file_list):,}')
approx_percent = steps * batch_size / (len(train_file_list) * 10) * 100
logging.info(f'total steps: {steps:,} est. {approx_percent:6.3f}%')
pb = tqdm(total=save_every, desc='TRAIN')
for inputs, rank_by_players in train_data_loader:
inputs = inputs.to(dtype=torch.float64, device=device)
rank_by_players = rank_by_players.to(dtype=torch.int64, device=device)
logits = grp.forward_packed(inputs)
labels = grp.get_label(rank_by_players)
loss = F.cross_entropy(logits, labels)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
with torch.inference_mode():
stats['train_loss'] += loss
stats['train_acc'] += (logits.argmax(-1) == labels).to(torch.float64).mean()
steps += 1
pb.update(1)
if steps % save_every == 0:
pb.close()
with torch.inference_mode():
grp.eval()
pb = tqdm(total=val_steps, desc='VAL')
for idx, (inputs, rank_by_players) in enumerate(val_data_loader):
if idx == val_steps:
break
inputs = inputs.to(dtype=torch.float64, device=device)
rank_by_players = rank_by_players.to(dtype=torch.int64, device=device)
logits = grp.forward_packed(inputs)
labels = grp.get_label(rank_by_players)
loss = F.cross_entropy(logits, labels)
stats['val_loss'] += loss
stats['val_acc'] += (logits.argmax(-1) == labels).to(torch.float64).mean()
pb.update(1)
pb.close()
grp.train()
writer.add_scalars('loss', {
'train': stats['train_loss'] / save_every,
'val': stats['val_loss'] / val_steps,
}, steps)
writer.add_scalars('acc', {
'train': stats['train_acc'] / save_every,
'val': stats['val_acc'] / val_steps,
}, steps)
writer.add_scalar('lr', lr, steps)
writer.flush()
for k in stats:
stats[k] = 0
approx_percent = steps * batch_size / (len(train_file_list) * 10) * 100
logging.info(f'total steps: {steps:,} est. {approx_percent:6.3f}%')
state = {
'model': grp.state_dict(),
'optimizer': optimizer.state_dict(),
'steps': steps,
'timestamp': datetime.now().timestamp(),
}
torch.save(state, state_file)
pb = tqdm(total=save_every, desc='TRAIN')
pb.close()
if __name__ == '__main__':
try:
train()
except KeyboardInterrupt:
pass