Mortal
This commit is contained in:
172
mortal/server.py
Normal file
172
mortal/server.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import prelude
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
from os import path
|
||||
from io import BytesIO
|
||||
from typing import *
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from socketserver import ThreadingTCPServer, BaseRequestHandler
|
||||
from threading import Lock
|
||||
from common import send_msg, recv_msg, UnexpectedEOF
|
||||
from config import config
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
buffer_dir: str
|
||||
drain_dir: str
|
||||
capacity: int
|
||||
force_sequential: bool
|
||||
dir_lock: Lock
|
||||
param_lock: Lock
|
||||
# fields below are protected by dir_lock
|
||||
buffer_size: int
|
||||
submission_id: int
|
||||
# fields below are protected by param_lock
|
||||
mortal_param: Optional[OrderedDict]
|
||||
dqn_param: Optional[OrderedDict]
|
||||
param_version: int
|
||||
idle_param_version: int
|
||||
S = None
|
||||
|
||||
class Handler(BaseRequestHandler):
|
||||
def handle(self):
|
||||
msg = self.recv_msg()
|
||||
match msg['type']:
|
||||
# called by workers
|
||||
case 'get_param':
|
||||
self.handle_get_param(msg)
|
||||
case 'submit_replay':
|
||||
self.handle_submit_replay(msg)
|
||||
# called by trainer
|
||||
case 'submit_param':
|
||||
self.handle_submit_param(msg)
|
||||
case 'drain':
|
||||
self.handle_drain()
|
||||
|
||||
def handle_get_param(self, msg):
|
||||
with S.dir_lock:
|
||||
overflow = S.buffer_size >= S.capacity
|
||||
with S.param_lock:
|
||||
has_param = S.mortal_param is not None and S.dqn_param is not None
|
||||
if overflow:
|
||||
self.send_msg({'status': 'samples overflow'})
|
||||
return
|
||||
if not has_param:
|
||||
self.send_msg({'status': 'empty param'})
|
||||
return
|
||||
|
||||
client_param_version = msg['param_version']
|
||||
buf = BytesIO()
|
||||
with S.param_lock:
|
||||
if S.force_sequential and S.idle_param_version <= client_param_version:
|
||||
res = {'status': 'trainer is busy'}
|
||||
else:
|
||||
res = {
|
||||
'status': 'ok',
|
||||
'mortal': S.mortal_param,
|
||||
'dqn': S.dqn_param,
|
||||
'param_version': S.param_version,
|
||||
}
|
||||
torch.save(res, buf)
|
||||
self.send_msg(buf.getbuffer(), packed=True)
|
||||
|
||||
def handle_submit_replay(self, msg):
|
||||
with S.dir_lock:
|
||||
for filename, content in msg['logs'].items():
|
||||
filepath = path.join(S.buffer_dir, f'{S.submission_id}_{filename}')
|
||||
with open(filepath, 'wb') as f:
|
||||
f.write(content)
|
||||
S.buffer_size += len(msg['logs'])
|
||||
S.submission_id += 1
|
||||
logging.info(f'total buffer size: {S.buffer_size}')
|
||||
|
||||
def handle_submit_param(self, msg):
|
||||
with S.param_lock:
|
||||
S.mortal_param = msg['mortal']
|
||||
S.dqn_param = msg['dqn']
|
||||
S.param_version += 1
|
||||
if msg['is_idle']:
|
||||
S.idle_param_version = S.param_version
|
||||
|
||||
def handle_drain(self):
|
||||
drained_size = 0
|
||||
with S.dir_lock:
|
||||
buffer_list = os.listdir(S.buffer_dir)
|
||||
raw_count = len(buffer_list)
|
||||
assert raw_count == S.buffer_size
|
||||
if (not S.force_sequential or raw_count >= S.capacity) and raw_count > 0:
|
||||
old_drain_list = os.listdir(S.drain_dir)
|
||||
for filename in old_drain_list:
|
||||
filepath = path.join(S.drain_dir, filename)
|
||||
os.remove(filepath)
|
||||
for filename in buffer_list:
|
||||
src = path.join(S.buffer_dir, filename)
|
||||
dst = path.join(S.drain_dir, filename)
|
||||
shutil.move(src, dst)
|
||||
drained_size = raw_count
|
||||
S.buffer_size = 0
|
||||
logging.info(f'files transferred to trainer: {drained_size}')
|
||||
logging.info(f'total buffer size: {S.buffer_size}')
|
||||
self.send_msg({
|
||||
'count': drained_size,
|
||||
'drain_dir': S.drain_dir,
|
||||
})
|
||||
|
||||
def send_msg(self, msg, packed=False):
|
||||
return send_msg(self.request, msg, packed)
|
||||
|
||||
def recv_msg(self):
|
||||
return recv_msg(self.request)
|
||||
|
||||
class Server(ThreadingTCPServer):
|
||||
def handle_error(self, request, client_address):
|
||||
typ, _, _ = sys.exc_info()
|
||||
if typ is BrokenPipeError or typ is UnexpectedEOF:
|
||||
return
|
||||
return super().handle_error(request, client_address)
|
||||
|
||||
def main():
|
||||
global S
|
||||
cfg = config['online']['server']
|
||||
S = State(
|
||||
buffer_dir = path.abspath(cfg['buffer_dir']),
|
||||
drain_dir = path.abspath(cfg['drain_dir']),
|
||||
capacity = cfg['capacity'],
|
||||
force_sequential = cfg['force_sequential'],
|
||||
dir_lock = Lock(),
|
||||
param_lock = Lock(),
|
||||
buffer_size = 0,
|
||||
submission_id = 0,
|
||||
mortal_param = None,
|
||||
dqn_param = None,
|
||||
param_version = 0,
|
||||
idle_param_version = 0,
|
||||
)
|
||||
|
||||
bind_addr = (config['online']['remote']['host'], config['online']['remote']['port'])
|
||||
if path.isdir(S.buffer_dir):
|
||||
shutil.rmtree(S.buffer_dir)
|
||||
if path.isdir(S.drain_dir):
|
||||
shutil.rmtree(S.drain_dir)
|
||||
os.makedirs(S.buffer_dir)
|
||||
os.makedirs(S.drain_dir)
|
||||
|
||||
with Server(bind_addr, Handler, bind_and_activate=False) as server:
|
||||
server.allow_reuse_address = True
|
||||
server.daemon_threads = True
|
||||
server.server_bind()
|
||||
server.server_activate()
|
||||
host, port = bind_addr
|
||||
logging.info(f'listening on {host}:{port}')
|
||||
server.serve_forever()
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
Reference in New Issue
Block a user