474 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			474 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
def train():
 | 
						|
    import prelude
 | 
						|
 | 
						|
    import logging
 | 
						|
    import sys
 | 
						|
    import os
 | 
						|
    import gc
 | 
						|
    import gzip
 | 
						|
    import json
 | 
						|
    import shutil
 | 
						|
    import random
 | 
						|
    import torch
 | 
						|
    from os import path
 | 
						|
    from glob import glob
 | 
						|
    from datetime import datetime
 | 
						|
    from itertools import chain
 | 
						|
    from torch import optim, nn
 | 
						|
    from torch.amp import GradScaler
 | 
						|
    from torch.nn.utils import clip_grad_norm_
 | 
						|
    from torch.utils.data import DataLoader
 | 
						|
    from torch.utils.tensorboard import SummaryWriter
 | 
						|
    from common import submit_param, parameter_count, drain, filtered_trimmed_lines, tqdm
 | 
						|
    from player import TestPlayer
 | 
						|
    from dataloader import FileDatasetsIter, worker_init_fn
 | 
						|
    from lr_scheduler import LinearWarmUpCosineAnnealingLR
 | 
						|
    from model import Brain, DQN, AuxNet
 | 
						|
    from libriichi.consts import obs_shape
 | 
						|
    from config import config
 | 
						|
 | 
						|
    version = config['control']['version']
 | 
						|
 | 
						|
    online = config['control']['online']
 | 
						|
    batch_size = config['control']['batch_size']
 | 
						|
    opt_step_every = config['control']['opt_step_every']
 | 
						|
    save_every = config['control']['save_every']
 | 
						|
    test_every = config['control']['test_every']
 | 
						|
    submit_every = config['control']['submit_every']
 | 
						|
    test_games = config['test_play']['games']
 | 
						|
    min_q_weight = config['cql']['min_q_weight']
 | 
						|
    next_rank_weight = config['aux']['next_rank_weight']
 | 
						|
    assert save_every % opt_step_every == 0
 | 
						|
    assert test_every % save_every == 0
 | 
						|
 | 
						|
    device = torch.device(config['control']['device'])
 | 
						|
    torch.backends.cudnn.benchmark = config['control']['enable_cudnn_benchmark']
 | 
						|
    enable_amp = config['control']['enable_amp']
 | 
						|
    enable_compile = config['control']['enable_compile']
 | 
						|
 | 
						|
    pts = config['env']['pts']
 | 
						|
    gamma = config['env']['gamma']
 | 
						|
    file_batch_size = config['dataset']['file_batch_size']
 | 
						|
    reserve_ratio = config['dataset']['reserve_ratio']
 | 
						|
    num_workers = config['dataset']['num_workers']
 | 
						|
    num_epochs = config['dataset']['num_epochs']
 | 
						|
    enable_augmentation = config['dataset']['enable_augmentation']
 | 
						|
    augmented_first = config['dataset']['augmented_first']
 | 
						|
    eps = config['optim']['eps']
 | 
						|
    betas = config['optim']['betas']
 | 
						|
    weight_decay = config['optim']['weight_decay']
 | 
						|
    max_grad_norm = config['optim']['max_grad_norm']
 | 
						|
 | 
						|
    mortal = Brain(version=version, **config['resnet']).to(device)
 | 
						|
    dqn = DQN(version=version).to(device)
 | 
						|
    aux_net = AuxNet((4,)).to(device)
 | 
						|
    all_models = (mortal, dqn, aux_net)
 | 
						|
    if enable_compile:
 | 
						|
        for m in all_models:
 | 
						|
            m.compile()
 | 
						|
 | 
						|
    logging.info(f'version: {version}')
 | 
						|
    logging.info(f'obs shape: {obs_shape(version)}')
 | 
						|
    logging.info(f'mortal params: {parameter_count(mortal):,}')
 | 
						|
    logging.info(f'dqn params: {parameter_count(dqn):,}')
 | 
						|
    logging.info(f'aux params: {parameter_count(aux_net):,}')
 | 
						|
 | 
						|
    mortal.freeze_bn(config['freeze_bn']['mortal'])
 | 
						|
 | 
						|
    decay_params = []
 | 
						|
    no_decay_params = []
 | 
						|
    for model in all_models:
 | 
						|
        params_dict = {}
 | 
						|
        to_decay = set()
 | 
						|
        for mod_name, mod in model.named_modules():
 | 
						|
            for name, param in mod.named_parameters(prefix=mod_name, recurse=False):
 | 
						|
                params_dict[name] = param
 | 
						|
                if isinstance(mod, (nn.Linear, nn.Conv1d)) and name.endswith('weight'):
 | 
						|
                    to_decay.add(name)
 | 
						|
        decay_params.extend(params_dict[name] for name in sorted(to_decay))
 | 
						|
        no_decay_params.extend(params_dict[name] for name in sorted(params_dict.keys() - to_decay))
 | 
						|
    param_groups = [
 | 
						|
        {'params': decay_params, 'weight_decay': weight_decay},
 | 
						|
        {'params': no_decay_params},
 | 
						|
    ]
 | 
						|
    optimizer = optim.AdamW(param_groups, lr=1, weight_decay=0, betas=betas, eps=eps)
 | 
						|
    scheduler = LinearWarmUpCosineAnnealingLR(optimizer, **config['optim']['scheduler'])
 | 
						|
    scaler = GradScaler(device.type, enabled=enable_amp)
 | 
						|
    test_player = TestPlayer()
 | 
						|
    best_perf = {
 | 
						|
        'avg_rank': 4.,
 | 
						|
        'avg_pt': -135.,
 | 
						|
    }
 | 
						|
 | 
						|
    steps = 0
 | 
						|
    state_file = config['control']['state_file']
 | 
						|
    best_state_file = config['control']['best_state_file']
 | 
						|
    if path.exists(state_file):
 | 
						|
        state = torch.load(state_file, weights_only=True, map_location=device)
 | 
						|
        timestamp = datetime.fromtimestamp(state['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
 | 
						|
        logging.info(f'loaded: {timestamp}')
 | 
						|
        mortal.load_state_dict(state['mortal'])
 | 
						|
        dqn.load_state_dict(state['current_dqn'])
 | 
						|
        aux_net.load_state_dict(state['aux_net'])
 | 
						|
        if not online or state['config']['control']['online']:
 | 
						|
            optimizer.load_state_dict(state['optimizer'])
 | 
						|
            scheduler.load_state_dict(state['scheduler'])
 | 
						|
        scaler.load_state_dict(state['scaler'])
 | 
						|
        best_perf = state['best_perf']
 | 
						|
        steps = state['steps']
 | 
						|
 | 
						|
    optimizer.zero_grad(set_to_none=True)
 | 
						|
    mse = nn.MSELoss()
 | 
						|
    ce = nn.CrossEntropyLoss()
 | 
						|
 | 
						|
    if device.type == 'cuda':
 | 
						|
        logging.info(f'device: {device} ({torch.cuda.get_device_name(device)})')
 | 
						|
    else:
 | 
						|
        logging.info(f'device: {device}')
 | 
						|
 | 
						|
    if online:
 | 
						|
        submit_param(mortal, dqn, is_idle=True)
 | 
						|
        logging.info('param has been submitted')
 | 
						|
 | 
						|
    writer = SummaryWriter(config['control']['tensorboard_dir'])
 | 
						|
    stats = {
 | 
						|
        'dqn_loss': 0,
 | 
						|
        'cql_loss': 0,
 | 
						|
        'next_rank_loss': 0,
 | 
						|
    }
 | 
						|
    all_q = torch.zeros((save_every, batch_size), device=device, dtype=torch.float32)
 | 
						|
    all_q_target = torch.zeros((save_every, batch_size), device=device, dtype=torch.float32)
 | 
						|
    idx = 0
 | 
						|
 | 
						|
    def train_epoch():
 | 
						|
        nonlocal steps
 | 
						|
        nonlocal idx
 | 
						|
 | 
						|
        player_names = []
 | 
						|
        if online:
 | 
						|
            player_names = ['trainee']
 | 
						|
            dirname = drain()
 | 
						|
            file_list = list(map(lambda p: path.join(dirname, p), os.listdir(dirname)))
 | 
						|
        else:
 | 
						|
            player_names_set = set()
 | 
						|
            for filename in config['dataset']['player_names_files']:
 | 
						|
                with open(filename) as f:
 | 
						|
                    player_names_set.update(filtered_trimmed_lines(f))
 | 
						|
            player_names = list(player_names_set)
 | 
						|
            logging.info(f'loaded {len(player_names):,} players')
 | 
						|
 | 
						|
            file_index = config['dataset']['file_index']
 | 
						|
            if path.exists(file_index):
 | 
						|
                index = torch.load(file_index, weights_only=True)
 | 
						|
                file_list = index['file_list']
 | 
						|
            else:
 | 
						|
                logging.info('building file index...')
 | 
						|
                file_list = []
 | 
						|
                for pat in config['dataset']['globs']:
 | 
						|
                    file_list.extend(glob(pat, recursive=True))
 | 
						|
                if len(player_names_set) > 0:
 | 
						|
                    filtered = []
 | 
						|
                    for filename in tqdm(file_list, unit='file'):
 | 
						|
                        with gzip.open(filename, 'rt') as f:
 | 
						|
                            start = json.loads(next(f))
 | 
						|
                            if not set(start['names']).isdisjoint(player_names_set):
 | 
						|
                                filtered.append(filename)
 | 
						|
                    file_list = filtered
 | 
						|
                file_list.sort(reverse=True)
 | 
						|
                torch.save({'file_list': file_list}, file_index)
 | 
						|
        logging.info(f'file list size: {len(file_list):,}')
 | 
						|
 | 
						|
        before_next_test_play = (test_every - steps % test_every) % test_every
 | 
						|
        logging.info(f'total steps: {steps:,} (~{before_next_test_play:,})')
 | 
						|
 | 
						|
        if num_workers > 1:
 | 
						|
            random.shuffle(file_list)
 | 
						|
        file_data = FileDatasetsIter(
 | 
						|
            version = version,
 | 
						|
            file_list = file_list,
 | 
						|
            pts = pts,
 | 
						|
            file_batch_size = file_batch_size,
 | 
						|
            reserve_ratio = reserve_ratio,
 | 
						|
            player_names = player_names,
 | 
						|
            num_epochs = num_epochs,
 | 
						|
            enable_augmentation = enable_augmentation,
 | 
						|
            augmented_first = augmented_first,
 | 
						|
        )
 | 
						|
        data_loader = iter(DataLoader(
 | 
						|
            dataset = file_data,
 | 
						|
            batch_size = batch_size,
 | 
						|
            drop_last = False,
 | 
						|
            num_workers = num_workers,
 | 
						|
            pin_memory = True,
 | 
						|
            worker_init_fn = worker_init_fn,
 | 
						|
        ))
 | 
						|
 | 
						|
        remaining_obs = []
 | 
						|
        remaining_actions = []
 | 
						|
        remaining_masks = []
 | 
						|
        remaining_steps_to_done = []
 | 
						|
        remaining_kyoku_rewards = []
 | 
						|
        remaining_player_ranks = []
 | 
						|
        remaining_bs = 0
 | 
						|
        pb = tqdm(total=save_every, desc='TRAIN', initial=steps % save_every)
 | 
						|
 | 
						|
        def train_batch(obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks):
 | 
						|
            nonlocal steps
 | 
						|
            nonlocal idx
 | 
						|
            nonlocal pb
 | 
						|
 | 
						|
            obs = obs.to(dtype=torch.float32, device=device)
 | 
						|
            actions = actions.to(dtype=torch.int64, device=device)
 | 
						|
            masks = masks.to(dtype=torch.bool, device=device)
 | 
						|
            steps_to_done = steps_to_done.to(dtype=torch.int64, device=device)
 | 
						|
            kyoku_rewards = kyoku_rewards.to(dtype=torch.float64, device=device)
 | 
						|
            player_ranks = player_ranks.to(dtype=torch.int64, device=device)
 | 
						|
            assert masks[range(batch_size), actions].all()
 | 
						|
 | 
						|
            q_target_mc = gamma ** steps_to_done * kyoku_rewards
 | 
						|
            q_target_mc = q_target_mc.to(torch.float32)
 | 
						|
 | 
						|
            with torch.autocast(device.type, enabled=enable_amp):
 | 
						|
                phi = mortal(obs)
 | 
						|
                q_out = dqn(phi, masks)
 | 
						|
                q = q_out[range(batch_size), actions]
 | 
						|
                dqn_loss = 0.5 * mse(q, q_target_mc)
 | 
						|
                cql_loss = 0
 | 
						|
                if not online:
 | 
						|
                    cql_loss = q_out.logsumexp(-1).mean() - q.mean()
 | 
						|
 | 
						|
                next_rank_logits, = aux_net(phi)
 | 
						|
                next_rank_loss = ce(next_rank_logits, player_ranks)
 | 
						|
 | 
						|
                loss = sum((
 | 
						|
                    dqn_loss,
 | 
						|
                    cql_loss * min_q_weight,
 | 
						|
                    next_rank_loss * next_rank_weight,
 | 
						|
                ))
 | 
						|
            scaler.scale(loss / opt_step_every).backward()
 | 
						|
 | 
						|
            with torch.inference_mode():
 | 
						|
                stats['dqn_loss'] += dqn_loss
 | 
						|
                if not online:
 | 
						|
                    stats['cql_loss'] += cql_loss
 | 
						|
                stats['next_rank_loss'] += next_rank_loss
 | 
						|
                all_q[idx] = q
 | 
						|
                all_q_target[idx] = q_target_mc
 | 
						|
 | 
						|
            steps += 1
 | 
						|
            idx += 1
 | 
						|
            if idx % opt_step_every == 0:
 | 
						|
                if max_grad_norm > 0:
 | 
						|
                    scaler.unscale_(optimizer)
 | 
						|
                    params = chain.from_iterable(g['params'] for g in optimizer.param_groups)
 | 
						|
                    clip_grad_norm_(params, max_grad_norm)
 | 
						|
                scaler.step(optimizer)
 | 
						|
                scaler.update()
 | 
						|
                optimizer.zero_grad(set_to_none=True)
 | 
						|
            scheduler.step()
 | 
						|
            pb.update(1)
 | 
						|
 | 
						|
            if online and steps % submit_every == 0:
 | 
						|
                submit_param(mortal, dqn, is_idle=False)
 | 
						|
                logging.info('param has been submitted')
 | 
						|
 | 
						|
            if steps % save_every == 0:
 | 
						|
                pb.close()
 | 
						|
 | 
						|
                # downsample to reduce tensorboard event size
 | 
						|
                all_q_1d = all_q.cpu().numpy().flatten()[::128]
 | 
						|
                all_q_target_1d = all_q_target.cpu().numpy().flatten()[::128]
 | 
						|
 | 
						|
                writer.add_scalar('loss/dqn_loss', stats['dqn_loss'] / save_every, steps)
 | 
						|
                if not online:
 | 
						|
                    writer.add_scalar('loss/cql_loss', stats['cql_loss'] / save_every, steps)
 | 
						|
                writer.add_scalar('loss/next_rank_loss', stats['next_rank_loss'] / save_every, steps)
 | 
						|
                writer.add_scalar('hparam/lr', scheduler.get_last_lr()[0], steps)
 | 
						|
                writer.add_histogram('q_predicted', all_q_1d, steps)
 | 
						|
                writer.add_histogram('q_target', all_q_target_1d, steps)
 | 
						|
                writer.flush()
 | 
						|
 | 
						|
                for k in stats:
 | 
						|
                    stats[k] = 0
 | 
						|
                idx = 0
 | 
						|
 | 
						|
                before_next_test_play = (test_every - steps % test_every) % test_every
 | 
						|
                logging.info(f'total steps: {steps:,} (~{before_next_test_play:,})')
 | 
						|
 | 
						|
                state = {
 | 
						|
                    'mortal': mortal.state_dict(),
 | 
						|
                    'current_dqn': dqn.state_dict(),
 | 
						|
                    'aux_net': aux_net.state_dict(),
 | 
						|
                    'optimizer': optimizer.state_dict(),
 | 
						|
                    'scheduler': scheduler.state_dict(),
 | 
						|
                    'scaler': scaler.state_dict(),
 | 
						|
                    'steps': steps,
 | 
						|
                    'timestamp': datetime.now().timestamp(),
 | 
						|
                    'best_perf': best_perf,
 | 
						|
                    'config': config,
 | 
						|
                }
 | 
						|
                torch.save(state, state_file)
 | 
						|
 | 
						|
                if online and steps % submit_every != 0:
 | 
						|
                    submit_param(mortal, dqn, is_idle=False)
 | 
						|
                    logging.info('param has been submitted')
 | 
						|
 | 
						|
                if steps % test_every == 0:
 | 
						|
                    stat = test_player.test_play(test_games // 4, mortal, dqn, device)
 | 
						|
                    mortal.train()
 | 
						|
                    dqn.train()
 | 
						|
 | 
						|
                    avg_pt = stat.avg_pt([90, 45, 0, -135]) # for display only, never used in training
 | 
						|
                    better = avg_pt >= best_perf['avg_pt'] and stat.avg_rank <= best_perf['avg_rank']
 | 
						|
                    if better:
 | 
						|
                        past_best = best_perf.copy()
 | 
						|
                        best_perf['avg_pt'] = avg_pt
 | 
						|
                        best_perf['avg_rank'] = stat.avg_rank
 | 
						|
 | 
						|
                    logging.info(f'avg rank: {stat.avg_rank:.6}')
 | 
						|
                    logging.info(f'avg pt: {avg_pt:.6}')
 | 
						|
                    writer.add_scalar('test_play/avg_ranking', stat.avg_rank, steps)
 | 
						|
                    writer.add_scalar('test_play/avg_pt', avg_pt, steps)
 | 
						|
                    writer.add_scalars('test_play/ranking', {
 | 
						|
                        '1st': stat.rank_1_rate,
 | 
						|
                        '2nd': stat.rank_2_rate,
 | 
						|
                        '3rd': stat.rank_3_rate,
 | 
						|
                        '4th': stat.rank_4_rate,
 | 
						|
                    }, steps)
 | 
						|
                    writer.add_scalars('test_play/behavior', {
 | 
						|
                        'agari': stat.agari_rate,
 | 
						|
                        'houjuu': stat.houjuu_rate,
 | 
						|
                        'fuuro': stat.fuuro_rate,
 | 
						|
                        'riichi': stat.riichi_rate,
 | 
						|
                    }, steps)
 | 
						|
                    writer.add_scalars('test_play/agari_point', {
 | 
						|
                        'overall': stat.avg_point_per_agari,
 | 
						|
                        'riichi': stat.avg_point_per_riichi_agari,
 | 
						|
                        'fuuro': stat.avg_point_per_fuuro_agari,
 | 
						|
                        'dama': stat.avg_point_per_dama_agari,
 | 
						|
                    }, steps)
 | 
						|
                    writer.add_scalar('test_play/houjuu_point', stat.avg_point_per_houjuu, steps)
 | 
						|
                    writer.add_scalar('test_play/point_per_round', stat.avg_point_per_round, steps)
 | 
						|
                    writer.add_scalars('test_play/key_step', {
 | 
						|
                        'agari_jun': stat.avg_agari_jun,
 | 
						|
                        'houjuu_jun': stat.avg_houjuu_jun,
 | 
						|
                        'riichi_jun': stat.avg_riichi_jun,
 | 
						|
                    }, steps)
 | 
						|
                    writer.add_scalars('test_play/riichi', {
 | 
						|
                        'agari_after_riichi': stat.agari_rate_after_riichi,
 | 
						|
                        'houjuu_after_riichi': stat.houjuu_rate_after_riichi,
 | 
						|
                        'chasing_riichi': stat.chasing_riichi_rate,
 | 
						|
                        'riichi_chased': stat.riichi_chased_rate,
 | 
						|
                    }, steps)
 | 
						|
                    writer.add_scalar('test_play/riichi_point', stat.avg_riichi_point, steps)
 | 
						|
                    writer.add_scalars('test_play/fuuro', {
 | 
						|
                        'agari_after_fuuro': stat.agari_rate_after_fuuro,
 | 
						|
                        'houjuu_after_fuuro': stat.houjuu_rate_after_fuuro,
 | 
						|
                    }, steps)
 | 
						|
                    writer.add_scalar('test_play/fuuro_num', stat.avg_fuuro_num, steps)
 | 
						|
                    writer.add_scalar('test_play/fuuro_point', stat.avg_fuuro_point, steps)
 | 
						|
                    writer.flush()
 | 
						|
 | 
						|
                    if better:
 | 
						|
                        torch.save(state, state_file)
 | 
						|
                        logging.info(
 | 
						|
                            'a new record has been made, '
 | 
						|
                            f'pt: {past_best["avg_pt"]:.4} -> {best_perf["avg_pt"]:.4}, '
 | 
						|
                            f'rank: {past_best["avg_rank"]:.4} -> {best_perf["avg_rank"]:.4}, '
 | 
						|
                            f'saving to {best_state_file}'
 | 
						|
                        )
 | 
						|
                        shutil.copy(state_file, best_state_file)
 | 
						|
                    if online:
 | 
						|
                        # BUG: This is a bug with unknown reason. When training
 | 
						|
                        # in online mode, the process will get stuck here. This
 | 
						|
                        # is the reason why `main` spawns a sub process to train
 | 
						|
                        # in online mode instead of going for training directly.
 | 
						|
                        sys.exit(0)
 | 
						|
                pb = tqdm(total=save_every, desc='TRAIN')
 | 
						|
 | 
						|
        for obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks in data_loader:
 | 
						|
            bs = obs.shape[0]
 | 
						|
            if bs != batch_size:
 | 
						|
                remaining_obs.append(obs)
 | 
						|
                remaining_actions.append(actions)
 | 
						|
                remaining_masks.append(masks)
 | 
						|
                remaining_steps_to_done.append(steps_to_done)
 | 
						|
                remaining_kyoku_rewards.append(kyoku_rewards)
 | 
						|
                remaining_player_ranks.append(player_ranks)
 | 
						|
                remaining_bs += bs
 | 
						|
                continue
 | 
						|
            train_batch(obs, actions, masks, steps_to_done, kyoku_rewards, player_ranks)
 | 
						|
 | 
						|
        remaining_batches = remaining_bs // batch_size
 | 
						|
        if remaining_batches > 0:
 | 
						|
            obs = torch.cat(remaining_obs, dim=0)
 | 
						|
            actions = torch.cat(remaining_actions, dim=0)
 | 
						|
            masks = torch.cat(remaining_masks, dim=0)
 | 
						|
            steps_to_done = torch.cat(remaining_steps_to_done, dim=0)
 | 
						|
            kyoku_rewards = torch.cat(remaining_kyoku_rewards, dim=0)
 | 
						|
            player_ranks = torch.cat(remaining_player_ranks, dim=0)
 | 
						|
            start = 0
 | 
						|
            end = batch_size
 | 
						|
            while end <= remaining_bs:
 | 
						|
                train_batch(
 | 
						|
                    obs[start:end],
 | 
						|
                    actions[start:end],
 | 
						|
                    masks[start:end],
 | 
						|
                    steps_to_done[start:end],
 | 
						|
                    kyoku_rewards[start:end],
 | 
						|
                    player_ranks[start:end],
 | 
						|
                )
 | 
						|
                start = end
 | 
						|
                end += batch_size
 | 
						|
        pb.close()
 | 
						|
 | 
						|
        if online:
 | 
						|
            submit_param(mortal, dqn, is_idle=True)
 | 
						|
            logging.info('param has been submitted')
 | 
						|
 | 
						|
    while True:
 | 
						|
        train_epoch()
 | 
						|
        gc.collect()
 | 
						|
        # torch.cuda.empty_cache()
 | 
						|
        # torch.cuda.synchronize()
 | 
						|
        if not online:
 | 
						|
            # only run one epoch for offline for easier control
 | 
						|
            break
 | 
						|
 | 
						|
def main():
 | 
						|
    import os
 | 
						|
    import sys
 | 
						|
    import time
 | 
						|
    from subprocess import Popen
 | 
						|
    from config import config
 | 
						|
 | 
						|
    # do not set this env manually
 | 
						|
    is_sub_proc_key = 'MORTAL_IS_SUB_PROC'
 | 
						|
    online = config['control']['online']
 | 
						|
    if not online or os.environ.get(is_sub_proc_key, '0') == '1':
 | 
						|
        train()
 | 
						|
        return
 | 
						|
 | 
						|
    cmd = (sys.executable, __file__)
 | 
						|
    env = {
 | 
						|
        is_sub_proc_key: '1',
 | 
						|
        **os.environ.copy(),
 | 
						|
    }
 | 
						|
    while True:
 | 
						|
        child = Popen(
 | 
						|
            cmd,
 | 
						|
            stdin = sys.stdin,
 | 
						|
            stdout = sys.stdout,
 | 
						|
            stderr = sys.stderr,
 | 
						|
            env = env,
 | 
						|
        )
 | 
						|
        if (code := child.wait()) != 0:
 | 
						|
            sys.exit(code)
 | 
						|
        time.sleep(3)
 | 
						|
 | 
						|
if __name__ == '__main__':
 | 
						|
    try:
 | 
						|
        main()
 | 
						|
    except KeyboardInterrupt:
 | 
						|
        pass
 |