137 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			137 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import json
 | 
						|
import traceback
 | 
						|
import torch
 | 
						|
import numpy as np
 | 
						|
from torch.distributions import Normal, Categorical
 | 
						|
from typing import *
 | 
						|
 | 
						|
class MortalEngine:
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        brain,
 | 
						|
        dqn,
 | 
						|
        is_oracle,
 | 
						|
        version,
 | 
						|
        device = None,
 | 
						|
        stochastic_latent = False,
 | 
						|
        enable_amp = False,
 | 
						|
        enable_quick_eval = True,
 | 
						|
        enable_rule_based_agari_guard = False,
 | 
						|
        name = 'NoName',
 | 
						|
        boltzmann_epsilon = 0,
 | 
						|
        boltzmann_temp = 1,
 | 
						|
        top_p = 1,
 | 
						|
    ):
 | 
						|
        self.engine_type = 'mortal'
 | 
						|
        self.device = device or torch.device('cpu')
 | 
						|
        assert isinstance(self.device, torch.device)
 | 
						|
        self.brain = brain.to(self.device).eval()
 | 
						|
        self.dqn = dqn.to(self.device).eval()
 | 
						|
        self.is_oracle = is_oracle
 | 
						|
        self.version = version
 | 
						|
        self.stochastic_latent = stochastic_latent
 | 
						|
 | 
						|
        self.enable_amp = enable_amp
 | 
						|
        self.enable_quick_eval = enable_quick_eval
 | 
						|
        self.enable_rule_based_agari_guard = enable_rule_based_agari_guard
 | 
						|
        self.name = name
 | 
						|
 | 
						|
        self.boltzmann_epsilon = boltzmann_epsilon
 | 
						|
        self.boltzmann_temp = boltzmann_temp
 | 
						|
        self.top_p = top_p
 | 
						|
 | 
						|
    def react_batch(self, obs, masks, invisible_obs):
 | 
						|
        try:
 | 
						|
            with (
 | 
						|
                torch.autocast(self.device.type, enabled=self.enable_amp),
 | 
						|
                torch.inference_mode(),
 | 
						|
            ):
 | 
						|
                return self._react_batch(obs, masks, invisible_obs)
 | 
						|
        except Exception as ex:
 | 
						|
            raise Exception(f'{ex}\n{traceback.format_exc()}')
 | 
						|
 | 
						|
    def _react_batch(self, obs, masks, invisible_obs):
 | 
						|
        obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
 | 
						|
        masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
 | 
						|
        if invisible_obs is not None:
 | 
						|
            invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
 | 
						|
        batch_size = obs.shape[0]
 | 
						|
 | 
						|
        match self.version:
 | 
						|
            case 1:
 | 
						|
                mu, logsig = self.brain(obs, invisible_obs)
 | 
						|
                if self.stochastic_latent:
 | 
						|
                    latent = Normal(mu, logsig.exp() + 1e-6).sample()
 | 
						|
                else:
 | 
						|
                    latent = mu
 | 
						|
                q_out = self.dqn(latent, masks)
 | 
						|
            case 2 | 3 | 4:
 | 
						|
                phi = self.brain(obs)
 | 
						|
                q_out = self.dqn(phi, masks)
 | 
						|
 | 
						|
        if self.boltzmann_epsilon > 0:
 | 
						|
            is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool)
 | 
						|
            logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf)
 | 
						|
            sampled = sample_top_p(logits, self.top_p)
 | 
						|
            actions = torch.where(is_greedy, q_out.argmax(-1), sampled)
 | 
						|
        else:
 | 
						|
            is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device)
 | 
						|
            actions = q_out.argmax(-1)
 | 
						|
 | 
						|
        return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist()
 | 
						|
 | 
						|
def sample_top_p(logits, p):
 | 
						|
    if p >= 1:
 | 
						|
        return Categorical(logits=logits).sample()
 | 
						|
    if p <= 0:
 | 
						|
        return logits.argmax(-1)
 | 
						|
    probs = logits.softmax(-1)
 | 
						|
    probs_sort, probs_idx = probs.sort(-1, descending=True)
 | 
						|
    probs_sum = probs_sort.cumsum(-1)
 | 
						|
    mask = probs_sum - probs_sort > p
 | 
						|
    probs_sort[mask] = 0.
 | 
						|
    sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1)
 | 
						|
    return sampled
 | 
						|
 | 
						|
class ExampleMjaiLogEngine:
 | 
						|
    def __init__(self, name: str):
 | 
						|
        self.engine_type = 'mjai-log'
 | 
						|
        self.name = name
 | 
						|
        self.player_ids = None
 | 
						|
 | 
						|
    def set_player_ids(self, player_ids: List[int]):
 | 
						|
        self.player_ids = player_ids
 | 
						|
 | 
						|
    def react_batch(self, game_states):
 | 
						|
        res = []
 | 
						|
        for game_state in game_states:
 | 
						|
            game_idx = game_state.game_index
 | 
						|
            state = game_state.state
 | 
						|
            events_json = game_state.events_json
 | 
						|
 | 
						|
            events = json.loads(events_json)
 | 
						|
            assert events[0]['type'] == 'start_kyoku'
 | 
						|
 | 
						|
            player_id = self.player_ids[game_idx]
 | 
						|
            cans = state.last_cans
 | 
						|
            if cans.can_discard:
 | 
						|
                tile = state.last_self_tsumo()
 | 
						|
                res.append(json.dumps({
 | 
						|
                    'type': 'dahai',
 | 
						|
                    'actor': player_id,
 | 
						|
                    'pai': tile,
 | 
						|
                    'tsumogiri': True,
 | 
						|
                }))
 | 
						|
            else:
 | 
						|
                res.append('{"type":"none"}')
 | 
						|
        return res
 | 
						|
 | 
						|
    # They will be executed at specific events. They can be no-op but must be
 | 
						|
    # defined.
 | 
						|
    def start_game(self, game_idx: int):
 | 
						|
        pass
 | 
						|
    def end_kyoku(self, game_idx: int):
 | 
						|
        pass
 | 
						|
    def end_game(self, game_idx: int, scores: List[int]):
 | 
						|
        pass
 |