Mortal
This commit is contained in:
88
mortal/client.py
Normal file
88
mortal/client.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import prelude
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
import gc
|
||||
from os import path
|
||||
from model import Brain, DQN
|
||||
from player import TrainPlayer
|
||||
from common import send_msg, recv_msg
|
||||
from config import config
|
||||
|
||||
def main():
|
||||
remote = (config['online']['remote']['host'], config['online']['remote']['port'])
|
||||
device = torch.device(config['control']['device'])
|
||||
version = config['control']['version']
|
||||
num_blocks = config['resnet']['num_blocks']
|
||||
conv_channels = config['resnet']['conv_channels']
|
||||
|
||||
mortal = Brain(version=version, num_blocks=num_blocks, conv_channels=conv_channels).to(device).eval()
|
||||
dqn = DQN(version=version).to(device)
|
||||
if config['online']['enable_compile']:
|
||||
mortal.compile()
|
||||
dqn.compile()
|
||||
|
||||
train_player = TrainPlayer()
|
||||
param_version = -1
|
||||
|
||||
pts = np.array([90, 45, 0, -135])
|
||||
history_window = config['online']['history_window']
|
||||
history = []
|
||||
|
||||
while True:
|
||||
while True:
|
||||
with socket.socket() as conn:
|
||||
conn.connect(remote)
|
||||
msg = {
|
||||
'type': 'get_param',
|
||||
'param_version': param_version,
|
||||
}
|
||||
send_msg(conn, msg)
|
||||
rsp = recv_msg(conn, map_location=device)
|
||||
if rsp['status'] == 'ok':
|
||||
param_version = rsp['param_version']
|
||||
break
|
||||
time.sleep(3)
|
||||
mortal.load_state_dict(rsp['mortal'])
|
||||
dqn.load_state_dict(rsp['dqn'])
|
||||
logging.info('param has been updated')
|
||||
|
||||
rankings, file_list = train_player.train_play(mortal, dqn, device)
|
||||
avg_rank = rankings @ np.arange(1, 5) / rankings.sum()
|
||||
avg_pt = rankings @ pts / rankings.sum()
|
||||
|
||||
history.append(np.array(rankings))
|
||||
if len(history) > history_window:
|
||||
del history[0]
|
||||
sum_rankings = np.sum(history, axis=0)
|
||||
ma_avg_rank = sum_rankings @ np.arange(1, 5) / sum_rankings.sum()
|
||||
ma_avg_pt = sum_rankings @ pts / sum_rankings.sum()
|
||||
|
||||
logging.info(f'trainee rankings: {rankings} ({avg_rank:.6}, {avg_pt:.6}pt)')
|
||||
logging.info(f'last {len(history)} sessions: {sum_rankings} ({ma_avg_rank:.6}, {ma_avg_pt:.6}pt)')
|
||||
|
||||
logs = {}
|
||||
for filename in file_list:
|
||||
with open(filename, 'rb') as f:
|
||||
logs[path.basename(filename)] = f.read()
|
||||
|
||||
with socket.socket() as conn:
|
||||
conn.connect(remote)
|
||||
send_msg(conn, {
|
||||
'type': 'submit_replay',
|
||||
'logs': logs,
|
||||
'param_version': param_version,
|
||||
})
|
||||
logging.info('logs have been submitted')
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
Reference in New Issue
Block a user