Files
Mortal-Copied/mortal/engine.py
e2hang b7a7d7404a
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled
Mortal
2025-10-07 20:30:03 +08:00

137 lines
4.6 KiB
Python

import json
import traceback
import torch
import numpy as np
from torch.distributions import Normal, Categorical
from typing import *
class MortalEngine:
def __init__(
self,
brain,
dqn,
is_oracle,
version,
device = None,
stochastic_latent = False,
enable_amp = False,
enable_quick_eval = True,
enable_rule_based_agari_guard = False,
name = 'NoName',
boltzmann_epsilon = 0,
boltzmann_temp = 1,
top_p = 1,
):
self.engine_type = 'mortal'
self.device = device or torch.device('cpu')
assert isinstance(self.device, torch.device)
self.brain = brain.to(self.device).eval()
self.dqn = dqn.to(self.device).eval()
self.is_oracle = is_oracle
self.version = version
self.stochastic_latent = stochastic_latent
self.enable_amp = enable_amp
self.enable_quick_eval = enable_quick_eval
self.enable_rule_based_agari_guard = enable_rule_based_agari_guard
self.name = name
self.boltzmann_epsilon = boltzmann_epsilon
self.boltzmann_temp = boltzmann_temp
self.top_p = top_p
def react_batch(self, obs, masks, invisible_obs):
try:
with (
torch.autocast(self.device.type, enabled=self.enable_amp),
torch.inference_mode(),
):
return self._react_batch(obs, masks, invisible_obs)
except Exception as ex:
raise Exception(f'{ex}\n{traceback.format_exc()}')
def _react_batch(self, obs, masks, invisible_obs):
obs = torch.as_tensor(np.stack(obs, axis=0), device=self.device)
masks = torch.as_tensor(np.stack(masks, axis=0), device=self.device)
if invisible_obs is not None:
invisible_obs = torch.as_tensor(np.stack(invisible_obs, axis=0), device=self.device)
batch_size = obs.shape[0]
match self.version:
case 1:
mu, logsig = self.brain(obs, invisible_obs)
if self.stochastic_latent:
latent = Normal(mu, logsig.exp() + 1e-6).sample()
else:
latent = mu
q_out = self.dqn(latent, masks)
case 2 | 3 | 4:
phi = self.brain(obs)
q_out = self.dqn(phi, masks)
if self.boltzmann_epsilon > 0:
is_greedy = torch.full((batch_size,), 1-self.boltzmann_epsilon, device=self.device).bernoulli().to(torch.bool)
logits = (q_out / self.boltzmann_temp).masked_fill(~masks, -torch.inf)
sampled = sample_top_p(logits, self.top_p)
actions = torch.where(is_greedy, q_out.argmax(-1), sampled)
else:
is_greedy = torch.ones(batch_size, dtype=torch.bool, device=self.device)
actions = q_out.argmax(-1)
return actions.tolist(), q_out.tolist(), masks.tolist(), is_greedy.tolist()
def sample_top_p(logits, p):
if p >= 1:
return Categorical(logits=logits).sample()
if p <= 0:
return logits.argmax(-1)
probs = logits.softmax(-1)
probs_sort, probs_idx = probs.sort(-1, descending=True)
probs_sum = probs_sort.cumsum(-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.
sampled = probs_idx.gather(-1, probs_sort.multinomial(1)).squeeze(-1)
return sampled
class ExampleMjaiLogEngine:
def __init__(self, name: str):
self.engine_type = 'mjai-log'
self.name = name
self.player_ids = None
def set_player_ids(self, player_ids: List[int]):
self.player_ids = player_ids
def react_batch(self, game_states):
res = []
for game_state in game_states:
game_idx = game_state.game_index
state = game_state.state
events_json = game_state.events_json
events = json.loads(events_json)
assert events[0]['type'] == 'start_kyoku'
player_id = self.player_ids[game_idx]
cans = state.last_cans
if cans.can_discard:
tile = state.last_self_tsumo()
res.append(json.dumps({
'type': 'dahai',
'actor': player_id,
'pai': tile,
'tsumogiri': True,
}))
else:
res.append('{"type":"none"}')
return res
# They will be executed at specific events. They can be no-op but must be
# defined.
def start_game(self, game_idx: int):
pass
def end_kyoku(self, game_idx: int):
pass
def end_game(self, game_idx: int, scores: List[int]):
pass