232 lines
6.3 KiB
Python
232 lines
6.3 KiB
Python
"""Training script for Card Model.
|
|
|
|
Usage:
|
|
python train_card_model.py
|
|
|
|
Trains a CardModel to predict equity and equity histogram from
|
|
hole cards + board cards using Monte Carlo-generated training data.
|
|
Loss = EMD_loss + lambda * MSE_loss.
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import DataLoader
|
|
|
|
from card_model.config import (
|
|
BATCH_SIZE,
|
|
DATA_DIR,
|
|
LAMBDA_MSE,
|
|
LEARNING_RATE,
|
|
NUM_EPOCHS,
|
|
NUM_ROLLOUTS,
|
|
NUM_TRAIN_SAMPLES,
|
|
NUM_VAL_SAMPLES,
|
|
NUM_WORKERS,
|
|
WEIGHT_DECAY,
|
|
)
|
|
from card_model.dataset import PokerCardDataset
|
|
from card_model.model import CardModel
|
|
|
|
|
|
def emd_loss_1d(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
"""Compute 1D Wasserstein (Earth Mover's Distance) loss.
|
|
|
|
For 1D distributions with ordered bins, EMD equals the L1 distance
|
|
between CDFs:
|
|
EMD = mean(|cumsum(pred) - cumsum(target)|)
|
|
|
|
Args:
|
|
pred: [batch, num_bins] predicted histogram (after softmax).
|
|
target: [batch, num_bins] target histogram (sums to 1).
|
|
|
|
Returns:
|
|
Scalar loss value.
|
|
"""
|
|
cdf_pred = torch.cumsum(pred, dim=-1)
|
|
cdf_target = torch.cumsum(target, dim=-1)
|
|
return torch.mean(torch.abs(cdf_pred - cdf_target))
|
|
|
|
|
|
def train_one_epoch(model, dataloader, optimizer, device):
|
|
"""Train for one epoch.
|
|
|
|
Returns:
|
|
avg_total_loss, avg_emd_loss, avg_mse_loss
|
|
"""
|
|
model.train()
|
|
total_loss_sum = 0.0
|
|
emd_loss_sum = 0.0
|
|
mse_loss_sum = 0.0
|
|
num_batches = 0
|
|
|
|
for x_hole, x_board, y_equity, y_histogram in dataloader:
|
|
x_hole = x_hole.to(device)
|
|
x_board = x_board.to(device)
|
|
y_equity = y_equity.to(device)
|
|
y_histogram = y_histogram.to(device)
|
|
|
|
pred_equity, pred_histogram = model(x_hole, x_board)
|
|
|
|
# MSE loss for equity
|
|
mse = F.mse_loss(pred_equity, y_equity)
|
|
|
|
# EMD loss for histogram
|
|
emd = emd_loss_1d(pred_histogram, y_histogram)
|
|
|
|
# Total loss
|
|
loss = emd + LAMBDA_MSE * mse
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss_sum += loss.item()
|
|
emd_loss_sum += emd.item()
|
|
mse_loss_sum += mse.item()
|
|
num_batches += 1
|
|
|
|
return (
|
|
total_loss_sum / max(num_batches, 1),
|
|
emd_loss_sum / max(num_batches, 1),
|
|
mse_loss_sum / max(num_batches, 1),
|
|
)
|
|
|
|
|
|
@torch.no_grad()
|
|
def validate(model, dataloader, device):
|
|
"""Evaluate on validation set.
|
|
|
|
Returns:
|
|
avg_total_loss, avg_emd_loss, avg_mse_loss
|
|
"""
|
|
model.eval()
|
|
total_loss_sum = 0.0
|
|
emd_loss_sum = 0.0
|
|
mse_loss_sum = 0.0
|
|
num_batches = 0
|
|
|
|
for x_hole, x_board, y_equity, y_histogram in dataloader:
|
|
x_hole = x_hole.to(device)
|
|
x_board = x_board.to(device)
|
|
y_equity = y_equity.to(device)
|
|
y_histogram = y_histogram.to(device)
|
|
|
|
pred_equity, pred_histogram = model(x_hole, x_board)
|
|
|
|
mse = F.mse_loss(pred_equity, y_equity)
|
|
emd = emd_loss_1d(pred_histogram, y_histogram)
|
|
loss = emd + LAMBDA_MSE * mse
|
|
|
|
total_loss_sum += loss.item()
|
|
emd_loss_sum += emd.item()
|
|
mse_loss_sum += mse.item()
|
|
num_batches += 1
|
|
|
|
return (
|
|
total_loss_sum / max(num_batches, 1),
|
|
emd_loss_sum / max(num_batches, 1),
|
|
mse_loss_sum / max(num_batches, 1),
|
|
)
|
|
|
|
|
|
def main():
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
print(f"Using device: {device}")
|
|
|
|
# --- Data ---
|
|
os.makedirs(DATA_DIR, exist_ok=True)
|
|
train_save = os.path.join(DATA_DIR, "train_data.npz")
|
|
val_save = os.path.join(DATA_DIR, "val_data.npz")
|
|
|
|
# Use fewer rollouts for quick iteration; increase for production
|
|
num_rollouts = NUM_ROLLOUTS
|
|
|
|
print("\n=== Generating training data ===")
|
|
train_dataset = PokerCardDataset(
|
|
num_samples=NUM_TRAIN_SAMPLES,
|
|
num_rollouts=num_rollouts,
|
|
num_workers=NUM_WORKERS,
|
|
save_path=train_save,
|
|
)
|
|
|
|
print("\n=== Generating validation data ===")
|
|
val_dataset = PokerCardDataset(
|
|
num_samples=NUM_VAL_SAMPLES,
|
|
num_rollouts=num_rollouts,
|
|
num_workers=NUM_WORKERS,
|
|
save_path=val_save,
|
|
)
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=True,
|
|
num_workers=NUM_WORKERS,
|
|
pin_memory=True if device.type == "cuda" else False,
|
|
)
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=BATCH_SIZE,
|
|
shuffle=False,
|
|
num_workers=NUM_WORKERS,
|
|
pin_memory=True if device.type == "cuda" else False,
|
|
)
|
|
|
|
# --- Model ---
|
|
model = CardModel().to(device)
|
|
num_params = sum(p.numel() for p in model.parameters())
|
|
print(f"\nModel parameters: {num_params:,}")
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=LEARNING_RATE,
|
|
weight_decay=WEIGHT_DECAY,
|
|
)
|
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
optimizer, T_max=NUM_EPOCHS, eta_min=1e-5
|
|
)
|
|
|
|
# --- Training loop ---
|
|
print(f"\n{'Epoch':>5} | {'Train Loss':>10} | {'EMD':>8} | {'MSE':>8} | "
|
|
f"{'Val Loss':>10} | {'Val EMD':>8} | {'Val MSE':>8} | {'Time':>6}")
|
|
print("-" * 85)
|
|
|
|
best_val_loss = float("inf")
|
|
|
|
for epoch in range(1, NUM_EPOCHS + 1):
|
|
t0 = time.time()
|
|
|
|
train_total, train_emd, train_mse = train_one_epoch(
|
|
model, train_loader, optimizer, device
|
|
)
|
|
val_total, val_emd, val_mse = validate(model, val_loader, device)
|
|
|
|
scheduler.step()
|
|
|
|
elapsed = time.time() - t0
|
|
|
|
print(
|
|
f"{epoch:5d} | {train_total:10.4f} | {train_emd:8.4f} | {train_mse:8.4f} | "
|
|
f"{val_total:10.4f} | {val_emd:8.4f} | {val_mse:8.4f} | {elapsed:5.1f}s"
|
|
)
|
|
|
|
# Save best model
|
|
if val_total < best_val_loss:
|
|
best_val_loss = val_total
|
|
save_path = os.path.join(DATA_DIR, "best_card_model.pt")
|
|
torch.save(model.state_dict(), save_path)
|
|
|
|
# Save final model
|
|
final_path = os.path.join(DATA_DIR, "final_card_model.pt")
|
|
torch.save(model.state_dict(), final_path)
|
|
print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}")
|
|
print(f"Best model saved to: {os.path.join(DATA_DIR, 'best_card_model.pt')}")
|
|
print(f"Final model saved to: {final_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|