import torch from torch import nn, Tensor from torch.nn import functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence from typing import * from functools import partial from itertools import permutations from libriichi.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE class ChannelAttention(nn.Module): def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True): super().__init__() self.shared_mlp = nn.Sequential( nn.Linear(channels, channels // ratio, bias=bias), actv_builder(), nn.Linear(channels // ratio, channels, bias=bias), ) if bias: for mod in self.modules(): if isinstance(mod, nn.Linear): nn.init.constant_(mod.bias, 0) def forward(self, x: Tensor): avg_out = self.shared_mlp(x.mean(-1)) max_out = self.shared_mlp(x.amax(-1)) weight = (avg_out + max_out).sigmoid() x = weight.unsqueeze(-1) * x return x class ResBlock(nn.Module): def __init__( self, channels, *, norm_builder = nn.Identity, actv_builder = nn.ReLU, pre_actv = False, ): super().__init__() self.pre_actv = pre_actv if pre_actv: self.res_unit = nn.Sequential( norm_builder(), actv_builder(), nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), norm_builder(), actv_builder(), nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), ) else: self.res_unit = nn.Sequential( nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), norm_builder(), actv_builder(), nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False), norm_builder(), ) self.actv = actv_builder() self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True) def forward(self, x): out = self.res_unit(x) out = self.ca(out) out = out + x if not self.pre_actv: out = self.actv(out) return out class ResNet(nn.Module): def __init__( self, in_channels, conv_channels, num_blocks, *, norm_builder = nn.Identity, actv_builder = nn.ReLU, pre_actv = False, ): super().__init__() blocks = [] for _ in range(num_blocks): blocks.append(ResBlock( conv_channels, norm_builder = norm_builder, actv_builder = actv_builder, pre_actv = pre_actv, )) layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)] if pre_actv: layers += [*blocks, norm_builder(), actv_builder()] else: layers += [norm_builder(), actv_builder(), *blocks] layers += [ nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1), actv_builder(), nn.Flatten(), nn.Linear(32 * 34, 1024), ] self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class Brain(nn.Module): def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1): super().__init__() self.is_oracle = is_oracle self.version = version in_channels = obs_shape(version)[0] if is_oracle: in_channels += oracle_obs_shape(version)[0] norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01) actv_builder = partial(nn.Mish, inplace=True) pre_actv = True match version: case 1: actv_builder = partial(nn.ReLU, inplace=True) pre_actv = False self.latent_net = nn.Sequential( nn.Linear(1024, 512), nn.ReLU(inplace=True), ) self.mu_head = nn.Linear(512, 512) self.logsig_head = nn.Linear(512, 512) case 2: pass case 3 | 4: norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3) case _: raise ValueError(f'Unexpected version {self.version}') self.encoder = ResNet( in_channels = in_channels, conv_channels = conv_channels, num_blocks = num_blocks, norm_builder = norm_builder, actv_builder = actv_builder, pre_actv = pre_actv, ) self.actv = actv_builder() # always use EMA or CMA when True self._freeze_bn = False def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]: if self.is_oracle: assert invisible_obs is not None obs = torch.cat((obs, invisible_obs), dim=1) phi = self.encoder(obs) match self.version: case 1: latent_out = self.latent_net(phi) mu = self.mu_head(latent_out) logsig = self.logsig_head(latent_out) return mu, logsig case 2 | 3 | 4: return self.actv(phi) case _: raise ValueError(f'Unexpected version {self.version}') def train(self, mode=True): super().train(mode) if self._freeze_bn: for mod in self.modules(): if isinstance(mod, nn.BatchNorm1d): mod.eval() # I don't think this benefits # module.requires_grad_(False) return self def reset_running_stats(self): for mod in self.modules(): if isinstance(mod, nn.BatchNorm1d): mod.reset_running_stats() def freeze_bn(self, value: bool): self._freeze_bn = value return self.train(self.training) class AuxNet(nn.Module): def __init__(self, dims=None): super().__init__() self.dims = dims self.net = nn.Linear(1024, sum(dims), bias=False) def forward(self, x): return self.net(x).split(self.dims, dim=-1) class DQN(nn.Module): def __init__(self, *, version=1): super().__init__() self.version = version match version: case 1: self.v_head = nn.Linear(512, 1) self.a_head = nn.Linear(512, ACTION_SPACE) case 2 | 3: hidden_size = 512 if version == 2 else 256 self.v_head = nn.Sequential( nn.Linear(1024, hidden_size), nn.Mish(inplace=True), nn.Linear(hidden_size, 1), ) self.a_head = nn.Sequential( nn.Linear(1024, hidden_size), nn.Mish(inplace=True), nn.Linear(hidden_size, ACTION_SPACE), ) case 4: self.net = nn.Linear(1024, 1 + ACTION_SPACE) nn.init.constant_(self.net.bias, 0) def forward(self, phi, mask): if self.version == 4: v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1) else: v = self.v_head(phi) a = self.a_head(phi) a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True) mask_sum = mask.sum(-1, keepdim=True) a_mean = a_sum / mask_sum q = (v + a - a_mean).masked_fill(~mask, -torch.inf) return q class GRP(nn.Module): def __init__(self, hidden_size=64, num_layers=2): super().__init__() self.rnn = nn.GRU(input_size=GRP_SIZE, hidden_size=hidden_size, num_layers=num_layers, batch_first=True) self.fc = nn.Sequential( nn.Linear(hidden_size * num_layers, hidden_size * num_layers), nn.ReLU(inplace=True), nn.Linear(hidden_size * num_layers, 24), ) for mod in self.modules(): mod.to(torch.float64) # perms are the permutations of all possible rank-by-player result perms = torch.tensor(list(permutations(range(4)))) perms_t = perms.transpose(0, 1) self.register_buffer('perms', perms) # (24, 4) self.register_buffer('perms_t', perms_t) # (4, 24) # input: [grand_kyoku, honba, kyotaku, s[0], s[1], s[2], s[3]] # grand_kyoku: E1 = 0, S4 = 7, W4 = 11 # s is 2.5 at E1 # s[0] is score of player id 0 def forward(self, inputs: List[Tensor]): lengths = torch.tensor([t.shape[0] for t in inputs], dtype=torch.int64) inputs = pad_sequence(inputs, batch_first=True) packed_inputs = pack_padded_sequence(inputs, lengths, batch_first=True, enforce_sorted=False) return self.forward_packed(packed_inputs) def forward_packed(self, packed_inputs): _, state = self.rnn(packed_inputs) state = state.transpose(0, 1).flatten(1) logits = self.fc(state) return logits # (N, 24) -> (N, player, rank_prob) def calc_matrix(self, logits: Tensor): batch_size = logits.shape[0] probs = logits.softmax(-1) matrix = torch.zeros(batch_size, 4, 4, dtype=probs.dtype) for player in range(4): for rank in range(4): cond = self.perms_t[player] == rank matrix[:, player, rank] = probs[:, cond].sum(-1) return matrix # (N, 4) -> (N) def get_label(self, rank_by_player: Tensor): batch_size = rank_by_player.shape[0] perms = self.perms.expand(batch_size, -1, -1).transpose(0, 1) mappings = (perms == rank_by_player).all(-1).nonzero() labels = torch.zeros(batch_size, dtype=torch.int64, device=mappings.device) labels[mappings[:, 1]] = mappings[:, 0] return labels