"""Training script for Card Model. Usage: python train_card_model.py Trains a CardModel to predict equity and equity histogram from hole cards + board cards using Monte Carlo-generated training data. Loss = EMD_loss + lambda * MSE_loss. """ import os import time import torch import torch.nn.functional as F from torch.utils.data import DataLoader from card_model.config import ( BATCH_SIZE, DATA_DIR, LAMBDA_MSE, LEARNING_RATE, NUM_EPOCHS, NUM_ROLLOUTS, NUM_TRAIN_SAMPLES, NUM_VAL_SAMPLES, NUM_WORKERS, WEIGHT_DECAY, ) from card_model.dataset import PokerCardDataset from card_model.model import CardModel def emd_loss_1d(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute 1D Wasserstein (Earth Mover's Distance) loss. For 1D distributions with ordered bins, EMD equals the L1 distance between CDFs: EMD = mean(|cumsum(pred) - cumsum(target)|) Args: pred: [batch, num_bins] predicted histogram (after softmax). target: [batch, num_bins] target histogram (sums to 1). Returns: Scalar loss value. """ cdf_pred = torch.cumsum(pred, dim=-1) cdf_target = torch.cumsum(target, dim=-1) return torch.mean(torch.abs(cdf_pred - cdf_target)) def train_one_epoch(model, dataloader, optimizer, device): """Train for one epoch. Returns: avg_total_loss, avg_emd_loss, avg_mse_loss """ model.train() total_loss_sum = 0.0 emd_loss_sum = 0.0 mse_loss_sum = 0.0 num_batches = 0 for x_hole, x_board, y_equity, y_histogram in dataloader: x_hole = x_hole.to(device) x_board = x_board.to(device) y_equity = y_equity.to(device) y_histogram = y_histogram.to(device) pred_equity, pred_histogram = model(x_hole, x_board) # MSE loss for equity mse = F.mse_loss(pred_equity, y_equity) # EMD loss for histogram emd = emd_loss_1d(pred_histogram, y_histogram) # Total loss loss = emd + LAMBDA_MSE * mse optimizer.zero_grad() loss.backward() optimizer.step() total_loss_sum += loss.item() emd_loss_sum += emd.item() mse_loss_sum += mse.item() num_batches += 1 return ( total_loss_sum / max(num_batches, 1), emd_loss_sum / max(num_batches, 1), mse_loss_sum / max(num_batches, 1), ) @torch.no_grad() def validate(model, dataloader, device): """Evaluate on validation set. Returns: avg_total_loss, avg_emd_loss, avg_mse_loss """ model.eval() total_loss_sum = 0.0 emd_loss_sum = 0.0 mse_loss_sum = 0.0 num_batches = 0 for x_hole, x_board, y_equity, y_histogram in dataloader: x_hole = x_hole.to(device) x_board = x_board.to(device) y_equity = y_equity.to(device) y_histogram = y_histogram.to(device) pred_equity, pred_histogram = model(x_hole, x_board) mse = F.mse_loss(pred_equity, y_equity) emd = emd_loss_1d(pred_histogram, y_histogram) loss = emd + LAMBDA_MSE * mse total_loss_sum += loss.item() emd_loss_sum += emd.item() mse_loss_sum += mse.item() num_batches += 1 return ( total_loss_sum / max(num_batches, 1), emd_loss_sum / max(num_batches, 1), mse_loss_sum / max(num_batches, 1), ) def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # --- Data --- os.makedirs(DATA_DIR, exist_ok=True) train_save = os.path.join(DATA_DIR, "train_data.npz") val_save = os.path.join(DATA_DIR, "val_data.npz") # Use fewer rollouts for quick iteration; increase for production num_rollouts = NUM_ROLLOUTS print("\n=== Generating training data ===") train_dataset = PokerCardDataset( num_samples=NUM_TRAIN_SAMPLES, num_rollouts=num_rollouts, num_workers=NUM_WORKERS, save_path=train_save, ) print("\n=== Generating validation data ===") val_dataset = PokerCardDataset( num_samples=NUM_VAL_SAMPLES, num_rollouts=num_rollouts, num_workers=NUM_WORKERS, save_path=val_save, ) train_loader = DataLoader( train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True if device.type == "cuda" else False, ) val_loader = DataLoader( val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True if device.type == "cuda" else False, ) # --- Model --- model = CardModel().to(device) num_params = sum(p.numel() for p in model.parameters()) print(f"\nModel parameters: {num_params:,}") optimizer = torch.optim.AdamW( model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=NUM_EPOCHS, eta_min=1e-5 ) # --- Training loop --- print(f"\n{'Epoch':>5} | {'Train Loss':>10} | {'EMD':>8} | {'MSE':>8} | " f"{'Val Loss':>10} | {'Val EMD':>8} | {'Val MSE':>8} | {'Time':>6}") print("-" * 85) best_val_loss = float("inf") for epoch in range(1, NUM_EPOCHS + 1): t0 = time.time() train_total, train_emd, train_mse = train_one_epoch( model, train_loader, optimizer, device ) val_total, val_emd, val_mse = validate(model, val_loader, device) scheduler.step() elapsed = time.time() - t0 print( f"{epoch:5d} | {train_total:10.4f} | {train_emd:8.4f} | {train_mse:8.4f} | " f"{val_total:10.4f} | {val_emd:8.4f} | {val_mse:8.4f} | {elapsed:5.1f}s" ) # Save best model if val_total < best_val_loss: best_val_loss = val_total save_path = os.path.join(DATA_DIR, "best_card_model.pt") torch.save(model.state_dict(), save_path) # Save final model final_path = os.path.join(DATA_DIR, "final_card_model.pt") torch.save(model.state_dict(), final_path) print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}") print(f"Best model saved to: {os.path.join(DATA_DIR, 'best_card_model.pt')}") print(f"Final model saved to: {final_path}") if __name__ == "__main__": main()