Files
Mortal-Copied/mortal/server.py
e2hang b7a7d7404a
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled
Mortal
2025-10-07 20:30:03 +08:00

173 lines
5.5 KiB
Python

import prelude
import logging
import shutil
import torch
import sys
import os
from os import path
from io import BytesIO
from typing import *
from collections import OrderedDict
from dataclasses import dataclass
from socketserver import ThreadingTCPServer, BaseRequestHandler
from threading import Lock
from common import send_msg, recv_msg, UnexpectedEOF
from config import config
@dataclass
class State:
buffer_dir: str
drain_dir: str
capacity: int
force_sequential: bool
dir_lock: Lock
param_lock: Lock
# fields below are protected by dir_lock
buffer_size: int
submission_id: int
# fields below are protected by param_lock
mortal_param: Optional[OrderedDict]
dqn_param: Optional[OrderedDict]
param_version: int
idle_param_version: int
S = None
class Handler(BaseRequestHandler):
def handle(self):
msg = self.recv_msg()
match msg['type']:
# called by workers
case 'get_param':
self.handle_get_param(msg)
case 'submit_replay':
self.handle_submit_replay(msg)
# called by trainer
case 'submit_param':
self.handle_submit_param(msg)
case 'drain':
self.handle_drain()
def handle_get_param(self, msg):
with S.dir_lock:
overflow = S.buffer_size >= S.capacity
with S.param_lock:
has_param = S.mortal_param is not None and S.dqn_param is not None
if overflow:
self.send_msg({'status': 'samples overflow'})
return
if not has_param:
self.send_msg({'status': 'empty param'})
return
client_param_version = msg['param_version']
buf = BytesIO()
with S.param_lock:
if S.force_sequential and S.idle_param_version <= client_param_version:
res = {'status': 'trainer is busy'}
else:
res = {
'status': 'ok',
'mortal': S.mortal_param,
'dqn': S.dqn_param,
'param_version': S.param_version,
}
torch.save(res, buf)
self.send_msg(buf.getbuffer(), packed=True)
def handle_submit_replay(self, msg):
with S.dir_lock:
for filename, content in msg['logs'].items():
filepath = path.join(S.buffer_dir, f'{S.submission_id}_{filename}')
with open(filepath, 'wb') as f:
f.write(content)
S.buffer_size += len(msg['logs'])
S.submission_id += 1
logging.info(f'total buffer size: {S.buffer_size}')
def handle_submit_param(self, msg):
with S.param_lock:
S.mortal_param = msg['mortal']
S.dqn_param = msg['dqn']
S.param_version += 1
if msg['is_idle']:
S.idle_param_version = S.param_version
def handle_drain(self):
drained_size = 0
with S.dir_lock:
buffer_list = os.listdir(S.buffer_dir)
raw_count = len(buffer_list)
assert raw_count == S.buffer_size
if (not S.force_sequential or raw_count >= S.capacity) and raw_count > 0:
old_drain_list = os.listdir(S.drain_dir)
for filename in old_drain_list:
filepath = path.join(S.drain_dir, filename)
os.remove(filepath)
for filename in buffer_list:
src = path.join(S.buffer_dir, filename)
dst = path.join(S.drain_dir, filename)
shutil.move(src, dst)
drained_size = raw_count
S.buffer_size = 0
logging.info(f'files transferred to trainer: {drained_size}')
logging.info(f'total buffer size: {S.buffer_size}')
self.send_msg({
'count': drained_size,
'drain_dir': S.drain_dir,
})
def send_msg(self, msg, packed=False):
return send_msg(self.request, msg, packed)
def recv_msg(self):
return recv_msg(self.request)
class Server(ThreadingTCPServer):
def handle_error(self, request, client_address):
typ, _, _ = sys.exc_info()
if typ is BrokenPipeError or typ is UnexpectedEOF:
return
return super().handle_error(request, client_address)
def main():
global S
cfg = config['online']['server']
S = State(
buffer_dir = path.abspath(cfg['buffer_dir']),
drain_dir = path.abspath(cfg['drain_dir']),
capacity = cfg['capacity'],
force_sequential = cfg['force_sequential'],
dir_lock = Lock(),
param_lock = Lock(),
buffer_size = 0,
submission_id = 0,
mortal_param = None,
dqn_param = None,
param_version = 0,
idle_param_version = 0,
)
bind_addr = (config['online']['remote']['host'], config['online']['remote']['port'])
if path.isdir(S.buffer_dir):
shutil.rmtree(S.buffer_dir)
if path.isdir(S.drain_dir):
shutil.rmtree(S.drain_dir)
os.makedirs(S.buffer_dir)
os.makedirs(S.drain_dir)
with Server(bind_addr, Handler, bind_and_activate=False) as server:
server.allow_reuse_address = True
server.daemon_threads = True
server.server_bind()
server.server_activate()
host, port = bind_addr
logging.info(f'listening on {host}:{port}')
server.serve_forever()
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
pass