287 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			287 lines
		
	
	
		
			9.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
from torch import nn, Tensor
 | 
						|
from torch.nn import functional as F
 | 
						|
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
 | 
						|
from typing import *
 | 
						|
from functools import partial
 | 
						|
from itertools import permutations
 | 
						|
from libriichi.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE
 | 
						|
 | 
						|
class ChannelAttention(nn.Module):
 | 
						|
    def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True):
 | 
						|
        super().__init__()
 | 
						|
        self.shared_mlp = nn.Sequential(
 | 
						|
            nn.Linear(channels, channels // ratio, bias=bias),
 | 
						|
            actv_builder(),
 | 
						|
            nn.Linear(channels // ratio, channels, bias=bias),
 | 
						|
        )
 | 
						|
        if bias:
 | 
						|
            for mod in self.modules():
 | 
						|
                if isinstance(mod, nn.Linear):
 | 
						|
                    nn.init.constant_(mod.bias, 0)
 | 
						|
 | 
						|
    def forward(self, x: Tensor):
 | 
						|
        avg_out = self.shared_mlp(x.mean(-1))
 | 
						|
        max_out = self.shared_mlp(x.amax(-1))
 | 
						|
        weight = (avg_out + max_out).sigmoid()
 | 
						|
        x = weight.unsqueeze(-1) * x
 | 
						|
        return x
 | 
						|
 | 
						|
class ResBlock(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        channels,
 | 
						|
        *,
 | 
						|
        norm_builder = nn.Identity,
 | 
						|
        actv_builder = nn.ReLU,
 | 
						|
        pre_actv = False,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
        self.pre_actv = pre_actv
 | 
						|
 | 
						|
        if pre_actv:
 | 
						|
            self.res_unit = nn.Sequential(
 | 
						|
                norm_builder(),
 | 
						|
                actv_builder(),
 | 
						|
                nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
 | 
						|
                norm_builder(),
 | 
						|
                actv_builder(),
 | 
						|
                nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            self.res_unit = nn.Sequential(
 | 
						|
                nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
 | 
						|
                norm_builder(),
 | 
						|
                actv_builder(),
 | 
						|
                nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
 | 
						|
                norm_builder(),
 | 
						|
            )
 | 
						|
            self.actv = actv_builder()
 | 
						|
        self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        out = self.res_unit(x)
 | 
						|
        out = self.ca(out)
 | 
						|
        out = out + x
 | 
						|
        if not self.pre_actv:
 | 
						|
            out = self.actv(out)
 | 
						|
        return out
 | 
						|
 | 
						|
class ResNet(nn.Module):
 | 
						|
    def __init__(
 | 
						|
        self,
 | 
						|
        in_channels,
 | 
						|
        conv_channels,
 | 
						|
        num_blocks,
 | 
						|
        *,
 | 
						|
        norm_builder = nn.Identity,
 | 
						|
        actv_builder = nn.ReLU,
 | 
						|
        pre_actv = False,
 | 
						|
    ):
 | 
						|
        super().__init__()
 | 
						|
 | 
						|
        blocks = []
 | 
						|
        for _ in range(num_blocks):
 | 
						|
            blocks.append(ResBlock(
 | 
						|
                conv_channels,
 | 
						|
                norm_builder = norm_builder,
 | 
						|
                actv_builder = actv_builder,
 | 
						|
                pre_actv = pre_actv,
 | 
						|
            ))
 | 
						|
 | 
						|
        layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)]
 | 
						|
        if pre_actv:
 | 
						|
            layers += [*blocks, norm_builder(), actv_builder()]
 | 
						|
        else:
 | 
						|
            layers += [norm_builder(), actv_builder(), *blocks]
 | 
						|
        layers += [
 | 
						|
            nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1),
 | 
						|
            actv_builder(),
 | 
						|
            nn.Flatten(),
 | 
						|
            nn.Linear(32 * 34, 1024),
 | 
						|
        ]
 | 
						|
        self.net = nn.Sequential(*layers)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        return self.net(x)
 | 
						|
 | 
						|
class Brain(nn.Module):
 | 
						|
    def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1):
 | 
						|
        super().__init__()
 | 
						|
        self.is_oracle = is_oracle
 | 
						|
        self.version = version
 | 
						|
 | 
						|
        in_channels = obs_shape(version)[0]
 | 
						|
        if is_oracle:
 | 
						|
            in_channels += oracle_obs_shape(version)[0]
 | 
						|
 | 
						|
        norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01)
 | 
						|
        actv_builder = partial(nn.Mish, inplace=True)
 | 
						|
        pre_actv = True
 | 
						|
 | 
						|
        match version:
 | 
						|
            case 1:
 | 
						|
                actv_builder = partial(nn.ReLU, inplace=True)
 | 
						|
                pre_actv = False
 | 
						|
                self.latent_net = nn.Sequential(
 | 
						|
                    nn.Linear(1024, 512),
 | 
						|
                    nn.ReLU(inplace=True),
 | 
						|
                )
 | 
						|
                self.mu_head = nn.Linear(512, 512)
 | 
						|
                self.logsig_head = nn.Linear(512, 512)
 | 
						|
            case 2:
 | 
						|
                pass
 | 
						|
            case 3 | 4:
 | 
						|
                norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3)
 | 
						|
            case _:
 | 
						|
                raise ValueError(f'Unexpected version {self.version}')
 | 
						|
 | 
						|
        self.encoder = ResNet(
 | 
						|
            in_channels = in_channels,
 | 
						|
            conv_channels = conv_channels,
 | 
						|
            num_blocks = num_blocks,
 | 
						|
            norm_builder = norm_builder,
 | 
						|
            actv_builder = actv_builder,
 | 
						|
            pre_actv = pre_actv,
 | 
						|
        )
 | 
						|
        self.actv = actv_builder()
 | 
						|
 | 
						|
        # always use EMA or CMA when True
 | 
						|
        self._freeze_bn = False
 | 
						|
 | 
						|
    def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]:
 | 
						|
        if self.is_oracle:
 | 
						|
            assert invisible_obs is not None
 | 
						|
            obs = torch.cat((obs, invisible_obs), dim=1)
 | 
						|
        phi = self.encoder(obs)
 | 
						|
 | 
						|
        match self.version:
 | 
						|
            case 1:
 | 
						|
                latent_out = self.latent_net(phi)
 | 
						|
                mu = self.mu_head(latent_out)
 | 
						|
                logsig = self.logsig_head(latent_out)
 | 
						|
                return mu, logsig
 | 
						|
            case 2 | 3 | 4:
 | 
						|
                return self.actv(phi)
 | 
						|
            case _:
 | 
						|
                raise ValueError(f'Unexpected version {self.version}')
 | 
						|
 | 
						|
    def train(self, mode=True):
 | 
						|
        super().train(mode)
 | 
						|
        if self._freeze_bn:
 | 
						|
            for mod in self.modules():
 | 
						|
                if isinstance(mod, nn.BatchNorm1d):
 | 
						|
                    mod.eval()
 | 
						|
                    # I don't think this benefits
 | 
						|
                    # module.requires_grad_(False)
 | 
						|
        return self
 | 
						|
 | 
						|
    def reset_running_stats(self):
 | 
						|
        for mod in self.modules():
 | 
						|
            if isinstance(mod, nn.BatchNorm1d):
 | 
						|
                mod.reset_running_stats()
 | 
						|
 | 
						|
    def freeze_bn(self, value: bool):
 | 
						|
        self._freeze_bn = value
 | 
						|
        return self.train(self.training)
 | 
						|
 | 
						|
