140 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			140 lines
		
	
	
		
			5.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import random
 | 
						|
import torch
 | 
						|
import numpy as np
 | 
						|
from torch.utils.data import IterableDataset
 | 
						|
from model import GRP
 | 
						|
from reward_calculator import RewardCalculator
 | 
						|
from libriichi.dataset import GameplayLoader
 | 
						|
from config import config
 | 
						|
 | 
						|
class FileDatasetsIter(IterableDataset):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        version,
 | 
						|
        file_list,
 | 
						|
        pts,
 | 
						|
        oracle = False,
 | 
						|
        file_batch_size = 20, # hint: around 660 instances per file
 | 
						|
        reserve_ratio = 0,
 | 
						|
        player_names = None,
 | 
						|
        excludes = None,
 | 
						|
        num_epochs = 1,
 | 
						|
        enable_augmentation = False,
 | 
						|
        augmented_first = False,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.version = version
 | 
						|
        self.file_list = file_list
 | 
						|
        self.pts = pts
 | 
						|
        self.oracle = oracle
 | 
						|
        self.file_batch_size = file_batch_size
 | 
						|
        self.reserve_ratio = reserve_ratio
 | 
						|
        self.player_names = player_names
 | 
						|
        self.excludes = excludes
 | 
						|
        self.num_epochs = num_epochs
 | 
						|
        self.enable_augmentation = enable_augmentation
 | 
						|
        self.augmented_first = augmented_first
 | 
						|
        self.iterator = None
 | 
						|
 | 
						|
    def build_iter(self):
 | 
						|
        # do not put it in __init__, it won't work on Windows
 | 
						|
        self.grp = GRP(**config['grp']['network'])
 | 
						|
        grp_state = torch.load(config['grp']['state_file'], weights_only=True, map_location=torch.device('cpu'))
 | 
						|
        self.grp.load_state_dict(grp_state['model'])
 | 
						|
        self.reward_calc = RewardCalculator(self.grp, self.pts)
 | 
						|
 | 
						|
        for _ in range(self.num_epochs):
 | 
						|
            yield from self.load_files(self.augmented_first)
 | 
						|
            if self.enable_augmentation:
 | 
						|
                yield from self.load_files(not self.augmented_first)
 | 
						|
 | 
						|
    def load_files(self, augmented):
 | 
						|
        # shuffle the file list for each epoch
 | 
						|
        random.shuffle(self.file_list)
 | 
						|
 | 
						|
        self.loader = GameplayLoader(
 | 
						|
            version = self.version,
 | 
						|
            oracle = self.oracle,
 | 
						|
            player_names = self.player_names,
 | 
						|
            excludes = self.excludes,
 | 
						|
            augmented = augmented,
 | 
						|
        )
 | 
						|
        self.buffer = []
 | 
						|
 | 
						|
        for start_idx in range(0, len(self.file_list), self.file_batch_size):
 | 
						|
            old_buffer_size = len(self.buffer)
 | 
						|
            self.populate_buffer(self.file_list[start_idx:start_idx + self.file_batch_size])
 | 
						|
            buffer_size = len(self.buffer)
 | 
						|
 | 
						|
            reserved_size = int((buffer_size - old_buffer_size) * self.reserve_ratio)
 | 
						|
            if reserved_size > buffer_size:
 | 
						|
                continue
 | 
						|
 | 
						|
            random.shuffle(self.buffer)
 | 
						|
            yield from self.buffer[reserved_size:]
 | 
						|
            del self.buffer[reserved_size:]
 | 
						|
        random.shuffle(self.buffer)
 | 
						|
        yield from self.buffer
 | 
						|
        self.buffer.clear()
 | 
						|
 | 
						|
    def populate_buffer(self, file_list):
 | 
						|
        data = self.loader.load_gz_log_files(file_list)
 | 
						|
        for file in data:
 | 
						|
            for game in file:
 | 
						|
                # per move
 | 
						|
                obs = game.take_obs()
 | 
						|
                if self.oracle:
 | 
						|
                    invisible_obs = game.take_invisible_obs()
 | 
						|
                actions = game.take_actions()
 | 
						|
                masks = game.take_masks()
 | 
						|
                at_kyoku = game.take_at_kyoku()
 | 
						|
                dones = game.take_dones()
 | 
						|
                apply_gamma = game.take_apply_gamma()
 | 
						|
 | 
						|
                # per game
 | 
						|
                grp = game.take_grp()
 | 
						|
                player_id = game.take_player_id()
 | 
						|
 | 
						|
                game_size = len(obs)
 | 
						|
 | 
						|
                grp_feature = grp.take_feature()
 | 
						|
                rank_by_player = grp.take_rank_by_player()
 | 
						|
                kyoku_rewards = self.reward_calc.calc_delta_pt(player_id, grp_feature, rank_by_player)
 | 
						|
                assert len(kyoku_rewards) >= at_kyoku[-1] + 1 # usually they are equal, unless there is no action in the last kyoku
 | 
						|
 | 
						|
                final_scores = grp.take_final_scores()
 | 
						|
                scores_seq = np.concatenate((grp_feature[:, 3:] * 1e4, [final_scores]))
 | 
						|
                rank_by_player_seq = (-scores_seq).argsort(-1, kind='stable').argsort(-1, kind='stable')
 | 
						|
                player_ranks = rank_by_player_seq[:, player_id]
 | 
						|
 | 
						|
                steps_to_done = np.zeros(game_size, dtype=np.int64)
 | 
						|
                for i in reversed(range(game_size)):
 | 
						|
                    if not dones[i]:
 | 
						|
                        steps_to_done[i] = steps_to_done[i + 1] + int(apply_gamma[i])
 | 
						|
 | 
						|
                for i in range(game_size):
 | 
						|
                    entry = [
 | 
						|
                        obs[i],
 | 
						|
                        actions[i],
 | 
						|
                        masks[i],
 | 
						|
                        steps_to_done[i],
 | 
						|
                        kyoku_rewards[at_kyoku[i]],
 | 
						|
                        player_ranks[at_kyoku[i] + 1],
 | 
						|
                    ]
 | 
						|
                    if self.oracle:
 | 
						|
                        entry.insert(1, invisible_obs[i])
 | 
						|
                    self.buffer.append(entry)
 | 
						|
 | 
						|
    def __iter__(self):
 | 
						|
        if self.iterator is None:
 | 
						|
            self.iterator = self.build_iter()
 | 
						|
        return self.iterator
 | 
						|
 | 
						|
def worker_init_fn(*args, **kwargs):
 | 
						|
    worker_info = torch.utils.data.get_worker_info()
 | 
						|
    dataset = worker_info.dataset
 | 
						|
    per_worker = int(np.ceil(len(dataset.file_list) / worker_info.num_workers))
 | 
						|
    start = worker_info.id * per_worker
 | 
						|
    end = start + per_worker
 | 
						|
    dataset.file_list = dataset.file_list[start:end]
 |