96 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			96 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import prelude
 | 
						|
 | 
						|
import os
 | 
						|
import sys
 | 
						|
import json
 | 
						|
import torch
 | 
						|
from datetime import datetime, timezone
 | 
						|
from model import Brain, DQN, GRP
 | 
						|
from engine import MortalEngine
 | 
						|
from common import filtered_trimmed_lines
 | 
						|
from libriichi.mjai import Bot
 | 
						|
from libriichi.dataset import Grp
 | 
						|
from config import config
 | 
						|
 | 
						|
USAGE = '''Usage: python mortal.py <ID>
 | 
						|
 | 
						|
ARGS:
 | 
						|
    <ID>    The player ID, an integer within [0, 3].'''
 | 
						|
 | 
						|
def main():
 | 
						|
    try:
 | 
						|
        player_id = int(sys.argv[-1])
 | 
						|
        assert player_id in range(4)
 | 
						|
    except:
 | 
						|
        print(USAGE, file=sys.stderr)
 | 
						|
        sys.exit(1)
 | 
						|
    review_mode = os.environ.get('MORTAL_REVIEW_MODE', '0') == '1'
 | 
						|
 | 
						|
    device = torch.device('cpu')
 | 
						|
    state = torch.load(config['control']['state_file'], weights_only=True, map_location=torch.device('cpu'))
 | 
						|
    cfg = state['config']
 | 
						|
    version = cfg['control'].get('version', 1)
 | 
						|
    num_blocks = cfg['resnet']['num_blocks']
 | 
						|
    conv_channels = cfg['resnet']['conv_channels']
 | 
						|
    if 'tag' in state:
 | 
						|
        tag = state['tag']
 | 
						|
    else:
 | 
						|
        time = datetime.fromtimestamp(state['timestamp'], tz=timezone.utc).strftime('%y%m%d%H')
 | 
						|
        tag = f'mortal{version}-b{num_blocks}c{conv_channels}-t{time}'
 | 
						|
 | 
						|
    mortal = Brain(version=version, num_blocks=num_blocks, conv_channels=conv_channels).eval()
 | 
						|
    dqn = DQN(version=version).eval()
 | 
						|
    mortal.load_state_dict(state['mortal'])
 | 
						|
    dqn.load_state_dict(state['current_dqn'])
 | 
						|
 | 
						|
    engine = MortalEngine(
 | 
						|
        mortal,
 | 
						|
        dqn,
 | 
						|
        version = version,
 | 
						|
        is_oracle = False,
 | 
						|
        device = device,
 | 
						|
        enable_amp = False,
 | 
						|
        enable_quick_eval = not review_mode,
 | 
						|
        enable_rule_based_agari_guard = True,
 | 
						|
        name = 'mortal',
 | 
						|
    )
 | 
						|
    bot = Bot(engine, player_id)
 | 
						|
 | 
						|
    if review_mode:
 | 
						|
        logs = []
 | 
						|
    for line in filtered_trimmed_lines(sys.stdin):
 | 
						|
        if review_mode:
 | 
						|
            logs.append(line)
 | 
						|
 | 
						|
        if reaction := bot.react(line):
 | 
						|
            print(reaction, flush=True)
 | 
						|
        elif review_mode:
 | 
						|
            print('{"type":"none","meta":{"mask_bits":0}}', flush=True)
 | 
						|
 | 
						|
    if review_mode:
 | 
						|
        grp = GRP(**config['grp']['network'])
 | 
						|
        grp_state = torch.load(config['grp']['state_file'], weights_only=True, map_location=torch.device('cpu'))
 | 
						|
        grp.load_state_dict(grp_state['model'])
 | 
						|
 | 
						|
        ins = Grp.load_log('\n'.join(logs))
 | 
						|
        feature = ins.take_feature()
 | 
						|
        seq = list(map(
 | 
						|
            lambda idx: torch.as_tensor(feature[:idx+1], device=device),
 | 
						|
            range(len(feature)),
 | 
						|
        ))
 | 
						|
 | 
						|
        with torch.inference_mode():
 | 
						|
            logits = grp(seq)
 | 
						|
        matrix = grp.calc_matrix(logits)
 | 
						|
        extra_data = {
 | 
						|
            'model_tag': tag,
 | 
						|
            'phi_matrix': matrix.tolist(),
 | 
						|
        }
 | 
						|
        print(json.dumps(extra_data), flush=True)
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    try:
 | 
						|
        main()
 | 
						|
    except KeyboardInterrupt:
 | 
						|
        pass
 |