Mortal
Some checks failed
deploy-docs / build (push) Has been cancelled
build-libriichi / build (push) Has been cancelled

This commit is contained in:
e2hang
2025-10-07 20:30:03 +08:00
commit b7a7d7404a
441 changed files with 23367 additions and 0 deletions

286
mortal/model.py Normal file
View File

@@ -0,0 +1,286 @@
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence
from typing import *
from functools import partial
from itertools import permutations
from libriichi.consts import obs_shape, oracle_obs_shape, ACTION_SPACE, GRP_SIZE
class ChannelAttention(nn.Module):
def __init__(self, channels, ratio=16, actv_builder=nn.ReLU, bias=True):
super().__init__()
self.shared_mlp = nn.Sequential(
nn.Linear(channels, channels // ratio, bias=bias),
actv_builder(),
nn.Linear(channels // ratio, channels, bias=bias),
)
if bias:
for mod in self.modules():
if isinstance(mod, nn.Linear):
nn.init.constant_(mod.bias, 0)
def forward(self, x: Tensor):
avg_out = self.shared_mlp(x.mean(-1))
max_out = self.shared_mlp(x.amax(-1))
weight = (avg_out + max_out).sigmoid()
x = weight.unsqueeze(-1) * x
return x
class ResBlock(nn.Module):
def __init__(
self,
channels,
*,
norm_builder = nn.Identity,
actv_builder = nn.ReLU,
pre_actv = False,
):
super().__init__()
self.pre_actv = pre_actv
if pre_actv:
self.res_unit = nn.Sequential(
norm_builder(),
actv_builder(),
nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
norm_builder(),
actv_builder(),
nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
)
else:
self.res_unit = nn.Sequential(
nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
norm_builder(),
actv_builder(),
nn.Conv1d(channels, channels, kernel_size=3, padding=1, bias=False),
norm_builder(),
)
self.actv = actv_builder()
self.ca = ChannelAttention(channels, actv_builder=actv_builder, bias=True)
def forward(self, x):
out = self.res_unit(x)
out = self.ca(out)
out = out + x
if not self.pre_actv:
out = self.actv(out)
return out
class ResNet(nn.Module):
def __init__(
self,
in_channels,
conv_channels,
num_blocks,
*,
norm_builder = nn.Identity,
actv_builder = nn.ReLU,
pre_actv = False,
):
super().__init__()
blocks = []
for _ in range(num_blocks):
blocks.append(ResBlock(
conv_channels,
norm_builder = norm_builder,
actv_builder = actv_builder,
pre_actv = pre_actv,
))
layers = [nn.Conv1d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False)]
if pre_actv:
layers += [*blocks, norm_builder(), actv_builder()]
else:
layers += [norm_builder(), actv_builder(), *blocks]
layers += [
nn.Conv1d(conv_channels, 32, kernel_size=3, padding=1),
actv_builder(),
nn.Flatten(),
nn.Linear(32 * 34, 1024),
]
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class Brain(nn.Module):
def __init__(self, *, conv_channels, num_blocks, is_oracle=False, version=1):
super().__init__()
self.is_oracle = is_oracle
self.version = version
in_channels = obs_shape(version)[0]
if is_oracle:
in_channels += oracle_obs_shape(version)[0]
norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01)
actv_builder = partial(nn.Mish, inplace=True)
pre_actv = True
match version:
case 1:
actv_builder = partial(nn.ReLU, inplace=True)
pre_actv = False
self.latent_net = nn.Sequential(
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
)
self.mu_head = nn.Linear(512, 512)
self.logsig_head = nn.Linear(512, 512)
case 2:
pass
case 3 | 4:
norm_builder = partial(nn.BatchNorm1d, conv_channels, momentum=0.01, eps=1e-3)
case _:
raise ValueError(f'Unexpected version {self.version}')
self.encoder = ResNet(
in_channels = in_channels,
conv_channels = conv_channels,
num_blocks = num_blocks,
norm_builder = norm_builder,
actv_builder = actv_builder,
pre_actv = pre_actv,
)
self.actv = actv_builder()
# always use EMA or CMA when True
self._freeze_bn = False
def forward(self, obs: Tensor, invisible_obs: Optional[Tensor] = None) -> Union[Tuple[Tensor, Tensor], Tensor]:
if self.is_oracle:
assert invisible_obs is not None
obs = torch.cat((obs, invisible_obs), dim=1)
phi = self.encoder(obs)
match self.version:
case 1:
latent_out = self.latent_net(phi)
mu = self.mu_head(latent_out)
logsig = self.logsig_head(latent_out)
return mu, logsig
case 2 | 3 | 4:
return self.actv(phi)
case _:
raise ValueError(f'Unexpected version {self.version}')
def train(self, mode=True):
super().train(mode)
if self._freeze_bn:
for mod in self.modules():
if isinstance(mod, nn.BatchNorm1d):
mod.eval()
# I don't think this benefits
# module.requires_grad_(False)
return self
def reset_running_stats(self):
for mod in self.modules():
if isinstance(mod, nn.BatchNorm1d):
mod.reset_running_stats()
def freeze_bn(self, value: bool):
self._freeze_bn = value
return self.train(self.training)
class AuxNet(nn.Module):
def __init__(self, dims=None):
super().__init__()
self.dims = dims
self.net = nn.Linear(1024, sum(dims), bias=False)
def forward(self, x):
return self.net(x).split(self.dims, dim=-1)
class DQN(nn.Module):
def __init__(self, *, version=1):
super().__init__()
self.version = version
match version:
case 1:
self.v_head = nn.Linear(512, 1)
self.a_head = nn.Linear(512, ACTION_SPACE)
case 2 | 3:
hidden_size = 512 if version == 2 else 256
self.v_head = nn.Sequential(
nn.Linear(1024, hidden_size),
nn.Mish(inplace=True),
nn.Linear(hidden_size, 1),
)
self.a_head = nn.Sequential(
nn.Linear(1024, hidden_size),
nn.Mish(inplace=True),
nn.Linear(hidden_size, ACTION_SPACE),
)
case 4:
self.net = nn.Linear(1024, 1 + ACTION_SPACE)
nn.init.constant_(self.net.bias, 0)
def forward(self, phi, mask):
if self.version == 4:
v, a = self.net(phi).split((1, ACTION_SPACE), dim=-1)
else:
v = self.v_head(phi)
a = self.a_head(phi)
a_sum = a.masked_fill(~mask, 0.).sum(-1, keepdim=True)
mask_sum = mask.sum(-1, keepdim=True)
a_mean = a_sum / mask_sum
q = (v + a - a_mean).masked_fill(~mask, -torch.inf)
return q
class GRP(nn.Module):
def __init__(self, hidden_size=64, num_layers=2):
super().__init__()
self.rnn = nn.GRU(input_size=GRP_SIZE, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
self.fc = nn.Sequential(
nn.Linear(hidden_size * num_layers, hidden_size * num_layers),
nn.ReLU(inplace=True),
nn.Linear(hidden_size * num_layers, 24),
)
for mod in self.modules():
mod.to(torch.float64)
# perms are the permutations of all possible rank-by-player result
perms = torch.tensor(list(permutations(range(4))))
perms_t = perms.transpose(0, 1)
self.register_buffer('perms', perms) # (24, 4)
self.register_buffer('perms_t', perms_t) # (4, 24)
# input: [grand_kyoku, honba, kyotaku, s[0], s[1], s[2], s[3]]
# grand_kyoku: E1 = 0, S4 = 7, W4 = 11
# s is 2.5 at E1
# s[0] is score of player id 0
def forward(self, inputs: List[Tensor]):
lengths = torch.tensor([t.shape[0] for t in inputs], dtype=torch.int64)
inputs = pad_sequence(inputs, batch_first=True)
packed_inputs = pack_padded_sequence(inputs, lengths, batch_first=True, enforce_sorted=False)
return self.forward_packed(packed_inputs)
def forward_packed(self, packed_inputs):
_, state = self.rnn(packed_inputs)
state = state.transpose(0, 1).flatten(1)
logits = self.fc(state)
return logits
# (N, 24) -> (N, player, rank_prob)
def calc_matrix(self, logits: Tensor):
batch_size = logits.shape[0]
probs = logits.softmax(-1)
matrix = torch.zeros(batch_size, 4, 4, dtype=probs.dtype)
for player in range(4):
for rank in range(4):
cond = self.perms_t[player] == rank
matrix[:, player, rank] = probs[:, cond].sum(-1)
return matrix
# (N, 4) -> (N)
def get_label(self, rank_by_player: Tensor):
batch_size = rank_by_player.shape[0]
perms = self.perms.expand(batch_size, -1, -1).transpose(0, 1)
mappings = (perms == rank_by_player).all(-1).nonzero()
labels = torch.zeros(batch_size, dtype=torch.int64, device=mappings.device)
labels[mappings[:, 1]] = mappings[:, 0]
return labels