140 lines
5.2 KiB
Python
140 lines
5.2 KiB
Python
import random
|
|
import torch
|
|
import numpy as np
|
|
from torch.utils.data import IterableDataset
|
|
from model import GRP
|
|
from reward_calculator import RewardCalculator
|
|
from libriichi.dataset import GameplayLoader
|
|
from config import config
|
|
|
|
class FileDatasetsIter(IterableDataset):
|
|
def __init__(
|
|
self,
|
|
version,
|
|
file_list,
|
|
pts,
|
|
oracle = False,
|
|
file_batch_size = 20, # hint: around 660 instances per file
|
|
reserve_ratio = 0,
|
|
player_names = None,
|
|
excludes = None,
|
|
num_epochs = 1,
|
|
enable_augmentation = False,
|
|
augmented_first = False,
|
|
):
|
|
super().__init__()
|
|
self.version = version
|
|
self.file_list = file_list
|
|
self.pts = pts
|
|
self.oracle = oracle
|
|
self.file_batch_size = file_batch_size
|
|
self.reserve_ratio = reserve_ratio
|
|
self.player_names = player_names
|
|
self.excludes = excludes
|
|
self.num_epochs = num_epochs
|
|
self.enable_augmentation = enable_augmentation
|
|
self.augmented_first = augmented_first
|
|
self.iterator = None
|
|
|
|
def build_iter(self):
|
|
# do not put it in __init__, it won't work on Windows
|
|
self.grp = GRP(**config['grp']['network'])
|
|
grp_state = torch.load(config['grp']['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
|
self.grp.load_state_dict(grp_state['model'])
|
|
self.reward_calc = RewardCalculator(self.grp, self.pts)
|
|
|
|
for _ in range(self.num_epochs):
|
|
yield from self.load_files(self.augmented_first)
|
|
if self.enable_augmentation:
|
|
yield from self.load_files(not self.augmented_first)
|
|
|
|
def load_files(self, augmented):
|
|
# shuffle the file list for each epoch
|
|
random.shuffle(self.file_list)
|
|
|
|
self.loader = GameplayLoader(
|
|
version = self.version,
|
|
oracle = self.oracle,
|
|
player_names = self.player_names,
|
|
excludes = self.excludes,
|
|
augmented = augmented,
|
|
)
|
|
self.buffer = []
|
|
|
|
for start_idx in range(0, len(self.file_list), self.file_batch_size):
|
|
old_buffer_size = len(self.buffer)
|
|
self.populate_buffer(self.file_list[start_idx:start_idx + self.file_batch_size])
|
|
buffer_size = len(self.buffer)
|
|
|
|
reserved_size = int((buffer_size - old_buffer_size) * self.reserve_ratio)
|
|
if reserved_size > buffer_size:
|
|
continue
|
|
|
|
random.shuffle(self.buffer)
|
|
yield from self.buffer[reserved_size:]
|
|
del self.buffer[reserved_size:]
|
|
random.shuffle(self.buffer)
|
|
yield from self.buffer
|
|
self.buffer.clear()
|
|
|
|
def populate_buffer(self, file_list):
|
|
data = self.loader.load_gz_log_files(file_list)
|
|
for file in data:
|
|
for game in file:
|
|
# per move
|
|
obs = game.take_obs()
|
|
if self.oracle:
|
|
invisible_obs = game.take_invisible_obs()
|
|
actions = game.take_actions()
|
|
masks = game.take_masks()
|
|
at_kyoku = game.take_at_kyoku()
|
|
dones = game.take_dones()
|
|
apply_gamma = game.take_apply_gamma()
|
|
|
|
# per game
|
|
grp = game.take_grp()
|
|
player_id = game.take_player_id()
|
|
|
|
game_size = len(obs)
|
|
|
|
grp_feature = grp.take_feature()
|
|
rank_by_player = grp.take_rank_by_player()
|
|
kyoku_rewards = self.reward_calc.calc_delta_pt(player_id, grp_feature, rank_by_player)
|
|
assert len(kyoku_rewards) >= at_kyoku[-1] + 1 # usually they are equal, unless there is no action in the last kyoku
|
|
|
|
final_scores = grp.take_final_scores()
|
|
scores_seq = np.concatenate((grp_feature[:, 3:] * 1e4, [final_scores]))
|
|
rank_by_player_seq = (-scores_seq).argsort(-1, kind='stable').argsort(-1, kind='stable')
|
|
player_ranks = rank_by_player_seq[:, player_id]
|
|
|
|
steps_to_done = np.zeros(game_size, dtype=np.int64)
|
|
for i in reversed(range(game_size)):
|
|
if not dones[i]:
|
|
steps_to_done[i] = steps_to_done[i + 1] + int(apply_gamma[i])
|
|
|
|
for i in range(game_size):
|
|
entry = [
|
|
obs[i],
|
|
actions[i],
|
|
masks[i],
|
|
steps_to_done[i],
|
|
kyoku_rewards[at_kyoku[i]],
|
|
player_ranks[at_kyoku[i] + 1],
|
|
]
|
|
if self.oracle:
|
|
entry.insert(1, invisible_obs[i])
|
|
self.buffer.append(entry)
|
|
|
|
def __iter__(self):
|
|
if self.iterator is None:
|
|
self.iterator = self.build_iter()
|
|
return self.iterator
|
|
|
|
def worker_init_fn(*args, **kwargs):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
dataset = worker_info.dataset
|
|
per_worker = int(np.ceil(len(dataset.file_list) / worker_info.num_workers))
|
|
start = worker_info.id * per_worker
|
|
end = start + per_worker
|
|
dataset.file_list = dataset.file_list[start:end]
|