import random import torch import numpy as np from torch.utils.data import IterableDataset from model import GRP from reward_calculator import RewardCalculator from libriichi.dataset import GameplayLoader from config import config class FileDatasetsIter(IterableDataset): def __init__( self, version, file_list, pts, oracle = False, file_batch_size = 20, # hint: around 660 instances per file reserve_ratio = 0, player_names = None, excludes = None, num_epochs = 1, enable_augmentation = False, augmented_first = False, ): super().__init__() self.version = version self.file_list = file_list self.pts = pts self.oracle = oracle self.file_batch_size = file_batch_size self.reserve_ratio = reserve_ratio self.player_names = player_names self.excludes = excludes self.num_epochs = num_epochs self.enable_augmentation = enable_augmentation self.augmented_first = augmented_first self.iterator = None def build_iter(self): # do not put it in __init__, it won't work on Windows self.grp = GRP(**config['grp']['network']) grp_state = torch.load(config['grp']['state_file'], weights_only=True, map_location=torch.device('cpu')) self.grp.load_state_dict(grp_state['model']) self.reward_calc = RewardCalculator(self.grp, self.pts) for _ in range(self.num_epochs): yield from self.load_files(self.augmented_first) if self.enable_augmentation: yield from self.load_files(not self.augmented_first) def load_files(self, augmented): # shuffle the file list for each epoch random.shuffle(self.file_list) self.loader = GameplayLoader( version = self.version, oracle = self.oracle, player_names = self.player_names, excludes = self.excludes, augmented = augmented, ) self.buffer = [] for start_idx in range(0, len(self.file_list), self.file_batch_size): old_buffer_size = len(self.buffer) self.populate_buffer(self.file_list[start_idx:start_idx + self.file_batch_size]) buffer_size = len(self.buffer) reserved_size = int((buffer_size - old_buffer_size) * self.reserve_ratio) if reserved_size > buffer_size: continue random.shuffle(self.buffer) yield from self.buffer[reserved_size:] del self.buffer[reserved_size:] random.shuffle(self.buffer) yield from self.buffer self.buffer.clear() def populate_buffer(self, file_list): data = self.loader.load_gz_log_files(file_list) for file in data: for game in file: # per move obs = game.take_obs() if self.oracle: invisible_obs = game.take_invisible_obs() actions = game.take_actions() masks = game.take_masks() at_kyoku = game.take_at_kyoku() dones = game.take_dones() apply_gamma = game.take_apply_gamma() # per game grp = game.take_grp() player_id = game.take_player_id() game_size = len(obs) grp_feature = grp.take_feature() rank_by_player = grp.take_rank_by_player() kyoku_rewards = self.reward_calc.calc_delta_pt(player_id, grp_feature, rank_by_player) assert len(kyoku_rewards) >= at_kyoku[-1] + 1 # usually they are equal, unless there is no action in the last kyoku final_scores = grp.take_final_scores() scores_seq = np.concatenate((grp_feature[:, 3:] * 1e4, [final_scores])) rank_by_player_seq = (-scores_seq).argsort(-1, kind='stable').argsort(-1, kind='stable') player_ranks = rank_by_player_seq[:, player_id] steps_to_done = np.zeros(game_size, dtype=np.int64) for i in reversed(range(game_size)): if not dones[i]: steps_to_done[i] = steps_to_done[i + 1] + int(apply_gamma[i]) for i in range(game_size): entry = [ obs[i], actions[i], masks[i], steps_to_done[i], kyoku_rewards[at_kyoku[i]], player_ranks[at_kyoku[i] + 1], ] if self.oracle: entry.insert(1, invisible_obs[i]) self.buffer.append(entry) def __iter__(self): if self.iterator is None: self.iterator = self.build_iter() return self.iterator def worker_init_fn(*args, **kwargs): worker_info = torch.utils.data.get_worker_info() dataset = worker_info.dataset per_worker = int(np.ceil(len(dataset.file_list) / worker_info.num_workers)) start = worker_info.id * per_worker end = start + per_worker dataset.file_list = dataset.file_list[start:end]