104 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			104 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import prelude
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import torch
 | 
						|
import secrets
 | 
						|
import os
 | 
						|
from model import Brain, DQN
 | 
						|
from engine import MortalEngine
 | 
						|
from libriichi.arena import OneVsThree
 | 
						|
from config import config
 | 
						|
 | 
						|
def main():
 | 
						|
    cfg = config['1v3']
 | 
						|
    games_per_iter = cfg['games_per_iter']
 | 
						|
    seeds_per_iter = games_per_iter // 4
 | 
						|
    iters = cfg['iters']
 | 
						|
    log_dir = cfg['log_dir']
 | 
						|
    use_akochan = cfg['akochan']['enabled']
 | 
						|
 | 
						|
    if (key := cfg.get('seed_key', -1)) == -1:
 | 
						|
        key = secrets.randbits(64)
 | 
						|
 | 
						|
    if use_akochan:
 | 
						|
        os.environ['AKOCHAN_DIR'] = cfg['akochan']['dir']
 | 
						|
        os.environ['AKOCHAN_TACTICS'] = cfg['akochan']['tactics']
 | 
						|
    else:
 | 
						|
        state = torch.load(cfg['champion']['state_file'], weights_only=True, map_location=torch.device('cpu'))
 | 
						|
        cham_cfg = state['config']
 | 
						|
        version = cham_cfg['control'].get('version', 1)
 | 
						|
        conv_channels = cham_cfg['resnet']['conv_channels']
 | 
						|
        num_blocks = cham_cfg['resnet']['num_blocks']
 | 
						|
        mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
 | 
						|
        dqn = DQN(version=version).eval()
 | 
						|
        mortal.load_state_dict(state['mortal'])
 | 
						|
        dqn.load_state_dict(state['current_dqn'])
 | 
						|
        if cfg['champion']['enable_compile']:
 | 
						|
            mortal.compile()
 | 
						|
            dqn.compile()
 | 
						|
        engine_cham = MortalEngine(
 | 
						|
            mortal,
 | 
						|
            dqn,
 | 
						|
            is_oracle = False,
 | 
						|
            version = version,
 | 
						|
            device = torch.device(cfg['champion']['device']),
 | 
						|
            enable_amp = cfg['champion']['enable_amp'],
 | 
						|
            enable_rule_based_agari_guard = cfg['champion']['enable_rule_based_agari_guard'],
 | 
						|
            name = cfg['champion']['name'],
 | 
						|
        )
 | 
						|
 | 
						|
    state = torch.load(cfg['challenger']['state_file'], weights_only=True, map_location=torch.device('cpu'))
 | 
						|
    chal_cfg = state['config']
 | 
						|
    version = chal_cfg['control'].get('version', 1)
 | 
						|
    conv_channels = chal_cfg['resnet']['conv_channels']
 | 
						|
    num_blocks = chal_cfg['resnet']['num_blocks']
 | 
						|
    mortal = Brain(version=version, conv_channels=conv_channels, num_blocks=num_blocks).eval()
 | 
						|
    dqn = DQN(version=version).eval()
 | 
						|
    mortal.load_state_dict(state['mortal'])
 | 
						|
    dqn.load_state_dict(state['current_dqn'])
 | 
						|
    if cfg['challenger']['enable_compile']:
 | 
						|
        mortal.compile()
 | 
						|
        dqn.compile()
 | 
						|
    engine_chal = MortalEngine(
 | 
						|
        mortal,
 | 
						|
        dqn,
 | 
						|
        is_oracle = False,
 | 
						|
        version = version,
 | 
						|
        device = torch.device(cfg['challenger']['device']),
 | 
						|
        enable_amp = cfg['challenger']['enable_amp'],
 | 
						|
        enable_rule_based_agari_guard = cfg['challenger']['enable_rule_based_agari_guard'],
 | 
						|
        name = cfg['challenger']['name'],
 | 
						|
    )
 | 
						|
 | 
						|
    seed_start = 10000
 | 
						|
    for i, seed in enumerate(range(seed_start, seed_start + seeds_per_iter * iters, seeds_per_iter)):
 | 
						|
        print('-' * 50)
 | 
						|
        print('#', i)
 | 
						|
        env = OneVsThree(
 | 
						|
            disable_progress_bar = False,
 | 
						|
            log_dir = log_dir,
 | 
						|
        )
 | 
						|
        if use_akochan:
 | 
						|
            rankings = env.ako_vs_py(
 | 
						|
                engine = engine_chal,
 | 
						|
                seed_start = (seed, key),
 | 
						|
                seed_count = seeds_per_iter,
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            rankings = env.py_vs_py(
 | 
						|
                challenger = engine_chal,
 | 
						|
                champion = engine_cham,
 | 
						|
                seed_start = (seed, key),
 | 
						|
                seed_count = seeds_per_iter,
 | 
						|
            )
 | 
						|
        rankings = np.array(rankings)
 | 
						|
        avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
 | 
						|
        avg_pt = rankings @ np.array([90, 45, 0, -135]) / rankings.sum()
 | 
						|
        print(f'challenger rankings: {rankings} ({avg_rank}, {avg_pt}pt)')
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    try:
 | 
						|
        main()
 | 
						|
    except KeyboardInterrupt:
 | 
						|
        pass
 |