"""CardModel network architecture for poker equity prediction. Dual-head model: - equity_head: Sigmoid output -> scalar equity prediction - histogram_head: Softmax output -> 50-bin equity distribution """ import torch import torch.nn as nn from .config import EMBEDDING_DIM, MLP_HIDDEN, NUM_BINS, VOCAB_SIZE, PAD_TOKEN class CardModel(nn.Module): """Neural network predicting equity and equity histogram from cards. Architecture: 1. Embedding layer for 52 cards + 1 PAD token 2. Hole cards: embed -> sum -> hole_emb 3. Board cards: embed -> sum -> board_emb 4. Concatenate [hole_emb, board_emb] -> MLP backbone 5. Dual output heads: - equity_head: MLP -> Sigmoid -> scalar - histogram_head: MLP -> Softmax -> 50-dim distribution """ def __init__( self, vocab_size: int = VOCAB_SIZE, embedding_dim: int = EMBEDDING_DIM, mlp_hidden: list = None, num_bins: int = NUM_BINS, ): """Initialize CardModel. Args: vocab_size: Number of token types (52 cards + 1 pad). embedding_dim: Dimension of card embeddings. mlp_hidden: List of hidden layer sizes for the backbone MLP. num_bins: Number of bins for the equity histogram. """ super().__init__() if mlp_hidden is None: mlp_hidden = list(MLP_HIDDEN) # Embedding: 53 tokens (0-51 cards + 52 pad) -> embedding_dim self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=PAD_TOKEN) # Input to backbone: concat of hole_emb and board_emb backbone_input_dim = embedding_dim * 2 # Build backbone MLP layers = [] in_dim = backbone_input_dim for hidden_dim in mlp_hidden: layers.append(nn.Linear(in_dim, hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.LayerNorm(hidden_dim)) in_dim = hidden_dim self.backbone = nn.Sequential(*layers) # Equity head: scalar output with Sigmoid self.equity_head = nn.Sequential( nn.Linear(mlp_hidden[-1], 32), nn.ReLU(), nn.Linear(32, 1), nn.Sigmoid(), ) # Histogram head: num_bins output with Softmax self.histogram_head = nn.Sequential( nn.Linear(mlp_hidden[-1], 64), nn.ReLU(), nn.Linear(64, num_bins), nn.Softmax(dim=-1), ) def forward(self, x_hole: torch.Tensor, x_board: torch.Tensor): """Forward pass. Args: x_hole: [batch, 2] int64 tensor of hole card IDs. x_board: [batch, 5] int64 tensor of board card IDs (padded with 52). Returns: pred_equity: [batch, 1] float32, predicted equity in [0, 1]. pred_histogram: [batch, 50] float32, predicted equity distribution (sums to 1). """ # Embed hole cards and sum -> [batch, embedding_dim] hole_emb = self.embedding(x_hole).sum(dim=1) # Embed board cards and sum -> [batch, embedding_dim] # PAD_TOKEN (52) has zero embedding due to padding_idx board_emb = self.embedding(x_board).sum(dim=1) # Concatenate -> [batch, embedding_dim * 2] combined = torch.cat([hole_emb, board_emb], dim=-1) # Backbone features = self.backbone(combined) # Dual heads pred_equity = self.equity_head(features) pred_histogram = self.histogram_head(features) return pred_equity, pred_histogram