class AuxNet(nn.Module):
 | 
						|
    def __init__(self, dims=None):
 | 
						|
        super().__init__()
 | 
						|
        self.dims = dims
 | 
						|
        self.net = nn.Linear(1024, sum(dims), bias=False)
 | 
						|
 | 
						|
    def forward(self, x):
 | 
						|
        return self.net(x).split(self.dims, dim=-1)
 | 
						|
 | 
						|
class DQN(nn.Module):
 | 
						|
    def __init__(self, *, version=1):
 | 
						|
        super().__init__()
 | 
						|
        self.version = version
 | 
						|
        match version:
 | 
						|
            case 1:
 | 
						|
                self.v_head = nn.Linear(512, 1)
 | 
						|
                self.a_head = nn.Linear(512, ACTION_SPACE)
 | 
						|
            case 2 | 3:
 | 
						|
                hidden_size = 512 if version == 2 else 256
 | 
						|
                self.v_head = nn.Sequential(
 | 
						|
                    nn.Linear(1024, hidden_size),
 | 
						|
                    nn.Mish(inplace=True),
 | 
						|
                    nn.Linear(hidden_size, 1),
 | 
						|
                )
 | 
						|
                self.a_head = nn.Sequential(
 | 
						|
                    nn.Linear(1024, hidden_size),
 | 
						|
                    nn.Mish(inplace=True),
 | 
						|
                    nn.Linear(hidden_size, ACTION_SPACE),
 | 
						|
                )
 | 
						|
            case 4:
 | 
						|
                self.net = nn.Linear(1024, 1 + ACTION_SPACE)
 | 
						|
                nn.init.constant_(self.net.bias, 0)
 | 
						|
 | 
						|
    def forward(self, phi, mask):
 | 
						|
        if self.version == 4:
 | 
						|
            v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1)
 | 
						|
        else:
 | 
						|
            v = self.v_head(phi)
 | 
						|
            a = self.a_head(phi)
 | 
						|
        a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True)
 | 
						|
        mask_sum = mask.sum(-1, keepdim=True)
 | 
						|
        a_mean = a_sum / mask_sum
 | 
						|
        q = (v + a - a_mean).masked_fill(~mask, -torch.inf)
 | 
						|
        return q
 | 
						|
 | 
						|
