244 lines
8.1 KiB
Python
244 lines
8.1 KiB
Python
import prelude
|
|
|
|
import random
|
|
import torch
|
|
import logging
|
|
from os import path
|
|
from glob import glob
|
|
from datetime import datetime
|
|
from torch import optim
|
|
from torch.nn import functional as F
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
|
|
from torch.utils.data import DataLoader, IterableDataset
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
from model import GRP
|
|
from libriichi.dataset import Grp
|
|
from common import tqdm
|
|
from config import config
|
|
|
|
class GrpFileDatasetsIter(IterableDataset):
|
|
def __init__(
|
|
self,
|
|
file_list,
|
|
file_batch_size = 50,
|
|
cycle = False
|
|
):
|
|
super().__init__()
|
|
self.file_list = file_list
|
|
self.file_batch_size = file_batch_size
|
|
self.cycle = cycle
|
|
self.buffer = []
|
|
self.iterator = None
|
|
|
|
def build_iter(self):
|
|
while True:
|
|
random.shuffle(self.file_list)
|
|
for start_idx in range(0, len(self.file_list), self.file_batch_size):
|
|
self.populate_buffer(start_idx)
|
|
buffer_size = len(self.buffer)
|
|
for i in random.sample(range(buffer_size), buffer_size):
|
|
yield self.buffer[i]
|
|
self.buffer.clear()
|
|
if not self.cycle:
|
|
break
|
|
|
|
def populate_buffer(self, start_idx):
|
|
file_list = self.file_list[start_idx:start_idx + self.file_batch_size]
|
|
data = Grp.load_gz_log_files(file_list)
|
|
|
|
for game in data:
|
|
feature = game.take_feature()
|
|
rank_by_player = game.take_rank_by_player()
|
|
|
|
for i in range(feature.shape[0]):
|
|
inputs_seq = torch.as_tensor(feature[:i + 1], dtype=torch.float64)
|
|
self.buffer.append((
|
|
inputs_seq,
|
|
rank_by_player,
|
|
))
|
|
|
|
def __iter__(self):
|
|
if self.iterator is None:
|
|
self.iterator = self.build_iter()
|
|
return self.iterator
|
|
|
|
def collate(batch):
|
|
inputs = []
|
|
lengths = []
|
|
rank_by_players = []
|
|
for inputs_seq, rank_by_player in batch:
|
|
inputs.append(inputs_seq)
|
|
lengths.append(len(inputs_seq))
|
|
rank_by_players.append(rank_by_player)
|
|
|
|
lengths = torch.tensor(lengths)
|
|
rank_by_players = torch.tensor(rank_by_players, dtype=torch.int64, pin_memory=True)
|
|
|
|
padded = pad_sequence(inputs, batch_first=True)
|
|
packed_inputs = pack_padded_sequence(padded, lengths, batch_first=True, enforce_sorted=False)
|
|
packed_inputs.pin_memory()
|
|
|
|
return packed_inputs, rank_by_players
|
|
|
|
def train():
|
|
cfg = config['grp']
|
|
batch_size = cfg['control']['batch_size']
|
|
save_every = cfg['control']['save_every']
|
|
val_steps = cfg['control']['val_steps']
|
|
|
|
device = torch.device(cfg['control']['device'])
|
|
torch.backends.cudnn.benchmark = cfg['control']['enable_cudnn_benchmark']
|
|
if device.type == 'cuda':
|
|
logging.info(f'device: {device} ({torch.cuda.get_device_name(device)})')
|
|
else:
|
|
logging.info(f'device: {device}')
|
|
|
|
grp = GRP(**cfg['network']).to(device)
|
|
optimizer = optim.AdamW(grp.parameters())
|
|
|
|
state_file = cfg['state_file']
|
|
if path.exists(state_file):
|
|
state = torch.load(state_file, weights_only=True, map_location=device)
|
|
timestamp = datetime.fromtimestamp(state['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
|
|
logging.info(f'loaded: {timestamp}')
|
|
grp.load_state_dict(state['model'])
|
|
optimizer.load_state_dict(state['optimizer'])
|
|
steps = state['steps']
|
|
else:
|
|
steps = 0
|
|
|
|
lr = cfg['optim']['lr']
|
|
optimizer.param_groups[0]['lr'] = lr
|
|
|
|
file_index = cfg['dataset']['file_index']
|
|
train_globs = cfg['dataset']['train_globs']
|
|
val_globs = cfg['dataset']['val_globs']
|
|
if path.exists(file_index):
|
|
index = torch.load(file_index, weights_only=True)
|
|
train_file_list = index['train_file_list']
|
|
val_file_list = index['val_file_list']
|
|
else:
|
|
logging.info('building file index...')
|
|
train_file_list = []
|
|
val_file_list = []
|
|
for pat in train_globs:
|
|
train_file_list.extend(glob(pat, recursive=True))
|
|
for pat in val_globs:
|
|
val_file_list.extend(glob(pat, recursive=True))
|
|
train_file_list.sort(reverse=True)
|
|
val_file_list.sort(reverse=True)
|
|
torch.save({'train_file_list': train_file_list, 'val_file_list': val_file_list}, file_index)
|
|
writer = SummaryWriter(cfg['control']['tensorboard_dir'])
|
|
|
|
train_file_data = GrpFileDatasetsIter(
|
|
file_list = train_file_list,
|
|
file_batch_size = cfg['dataset']['file_batch_size'],
|
|
cycle = True,
|
|
)
|
|
train_data_loader = iter(DataLoader(
|
|
dataset = train_file_data,
|
|
batch_size = batch_size,
|
|
drop_last = True,
|
|
num_workers = 1,
|
|
collate_fn = collate,
|
|
))
|
|
|
|
val_file_data = GrpFileDatasetsIter(
|
|
file_list = val_file_list,
|
|
file_batch_size = cfg['dataset']['file_batch_size'],
|
|
cycle = True,
|
|
)
|
|
val_data_loader = iter(DataLoader(
|
|
dataset = val_file_data,
|
|
batch_size = batch_size,
|
|
drop_last = True,
|
|
num_workers = 1,
|
|
collate_fn = collate,
|
|
))
|
|
|
|
stats = {
|
|
'train_loss': 0,
|
|
'train_acc': 0,
|
|
'val_loss': 0,
|
|
'val_acc': 0,
|
|
}
|
|
logging.info(f'train file list size: {len(train_file_list):,}')
|
|
logging.info(f'val file list size: {len(val_file_list):,}')
|
|
|
|
approx_percent = steps * batch_size / (len(train_file_list) * 10) * 100
|
|
logging.info(f'total steps: {steps:,} est. {approx_percent:6.3f}%')
|
|
|
|
pb = tqdm(total=save_every, desc='TRAIN')
|
|
for inputs, rank_by_players in train_data_loader:
|
|
inputs = inputs.to(dtype=torch.float64, device=device)
|
|
rank_by_players = rank_by_players.to(dtype=torch.int64, device=device)
|
|
|
|
logits = grp.forward_packed(inputs)
|
|
labels = grp.get_label(rank_by_players)
|
|
loss = F.cross_entropy(logits, labels)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
with torch.inference_mode():
|
|
stats['train_loss'] += loss
|
|
stats['train_acc'] += (logits.argmax(-1) == labels).to(torch.float64).mean()
|
|
|
|
steps += 1
|
|
pb.update(1)
|
|
|
|
if steps % save_every == 0:
|
|
pb.close()
|
|
|
|
with torch.inference_mode():
|
|
grp.eval()
|
|
pb = tqdm(total=val_steps, desc='VAL')
|
|
for idx, (inputs, rank_by_players) in enumerate(val_data_loader):
|
|
if idx == val_steps:
|
|
break
|
|
inputs = inputs.to(dtype=torch.float64, device=device)
|
|
rank_by_players = rank_by_players.to(dtype=torch.int64, device=device)
|
|
|
|
logits = grp.forward_packed(inputs)
|
|
labels = grp.get_label(rank_by_players)
|
|
loss = F.cross_entropy(logits, labels)
|
|
|
|
stats['val_loss'] += loss
|
|
stats['val_acc'] += (logits.argmax(-1) == labels).to(torch.float64).mean()
|
|
pb.update(1)
|
|
pb.close()
|
|
grp.train()
|
|
|
|
writer.add_scalars('loss', {
|
|
'train': stats['train_loss'] / save_every,
|
|
'val': stats['val_loss'] / val_steps,
|
|
}, steps)
|
|
writer.add_scalars('acc', {
|
|
'train': stats['train_acc'] / save_every,
|
|
'val': stats['val_acc'] / val_steps,
|
|
}, steps)
|
|
writer.add_scalar('lr', lr, steps)
|
|
writer.flush()
|
|
|
|
for k in stats:
|
|
stats[k] = 0
|
|
approx_percent = steps * batch_size / (len(train_file_list) * 10) * 100
|
|
logging.info(f'total steps: {steps:,} est. {approx_percent:6.3f}%')
|
|
|
|
state = {
|
|
'model': grp.state_dict(),
|
|
'optimizer': optimizer.state_dict(),
|
|
'steps': steps,
|
|
'timestamp': datetime.now().timestamp(),
|
|
}
|
|
torch.save(state, state_file)
|
|
pb = tqdm(total=save_every, desc='TRAIN')
|
|
pb.close()
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
train()
|
|
except KeyboardInterrupt:
|
|
pass
|