import json import traceback import torch import numpy as np from torch.distributions import Normal, Categorical from typing import * class MortalEngine: def __init__( self, brain, dqn, is_oracle, version, device = None, stochastic_latent = False, enable_amp = False, enable_quick_eval = True, enable_rule_based_agari_guard = False, name = 'NoName', boltzmann_epsilon = 0, boltzmann_temp = 1, top_p = 1, ): self.engine_type = 'mortal' self.device = device or torch.device('cpu') assert isinstance(self.device, torch.device) self.brain = brain.to(self.device).eval() self.dqn = dqn.to(self.device).eval() self.is_oracle = is_oracle self.version = version self.stochastic_latent = stochastic_latent self.enable_amp = enable_amp self.enable_quick_eval = enable_quick_eval self.enable_rule_based_agari_guard = enable_rule_based_agari_guard self.name = name self.boltzmann_epsilon = boltzmann_epsilon self.boltzmann_temp = boltzmann_temp self.top_p = top_p def react_batch(self, obs, masks, invisible_obs): try: with ( torch.autocast(self.device.type, enabled=self.enable_amp), torch.inference_mode(), ): return self._react_batch(obs, masks, invisible_obs) except Exception as ex: raise Exception(f'{ex}\n{traceback.format_exc()}') def _react_batch(self, obs, masks, invisible_obs): obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device) masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device) if invisible_obs is not None: invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device) batch_size = obs.shape[0] match self.version: case 1: mu, logsig = self.brain(obs, invisible_obs) if self.stochastic_latent: latent = Normal(mu, logsig.exp() + 1e-6).sample() else: latent = mu q_out = self.dqn(latent, masks) case 2 | 3 | 4: phi = self.brain(obs) q_out = self.dqn(phi, masks) if self.boltzmann_epsilon > 0: is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool) logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf) sampled = sample_top_p(logits, self.top_p) actions = torch.where(is_greedy, q_out.argmax(-1), sampled) else: is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device) actions = q_out.argmax(-1) return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist() def sample_top_p(logits, p): if p >= 1: return Categorical(logits=logits).sample() if p <= 0: return logits.argmax(-1) probs = logits.softmax(-1) probs_sort, probs_idx = probs.sort(-1, descending=True) probs_sum = probs_sort.cumsum(-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0. sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1) return sampled class ExampleMjaiLogEngine: def __init__(self, name: str): self.engine_type = 'mjai-log' self.name = name self.player_ids = None def set_player_ids(self, player_ids: List[int]): self.player_ids = player_ids def react_batch(self, game_states): res = [] for game_state in game_states: game_idx = game_state.game_index state = game_state.state events_json = game_state.events_json events = json.loads(events_json) assert events[0]['type'] == 'start_kyoku' player_id = self.player_ids[game_idx] cans = state.last_cans if cans.can_discard: tile = state.last_self_tsumo() res.append(json.dumps({ 'type': 'dahai', 'actor': player_id, 'pai': tile, 'tsumogiri': True, })) else: res.append('{"type":"none"}') return res # They will be executed at specific events. They can be no-op but must be # defined. def start_game(self, game_idx: int): pass def end_kyoku(self, game_idx: int): pass def end_game(self, game_idx: int, scores: List[int]): pass