137 lines
4.6 KiB
Python
137 lines
4.6 KiB
Python
import json
|
|
import traceback
|
|
import torch
|
|
import numpy as np
|
|
from torch.distributions import Normal, Categorical
|
|
from typing import *
|
|
|
|
class MortalEngine:
|
|
def __init__(
|
|
self,
|
|
brain,
|
|
dqn,
|
|
is_oracle,
|
|
version,
|
|
device = None,
|
|
stochastic_latent = False,
|
|
enable_amp = False,
|
|
enable_quick_eval = True,
|
|
enable_rule_based_agari_guard = False,
|
|
name = 'NoName',
|
|
boltzmann_epsilon = 0,
|
|
boltzmann_temp = 1,
|
|
top_p = 1,
|
|
):
|
|
self.engine_type = 'mortal'
|
|
self.device = device or torch.device('cpu')
|
|
assert isinstance(self.device, torch.device)
|
|
self.brain = brain.to(self.device).eval()
|
|
self.dqn = dqn.to(self.device).eval()
|
|
self.is_oracle = is_oracle
|
|
self.version = version
|
|
self.stochastic_latent = stochastic_latent
|
|
|
|
self.enable_amp = enable_amp
|
|
self.enable_quick_eval = enable_quick_eval
|
|
self.enable_rule_based_agari_guard = enable_rule_based_agari_guard
|
|
self.name = name
|
|
|
|
self.boltzmann_epsilon = boltzmann_epsilon
|
|
self.boltzmann_temp = boltzmann_temp
|
|
self.top_p = top_p
|
|
|
|
def react_batch(self, obs, masks, invisible_obs):
|
|
try:
|
|
with (
|
|
torch.autocast(self.device.type, enabled=self.enable_amp),
|
|
torch.inference_mode(),
|
|
):
|
|
return self._react_batch(obs, masks, invisible_obs)
|
|
except Exception as ex:
|
|
raise Exception(f'{ex}\n{traceback.format_exc()}')
|
|
|
|
def _react_batch(self, obs, masks, invisible_obs):
|
|
obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
|
|
masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
|
|
if invisible_obs is not None:
|
|
invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
|
|
batch_size = obs.shape[0]
|
|
|
|
match self.version:
|
|
case 1:
|
|
mu, logsig = self.brain(obs, invisible_obs)
|
|
if self.stochastic_latent:
|
|
latent = Normal(mu, logsig.exp() + 1e-6).sample()
|
|
else:
|
|
latent = mu
|
|
q_out = self.dqn(latent, masks)
|
|
case 2 | 3 | 4:
|
|
phi = self.brain(obs)
|
|
q_out = self.dqn(phi, masks)
|
|
|
|
if self.boltzmann_epsilon > 0:
|
|
is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool)
|
|
logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf)
|
|
sampled = sample_top_p(logits, self.top_p)
|
|
actions = torch.where(is_greedy, q_out.argmax(-1), sampled)
|
|
else:
|
|
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device)
|
|
actions = q_out.argmax(-1)
|
|
|
|
return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist()
|
|
|
|
def sample_top_p(logits, p):
|
|
if p >= 1:
|
|
return Categorical(logits=logits).sample()
|
|
if p <= 0:
|
|
return logits.argmax(-1)
|
|
probs = logits.softmax(-1)
|
|
probs_sort, probs_idx = probs.sort(-1, descending=True)
|
|
probs_sum = probs_sort.cumsum(-1)
|
|
mask = probs_sum - probs_sort > p
|
|
probs_sort[mask] = 0.
|
|
sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1)
|
|
return sampled
|
|
|
|
class ExampleMjaiLogEngine:
|
|
def __init__(self, name: str):
|
|
self.engine_type = 'mjai-log'
|
|
self.name = name
|
|
self.player_ids = None
|
|
|
|
def set_player_ids(self, player_ids: List[int]):
|
|
self.player_ids = player_ids
|
|
|
|
def react_batch(self, game_states):
|
|
res = []
|
|
for game_state in game_states:
|
|
game_idx = game_state.game_index
|
|
state = game_state.state
|
|
events_json = game_state.events_json
|
|
|
|
events = json.loads(events_json)
|
|
assert events[0]['type'] == 'start_kyoku'
|
|
|
|
player_id = self.player_ids[game_idx]
|
|
cans = state.last_cans
|
|
if cans.can_discard:
|
|
tile = state.last_self_tsumo()
|
|
res.append(json.dumps({
|
|
'type': 'dahai',
|
|
'actor': player_id,
|
|
'pai': tile,
|
|
'tsumogiri': True,
|
|
}))
|
|
else:
|
|
res.append('{"type":"none"}')
|
|
return res
|
|
|
|
# They will be executed at specific events. They can be no-op but must be
|
|
# defined.
|
|
def start_game(self, game_idx: int):
|
|
pass
|
|
def end_kyoku(self, game_idx: int):
|
|
pass
|
|
def end_game(self, game_idx: int, scores: List[int]):
|
|
pass
|