173 lines
5.5 KiB
Python
173 lines
5.5 KiB
Python
import prelude
|
|
|
|
import logging
|
|
import shutil
|
|
import torch
|
|
import sys
|
|
import os
|
|
from os import path
|
|
from io import BytesIO
|
|
from typing import *
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
from socketserver import ThreadingTCPServer, BaseRequestHandler
|
|
from threading import Lock
|
|
from common import send_msg, recv_msg, UnexpectedEOF
|
|
from config import config
|
|
|
|
@dataclass
|
|
class State:
|
|
buffer_dir: str
|
|
drain_dir: str
|
|
capacity: int
|
|
force_sequential: bool
|
|
dir_lock: Lock
|
|
param_lock: Lock
|
|
# fields below are protected by dir_lock
|
|
buffer_size: int
|
|
submission_id: int
|
|
# fields below are protected by param_lock
|
|
mortal_param: Optional[OrderedDict]
|
|
dqn_param: Optional[OrderedDict]
|
|
param_version: int
|
|
idle_param_version: int
|
|
S = None
|
|
|
|
class Handler(BaseRequestHandler):
|
|
def handle(self):
|
|
msg = self.recv_msg()
|
|
match msg['type']:
|
|
# called by workers
|
|
case 'get_param':
|
|
self.handle_get_param(msg)
|
|
case 'submit_replay':
|
|
self.handle_submit_replay(msg)
|
|
# called by trainer
|
|
case 'submit_param':
|
|
self.handle_submit_param(msg)
|
|
case 'drain':
|
|
self.handle_drain()
|
|
|
|
def handle_get_param(self, msg):
|
|
with S.dir_lock:
|
|
overflow = S.buffer_size >= S.capacity
|
|
with S.param_lock:
|
|
has_param = S.mortal_param is not None and S.dqn_param is not None
|
|
if overflow:
|
|
self.send_msg({'status': 'samples overflow'})
|
|
return
|
|
if not has_param:
|
|
self.send_msg({'status': 'empty param'})
|
|
return
|
|
|
|
client_param_version = msg['param_version']
|
|
buf = BytesIO()
|
|
with S.param_lock:
|
|
if S.force_sequential and S.idle_param_version <= client_param_version:
|
|
res = {'status': 'trainer is busy'}
|
|
else:
|
|
res = {
|
|
'status': 'ok',
|
|
'mortal': S.mortal_param,
|
|
'dqn': S.dqn_param,
|
|
'param_version': S.param_version,
|
|
}
|
|
torch.save(res, buf)
|
|
self.send_msg(buf.getbuffer(), packed=True)
|
|
|
|
def handle_submit_replay(self, msg):
|
|
with S.dir_lock:
|
|
for filename, content in msg['logs'].items():
|
|
filepath = path.join(S.buffer_dir, f'{S.submission_id}_{filename}')
|
|
with open(filepath, 'wb') as f:
|
|
f.write(content)
|
|
S.buffer_size += len(msg['logs'])
|
|
S.submission_id += 1
|
|
logging.info(f'total buffer size: {S.buffer_size}')
|
|
|
|
def handle_submit_param(self, msg):
|
|
with S.param_lock:
|
|
S.mortal_param = msg['mortal']
|
|
S.dqn_param = msg['dqn']
|
|
S.param_version += 1
|
|
if msg['is_idle']:
|
|
S.idle_param_version = S.param_version
|
|
|
|
def handle_drain(self):
|
|
drained_size = 0
|
|
with S.dir_lock:
|
|
buffer_list = os.listdir(S.buffer_dir)
|
|
raw_count = len(buffer_list)
|
|
assert raw_count == S.buffer_size
|
|
if (not S.force_sequential or raw_count >= S.capacity) and raw_count > 0:
|
|
old_drain_list = os.listdir(S.drain_dir)
|
|
for filename in old_drain_list:
|
|
filepath = path.join(S.drain_dir, filename)
|
|
os.remove(filepath)
|
|
for filename in buffer_list:
|
|
src = path.join(S.buffer_dir, filename)
|
|
dst = path.join(S.drain_dir, filename)
|
|
shutil.move(src, dst)
|
|
drained_size = raw_count
|
|
S.buffer_size = 0
|
|
logging.info(f'files transferred to trainer: {drained_size}')
|
|
logging.info(f'total buffer size: {S.buffer_size}')
|
|
self.send_msg({
|
|
'count': drained_size,
|
|
'drain_dir': S.drain_dir,
|
|
})
|
|
|
|
def send_msg(self, msg, packed=False):
|
|
return send_msg(self.request, msg, packed)
|
|
|
|
def recv_msg(self):
|
|
return recv_msg(self.request)
|
|
|
|
class Server(ThreadingTCPServer):
|
|
def handle_error(self, request, client_address):
|
|
typ, _, _ = sys.exc_info()
|
|
if typ is BrokenPipeError or typ is UnexpectedEOF:
|
|
return
|
|
return super().handle_error(request, client_address)
|
|
|
|
def main():
|
|
global S
|
|
cfg = config['online']['server']
|
|
S = State(
|
|
buffer_dir = path.abspath(cfg['buffer_dir']),
|
|
drain_dir = path.abspath(cfg['drain_dir']),
|
|
capacity = cfg['capacity'],
|
|
force_sequential = cfg['force_sequential'],
|
|
dir_lock = Lock(),
|
|
param_lock = Lock(),
|
|
buffer_size = 0,
|
|
submission_id = 0,
|
|
mortal_param = None,
|
|
dqn_param = None,
|
|
param_version = 0,
|
|
idle_param_version = 0,
|
|
)
|
|
|
|
bind_addr = (config['online']['remote']['host'], config['online']['remote']['port'])
|
|
if path.isdir(S.buffer_dir):
|
|
shutil.rmtree(S.buffer_dir)
|
|
if path.isdir(S.drain_dir):
|
|
shutil.rmtree(S.drain_dir)
|
|
os.makedirs(S.buffer_dir)
|
|
os.makedirs(S.drain_dir)
|
|
|
|
with Server(bind_addr, Handler, bind_and_activate=False) as server:
|
|
server.allow_reuse_address = True
|
|
server.daemon_threads = True
|
|
server.server_bind()
|
|
server.server_activate()
|
|
host, port = bind_addr
|
|
logging.info(f'listening on {host}:{port}')
|
|
server.serve_forever()
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
main()
|
|
except KeyboardInterrupt:
|
|
pass
|