import prelude import logging import socket import torch import numpy as np import time import gc from os import path from model import Brain, DQN from player import TrainPlayer from common import send_msg, recv_msg from config import config def main(): remote = (config['online']['remote']['host'], config['online']['remote']['port']) device = torch.device(config['control']['device']) version = config['control']['version'] num_blocks = config['resnet']['num_blocks'] conv_channels = config['resnet']['conv_channels'] mortal = Brain(version=version, num_blocks=num_blocks, conv_channels=conv_channels).to(device).eval() dqn = DQN(version=version).to(device) if config['online']['enable_compile']: mortal.compile() dqn.compile() train_player = TrainPlayer() param_version = -1 pts = np.array([90, 45, 0, -135]) history_window = config['online']['history_window'] history = [] while True: while True: with socket.socket() as conn: conn.connect(remote) msg = { 'type': 'get_param', 'param_version': param_version, } send_msg(conn, msg) rsp = recv_msg(conn, map_location=device) if rsp['status'] == 'ok': param_version = rsp['param_version'] break time.sleep(3) mortal.load_state_dict(rsp['mortal']) dqn.load_state_dict(rsp['dqn']) logging.info('param has been updated') rankings, file_list = train_player.train_play(mortal, dqn, device) avg_rank = rankings @ np.arange(1, 5) / rankings.sum() avg_pt = rankings @ pts / rankings.sum() history.append(np.array(rankings)) if len(history) > history_window: del history[0] sum_rankings = np.sum(history, axis=0) ma_avg_rank = sum_rankings @ np.arange(1, 5) / sum_rankings.sum() ma_avg_pt = sum_rankings @ pts / sum_rankings.sum() logging.info(f'trainee rankings: {rankings} ({avg_rank:.6}, {avg_pt:.6}pt)') logging.info(f'last {len(history)} sessions: {sum_rankings} ({ma_avg_rank:.6}, {ma_avg_pt:.6}pt)') logs = {} for filename in file_list: with open(filename, 'rb') as f: logs[path.basename(filename)] = f.read() with socket.socket() as conn: conn.connect(remote) send_msg(conn, { 'type': 'submit_replay', 'logs': logs, 'param_version': param_version, }) logging.info('logs have been submitted') gc.collect() torch.cuda.empty_cache() torch.cuda.synchronize() if __name__ == '__main__': try: main() except KeyboardInterrupt: pass