class GRP(nn.Module):
 | 
						|
    def __init__(self, hidden_size=64, num_layers=2):
 | 
						|
        super().__init__()
 | 
						|
        self.rnn = nn.GRU(input_size=GRP_SIZE, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
 | 
						|
        self.fc = nn.Sequential(
 | 
						|
            nn.Linear(hidden_size * num_layers, hidden_size * num_layers),
 | 
						|
            nn.ReLU(inplace=True),
 | 
						|
            nn.Linear(hidden_size * num_layers, 24),
 | 
						|
        )
 | 
						|
        for mod in self.modules():
 | 
						|
            mod.to(torch.float64)
 | 
						|
 | 
						|
        # perms are the permutations of all possible rank-by-player result
 | 
						|
        perms = torch.tensor(list(permutations(range(4))))
 | 
						|
        perms_t = perms.transpose(0, 1)
 | 
						|
        self.register_buffer('perms', perms)     # (24, 4)
 | 
						|
        self.register_buffer('perms_t', perms_t) # (4, 24)
 | 
						|
 | 
						|
    # input: [grand_kyoku, honba, kyotaku, s[0], s[1], s[2], s[3]]
 | 
						|
    # grand_kyoku: E1 = 0, S4 = 7, W4 = 11
 | 
						|
    # s is 2.5 at E1
 | 
						|
    # s[0] is score of player id 0
 | 
						|
    def forward(self, inputs: List[Tensor]):
 | 
						|
        lengths = torch.tensor([t.shape[0] for t in inputs], dtype=torch.int64)
 | 
						|
        inputs = pad_sequence(inputs, batch_first=True)
 | 
						|
        packed_inputs = pack_padded_sequence(inputs, lengths, batch_first=True, enforce_sorted=False)
 | 
						|
        return self.forward_packed(packed_inputs)
 | 
						|
 | 
						|
    def forward_packed(self, packed_inputs):
 | 
						|
        _, state = self.rnn(packed_inputs)
 | 
						|
        state = state.transpose(0, 1).flatten(1)
 | 
						|
        logits = self.fc(state)
 | 
						|
        return logits
 | 
						|
 | 
						|
    # (N, 24) -> (N, player, rank_prob)
 | 
						|
    def calc_matrix(self, logits: Tensor):
 | 
						|
        batch_size = logits.shape[0]
 | 
						|
        probs = logits.softmax(-1)
 | 
						|
        matrix = torch.zeros(batch_size, 4, 4, dtype=probs.dtype)
 | 
						|
        for player in range(4):
 | 
						|
            for rank in range(4):
 | 
						|
                cond = self.perms_t[player] == rank
 | 
						|
                matrix[:, player, rank] = probs[:, cond].sum(-1)
 | 
						|
        return matrix
 | 
						|
 | 
						|
    # (N, 4) -> (N)
 | 
						|
    def get_label(self, rank_by_player: Tensor):
 | 
						|
        batch_size = rank_by_player.shape[0]
 | 
						|
        perms = self.perms.expand(batch_size, -1, -1).transpose(0, 1)
 | 
						|
        mappings = (perms == rank_by_player).all(-1).nonzero()
 | 
						|
 | 
						|
        labels = torch.zeros(batch_size, dtype=torch.int64, device=mappings.device)
 | 
						|
        labels[mappings[:, 1]] = mappings[:, 0]
 | 
						|
        return labels
 |