173 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			173 lines
		
	
	
		
			5.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import prelude
 | 
						|
 | 
						|
import logging
 | 
						|
import shutil
 | 
						|
import torch
 | 
						|
import sys
 | 
						|
import os
 | 
						|
from os import path
 | 
						|
from io import BytesIO
 | 
						|
from typing import *
 | 
						|
from collections import OrderedDict
 | 
						|
from dataclasses import dataclass
 | 
						|
from socketserver import ThreadingTCPServer, BaseRequestHandler
 | 
						|
from threading import Lock
 | 
						|
from common import send_msg, recv_msg, UnexpectedEOF
 | 
						|
from config import config
 | 
						|
 | 
						|
@dataclass
 | 
						|
class State:
 | 
						|
    buffer_dir: str
 | 
						|
    drain_dir: str
 | 
						|
    capacity: int
 | 
						|
    force_sequential: bool
 | 
						|
    dir_lock: Lock
 | 
						|
    param_lock: Lock
 | 
						|
    # fields below are protected by dir_lock
 | 
						|
    buffer_size: int
 | 
						|
    submission_id: int
 | 
						|
    # fields below are protected by param_lock
 | 
						|
    mortal_param: Optional[OrderedDict]
 | 
						|
    dqn_param: Optional[OrderedDict]
 | 
						|
    param_version: int
 | 
						|
    idle_param_version: int
 | 
						|
S = None
 | 
						|
 | 
						|
class Handler(BaseRequestHandler):
 | 
						|
    def handle(self):
 | 
						|
        msg = self.recv_msg()
 | 
						|
        match msg['type']:
 | 
						|
            # called by workers
 | 
						|
            case 'get_param':
 | 
						|
                self.handle_get_param(msg)
 | 
						|
            case 'submit_replay':
 | 
						|
                self.handle_submit_replay(msg)
 | 
						|
            # called by trainer
 | 
						|
            case 'submit_param':
 | 
						|
                self.handle_submit_param(msg)
 | 
						|
            case 'drain':
 | 
						|
                self.handle_drain()
 | 
						|
 | 
						|
    def handle_get_param(self, msg):
 | 
						|
        with S.dir_lock:
 | 
						|
            overflow = S.buffer_size >= S.capacity
 | 
						|
            with S.param_lock:
 | 
						|
                has_param = S.mortal_param is not None and S.dqn_param is not None
 | 
						|
        if overflow:
 | 
						|
            self.send_msg({'status': 'samples overflow'})
 | 
						|
            return
 | 
						|
        if not has_param:
 | 
						|
            self.send_msg({'status': 'empty param'})
 | 
						|
            return
 | 
						|
 | 
						|
        client_param_version = msg['param_version']
 | 
						|
        buf = BytesIO()
 | 
						|
        with S.param_lock:
 | 
						|
            if S.force_sequential and S.idle_param_version <= client_param_version:
 | 
						|
                res = {'status': 'trainer is busy'}
 | 
						|
            else:
 | 
						|
                res = {
 | 
						|
                    'status': 'ok',
 | 
						|
                    'mortal': S.mortal_param,
 | 
						|
                    'dqn': S.dqn_param,
 | 
						|
                    'param_version': S.param_version,
 | 
						|
                }
 | 
						|
            torch.save(res, buf)
 | 
						|
        self.send_msg(buf.getbuffer(), packed=True)
 | 
						|
 | 
						|
    def handle_submit_replay(self, msg):
 | 
						|
        with S.dir_lock:
 | 
						|
            for filename, content in msg['logs'].items():
 | 
						|
                filepath = path.join(S.buffer_dir, f'{S.submission_id}_{filename}')
 | 
						|
                with open(filepath, 'wb') as f:
 | 
						|
                    f.write(content)
 | 
						|
            S.buffer_size += len(msg['logs'])
 | 
						|
            S.submission_id += 1
 | 
						|
            logging.info(f'total buffer size: {S.buffer_size}')
 | 
						|
 | 
						|
    def handle_submit_param(self, msg):
 | 
						|
        with S.param_lock:
 | 
						|
            S.mortal_param = msg['mortal']
 | 
						|
            S.dqn_param = msg['dqn']
 | 
						|
            S.param_version += 1
 | 
						|
            if msg['is_idle']:
 | 
						|
                S.idle_param_version = S.param_version
 | 
						|
 | 
						|
    def handle_drain(self):
 | 
						|
        drained_size = 0
 | 
						|
        with S.dir_lock:
 | 
						|
            buffer_list = os.listdir(S.buffer_dir)
 | 
						|
            raw_count = len(buffer_list)
 | 
						|
            assert raw_count == S.buffer_size
 | 
						|
            if (not S.force_sequential or raw_count >= S.capacity) and raw_count > 0:
 | 
						|
                old_drain_list = os.listdir(S.drain_dir)
 | 
						|
                for filename in old_drain_list:
 | 
						|
                    filepath = path.join(S.drain_dir, filename)
 | 
						|
                    os.remove(filepath)
 | 
						|
                for filename in buffer_list:
 | 
						|
                    src = path.join(S.buffer_dir, filename)
 | 
						|
                    dst = path.join(S.drain_dir, filename)
 | 
						|
                    shutil.move(src, dst)
 | 
						|
                drained_size = raw_count
 | 
						|
                S.buffer_size = 0
 | 
						|
                logging.info(f'files transferred to trainer: {drained_size}')
 | 
						|
                logging.info(f'total buffer size: {S.buffer_size}')
 | 
						|
        self.send_msg({
 | 
						|
            'count': drained_size,
 | 
						|
            'drain_dir': S.drain_dir,
 | 
						|
        })
 | 
						|
 | 
						|
    def send_msg(self, msg, packed=False):
 | 
						|
        return send_msg(self.request, msg, packed)
 | 
						|
 | 
						|
    def recv_msg(self):
 | 
						|
        return recv_msg(self.request)
 | 
						|
 | 
						|
