Mortal
This commit is contained in:
95
mortal/mortal.py
Normal file
95
mortal/mortal.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import prelude
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
from datetime import datetime, timezone
|
||||
from model import Brain, DQN, GRP
|
||||
from engine import MortalEngine
|
||||
from common import filtered_trimmed_lines
|
||||
from libriichi.mjai import Bot
|
||||
from libriichi.dataset import Grp
|
||||
from config import config
|
||||
|
||||
USAGE = '''Usage: python mortal.py <ID>
|
||||
|
||||
ARGS:
|
||||
<ID> The player ID, an integer within [0, 3].'''
|
||||
|
||||
def main():
|
||||
try:
|
||||
player_id = int(sys.argv[-1])
|
||||
assert player_id in range(4)
|
||||
except:
|
||||
print(USAGE, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
review_mode = os.environ.get('MORTAL_REVIEW_MODE', '0') == '1'
|
||||
|
||||
device = torch.device('cpu')
|
||||
state = torch.load(config['control']['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
||||
cfg = state['config']
|
||||
version = cfg['control'].get('version', 1)
|
||||
num_blocks = cfg['resnet']['num_blocks']
|
||||
conv_channels = cfg['resnet']['conv_channels']
|
||||
if 'tag' in state:
|
||||
tag = state['tag']
|
||||
else:
|
||||
time = datetime.fromtimestamp(state['timestamp'], tz=timezone.utc).strftime('%y%m%d%H')
|
||||
tag = f'mortal{version}-b{num_blocks}c{conv_channels}-t{time}'
|
||||
|
||||
mortal = Brain(version=version, num_blocks=num_blocks, conv_channels=conv_channels).eval()
|
||||
dqn = DQN(version=version).eval()
|
||||
mortal.load_state_dict(state['mortal'])
|
||||
dqn.load_state_dict(state['current_dqn'])
|
||||
|
||||
engine = MortalEngine(
|
||||
mortal,
|
||||
dqn,
|
||||
version = version,
|
||||
is_oracle = False,
|
||||
device = device,
|
||||
enable_amp = False,
|
||||
enable_quick_eval = not review_mode,
|
||||
enable_rule_based_agari_guard = True,
|
||||
name = 'mortal',
|
||||
)
|
||||
bot = Bot(engine, player_id)
|
||||
|
||||
if review_mode:
|
||||
logs = []
|
||||
for line in filtered_trimmed_lines(sys.stdin):
|
||||
if review_mode:
|
||||
logs.append(line)
|
||||
|
||||
if reaction := bot.react(line):
|
||||
print(reaction, flush=True)
|
||||
elif review_mode:
|
||||
print('{"type":"none","meta":{"mask_bits":0}}', flush=True)
|
||||
|
||||
if review_mode:
|
||||
grp = GRP(**config['grp']['network'])
|
||||
grp_state = torch.load(config['grp']['state_file'], weights_only=True, map_location=torch.device('cpu'))
|
||||
grp.load_state_dict(grp_state['model'])
|
||||
|
||||
ins = Grp.load_log('\n'.join(logs))
|
||||
feature = ins.take_feature()
|
||||
seq = list(map(
|
||||
lambda idx: torch.as_tensor(feature[:idx+1], device=device),
|
||||
range(len(feature)),
|
||||
))
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = grp(seq)
|
||||
matrix = grp.calc_matrix(logits)
|
||||
extra_data = {
|
||||
'model_tag': tag,
|
||||
'phi_matrix': matrix.tolist(),
|
||||
}
|
||||
print(json.dumps(extra_data), flush=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
Reference in New Issue
Block a user