import torch import socket import struct import time from typing import * from io import BytesIO from functools import partial from tqdm.auto import tqdm as orig_tqdm from config import config tqdm = partial(orig_tqdm, unit='batch', dynamic_ncols=True, ascii=True) def parameter_count(module): return sum(p.numel() for p in module.parameters() if p.requires_grad) def filtered_trimmed_lines(lines): return filter(lambda l: l, map(lambda l: l.strip(), lines)) def iter_grads(parameters, take=False): for p in parameters: if p.grad is not None: if take: # Set to zero instead of None to preserve the layout and make it # easier to assign back later yield p.grad.clone() p.grad.zero_() else: yield p.grad def drain(): remote = (config['online']['remote']['host'], config['online']['remote']['port']) while True: with socket.socket() as conn: conn.connect(remote) send_msg(conn, {'type': 'drain'}) msg = recv_msg(conn) if msg['count'] == 0: time.sleep(5) continue return msg['drain_dir'] def submit_param(mortal, dqn, is_idle=False): remote = (config['online']['remote']['host'], config['online']['remote']['port']) with socket.socket() as conn: conn.connect(remote) send_msg(conn, { 'type': 'submit_param', 'mortal': mortal.state_dict(), 'dqn': dqn.state_dict(), 'is_idle': is_idle, }) def send_msg(conn: socket.socket, msg, packed=False): if packed: tx = msg else: buf = BytesIO() torch.save(msg, buf) tx = buf.getbuffer() conn.sendall(struct.pack(' 0 ret = bytearray(size) buf = memoryview(ret) while len(buf) > 0: n = conn.recv_into(buf) if n == 0: raise UnexpectedEOF() buf = buf[n:] return bytes(ret) class UnexpectedEOF(Exception): def __init__(self): super().__init__('unexpected EOF')