class Server(ThreadingTCPServer):
 | 
						|
    def handle_error(self, request, client_address):
 | 
						|
        typ, _, _ = sys.exc_info()
 | 
						|
        if typ is BrokenPipeError or typ is UnexpectedEOF:
 | 
						|
            return
 | 
						|
        return super().handle_error(request, client_address)
 | 
						|
 | 
						|
def main():
 | 
						|
    global S
 | 
						|
    cfg = config['online']['server']
 | 
						|
    S = State(
 | 
						|
        buffer_dir = path.abspath(cfg['buffer_dir']),
 | 
						|
        drain_dir = path.abspath(cfg['drain_dir']),
 | 
						|
        capacity = cfg['capacity'],
 | 
						|
        force_sequential = cfg['force_sequential'],
 | 
						|
        dir_lock = Lock(),
 | 
						|
        param_lock = Lock(),
 | 
						|
        buffer_size = 0,
 | 
						|
        submission_id = 0,
 | 
						|
        mortal_param = None,
 | 
						|
        dqn_param = None,
 | 
						|
        param_version = 0,
 | 
						|
        idle_param_version = 0,
 | 
						|
    )
 | 
						|
 | 
						|
    bind_addr = (config['online']['remote']['host'], config['online']['remote']['port'])
 | 
						|
    if path.isdir(S.buffer_dir):
 | 
						|
        shutil.rmtree(S.buffer_dir)
 | 
						|
    if path.isdir(S.drain_dir):
 | 
						|
        shutil.rmtree(S.drain_dir)
 | 
						|
    os.makedirs(S.buffer_dir)
 | 
						|
    os.makedirs(S.drain_dir)
 | 
						|
 | 
						|
    with Server(bind_addr, Handler, bind_and_activate=False) as server:
 | 
						|
        server.allow_reuse_address = True
 | 
						|
        server.daemon_threads = True
 | 
						|
        server.server_bind()
 | 
						|
        server.server_activate()
 | 
						|
        host, port = bind_addr
 | 
						|
        logging.info(f'listening on {host}:{port}')
 | 
						|
        server.serve_forever()
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    try:
 | 
						|
        main()
 | 
						|
    except KeyboardInterrupt:
 | 
						|
        pass
 |