Files
Mortal-Copied/mortal/client.py
e2hang b7a7d7404a
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled
Mortal
2025-10-07 20:30:03 +08:00

89 lines
2.8 KiB
Python

import prelude
import logging
import socket
import torch
import numpy as np
import time
import gc
from os import path
from model import Brain, DQN
from player import TrainPlayer
from common import send_msg, recv_msg
from config import config
def main():
remote = (config['online']['remote']['host'], config['online']['remote']['port'])
device = torch.device(config['control']['device'])
version = config['control']['version']
num_blocks = config['resnet']['num_blocks']
conv_channels = config['resnet']['conv_channels']
mortal = Brain(version=version, num_blocks=num_blocks, conv_channels=conv_channels).to(device).eval()
dqn = DQN(version=version).to(device)
if config['online']['enable_compile']:
mortal.compile()
dqn.compile()
train_player = TrainPlayer()
param_version = -1
pts = np.array([90, 45, 0, -135])
history_window = config['online']['history_window']
history = []
while True:
while True:
with socket.socket() as conn:
conn.connect(remote)
msg = {
'type': 'get_param',
'param_version': param_version,
}
send_msg(conn, msg)
rsp = recv_msg(conn, map_location=device)
if rsp['status'] == 'ok':
param_version = rsp['param_version']
break
time.sleep(3)
mortal.load_state_dict(rsp['mortal'])
dqn.load_state_dict(rsp['dqn'])
logging.info('param has been updated')
rankings, file_list = train_player.train_play(mortal, dqn, device)
avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
avg_pt = rankings @ pts / rankings.sum()
history.append(np.array(rankings))
if len(history) > history_window:
del history[0]
sum_rankings = np.sum(history, axis=0)
ma_avg_rank = sum_rankings @ np.arange(1, 5) / sum_rankings.sum()
ma_avg_pt = sum_rankings @ pts / sum_rankings.sum()
logging.info(f'trainee rankings: {rankings} ({avg_rank:.6}, {avg_pt:.6}pt)')
logging.info(f'last {len(history)} sessions: {sum_rankings} ({ma_avg_rank:.6}, {ma_avg_pt:.6}pt)')
logs = {}
for filename in file_list:
with open(filename, 'rb') as f:
logs[path.basename(filename)] = f.read()
with socket.socket() as conn:
conn.connect(remote)
send_msg(conn, {
'type': 'submit_replay',
'logs': logs,
'param_version': param_version,
})
logging.info('logs have been submitted')
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass