Files
Mortal-Copied/mortal/dataloader.py
e2hang b7a7d7404a
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled
Mortal
2025-10-07 20:30:03 +08:00

140 lines
5.2 KiB
Python

import random
import torch
import numpy as np
from torch.utils.data import IterableDataset
from model import GRP
from reward_calculator import RewardCalculator
from libriichi.dataset import GameplayLoader
from config import config
class FileDatasetsIter(IterableDataset):
def __init__(
self,
version,
file_list,
pts,
oracle = False,
file_batch_size = 20, # hint: around 660 instances per file
reserve_ratio = 0,
player_names = None,
excludes = None,
num_epochs = 1,
enable_augmentation = False,
augmented_first = False,
):
super().__init__()
self.version = version
self.file_list = file_list
self.pts = pts
self.oracle = oracle
self.file_batch_size = file_batch_size
self.reserve_ratio = reserve_ratio
self.player_names = player_names
self.excludes = excludes
self.num_epochs = num_epochs
self.enable_augmentation = enable_augmentation
self.augmented_first = augmented_first
self.iterator = None
def build_iter(self):
# do not put it in __init__, it won't work on Windows
self.grp = GRP(**config['grp']['network'])
grp_state = torch.load(config['grp']['state_file'], weights_only=True, map_location=torch.device('cpu'))
self.grp.load_state_dict(grp_state['model'])
self.reward_calc = RewardCalculator(self.grp, self.pts)
for _ in range(self.num_epochs):
yield from self.load_files(self.augmented_first)
if self.enable_augmentation:
yield from self.load_files(not self.augmented_first)
def load_files(self, augmented):
# shuffle the file list for each epoch
random.shuffle(self.file_list)
self.loader = GameplayLoader(
version = self.version,
oracle = self.oracle,
player_names = self.player_names,
excludes = self.excludes,
augmented = augmented,
)
self.buffer = []
for start_idx in range(0, len(self.file_list), self.file_batch_size):
old_buffer_size = len(self.buffer)
self.populate_buffer(self.file_list[start_idx:start_idx + self.file_batch_size])
buffer_size = len(self.buffer)
reserved_size = int((buffer_size - old_buffer_size) * self.reserve_ratio)
if reserved_size > buffer_size:
continue
random.shuffle(self.buffer)
yield from self.buffer[reserved_size:]
del self.buffer[reserved_size:]
random.shuffle(self.buffer)
yield from self.buffer
self.buffer.clear()
def populate_buffer(self, file_list):
data = self.loader.load_gz_log_files(file_list)
for file in data:
for game in file:
# per move
obs = game.take_obs()
if self.oracle:
invisible_obs = game.take_invisible_obs()
actions = game.take_actions()
masks = game.take_masks()
at_kyoku = game.take_at_kyoku()
dones = game.take_dones()
apply_gamma = game.take_apply_gamma()
# per game
grp = game.take_grp()
player_id = game.take_player_id()
game_size = len(obs)
grp_feature = grp.take_feature()
rank_by_player = grp.take_rank_by_player()
kyoku_rewards = self.reward_calc.calc_delta_pt(player_id, grp_feature, rank_by_player)
assert len(kyoku_rewards) >= at_kyoku[-1] + 1 # usually they are equal, unless there is no action in the last kyoku
final_scores = grp.take_final_scores()
scores_seq = np.concatenate((grp_feature[:, 3:] * 1e4, [final_scores]))
rank_by_player_seq = (-scores_seq).argsort(-1, kind='stable').argsort(-1, kind='stable')
player_ranks = rank_by_player_seq[:, player_id]
steps_to_done = np.zeros(game_size, dtype=np.int64)
for i in reversed(range(game_size)):
if not dones[i]:
steps_to_done[i] = steps_to_done[i + 1] + int(apply_gamma[i])
for i in range(game_size):
entry = [
obs[i],
actions[i],
masks[i],
steps_to_done[i],
kyoku_rewards[at_kyoku[i]],
player_ranks[at_kyoku[i] + 1],
]
if self.oracle:
entry.insert(1, invisible_obs[i])
self.buffer.append(entry)
def __iter__(self):
if self.iterator is None:
self.iterator = self.build_iter()
return self.iterator
def worker_init_fn(*args, **kwargs):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
per_worker = int(np.ceil(len(dataset.file_list) / worker_info.num_workers))
start = worker_info.id * per_worker
end = start + per_worker
dataset.file_list = dataset.file_list[start:end]