84 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import socket
 | 
						|
import struct
 | 
						|
import time
 | 
						|
from typing import *
 | 
						|
from io import BytesIO
 | 
						|
from functools import partial
 | 
						|
from tqdm.auto import tqdm as orig_tqdm
 | 
						|
from config import config
 | 
						|
 | 
						|
tqdm = partial(orig_tqdm, unit='batch', dynamic_ncols=True, ascii=True)
 | 
						|
 | 
						|
def parameter_count(module):
 | 
						|
    return sum(p.numel() for p in module.parameters() if p.requires_grad)
 | 
						|
 | 
						|
def filtered_trimmed_lines(lines):
 | 
						|
    return filter(lambda l: l, map(lambda l: l.strip(), lines))
 | 
						|
 | 
						|
def iter_grads(parameters, take=False):
 | 
						|
    for p in parameters:
 | 
						|
        if p.grad is not None:
 | 
						|
            if take:
 | 
						|
                # Set to zero instead of None to preserve the layout and make it
 | 
						|
                # easier to assign back later
 | 
						|
                yield p.grad.clone()
 | 
						|
                p.grad.zero_()
 | 
						|
            else:
 | 
						|
                yield p.grad
 | 
						|
 | 
						|
def drain():
 | 
						|
    remote = (config['online']['remote']['host'], config['online']['remote']['port'])
 | 
						|
    while True:
 | 
						|
        with socket.socket() as conn:
 | 
						|
            conn.connect(remote)
 | 
						|
            send_msg(conn, {'type': 'drain'})
 | 
						|
            msg = recv_msg(conn)
 | 
						|
        if msg['count'] == 0:
 | 
						|
            time.sleep(5)
 | 
						|
            continue
 | 
						|
        return msg['drain_dir']
 | 
						|
 | 
						|
def submit_param(mortal, dqn, is_idle=False):
 | 
						|
    remote = (config['online']['remote']['host'], config['online']['remote']['port'])
 | 
						|
    with socket.socket() as conn:
 | 
						|
        conn.connect(remote)
 | 
						|
        send_msg(conn, {
 | 
						|
            'type': 'submit_param',
 | 
						|
            'mortal': mortal.state_dict(),
 | 
						|
            'dqn': dqn.state_dict(),
 | 
						|
            'is_idle': is_idle,
 | 
						|
        })
 | 
						|
 | 
						|
def send_msg(conn: socket.socket, msg, packed=False):
 | 
						|
    if packed:
 | 
						|
        tx = msg
 | 
						|
    else:
 | 
						|
        buf = BytesIO()
 | 
						|
        torch.save(msg, buf)
 | 
						|
        tx = buf.getbuffer()
 | 
						|
    conn.sendall(struct.pack('<Q', len(tx)))
 | 
						|
    conn.sendall(tx)
 | 
						|
 | 
						|
def recv_msg(conn: socket.socket, map_location=torch.device('cpu')):
 | 
						|
    rx = recv_binary(conn, 8)
 | 
						|
    (size,) = struct.unpack('<Q', rx)
 | 
						|
    rx = recv_binary(conn, size)
 | 
						|
    return torch.load(BytesIO(rx), weights_only=False, map_location=map_location) # TODO: weights_only=True
 | 
						|
 | 
						|
def recv_binary(conn: socket.socket, size):
 | 
						|
    assert size > 0
 | 
						|
    ret = bytearray(size)
 | 
						|
    buf = memoryview(ret)
 | 
						|
 | 
						|
    while len(buf) > 0:
 | 
						|
        n = conn.recv_into(buf)
 | 
						|
        if n == 0:
 | 
						|
            raise UnexpectedEOF()
 | 
						|
        buf = buf[n:]
 | 
						|
    return bytes(ret)
 | 
						|
 | 
						|
class UnexpectedEOF(Exception):
 | 
						|
    def __init__(self):
 | 
						|
        super().__init__('unexpected EOF')
 |