Files
new/card_model/train_card_model.py
e2hang ed2fadb625 What
2026-04-20 20:25:35 +08:00

232 lines
6.3 KiB
Python

"""Training script for Card Model.
Usage:
python train_card_model.py
Trains a CardModel to predict equity and equity histogram from
hole cards + board cards using Monte Carlo-generated training data.
Loss = EMD_loss + lambda * MSE_loss.
"""
import os
import time
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from card_model.config import (
BATCH_SIZE,
DATA_DIR,
LAMBDA_MSE,
LEARNING_RATE,
NUM_EPOCHS,
NUM_ROLLOUTS,
NUM_TRAIN_SAMPLES,
NUM_VAL_SAMPLES,
NUM_WORKERS,
WEIGHT_DECAY,
)
from card_model.dataset import PokerCardDataset
from card_model.model import CardModel
def emd_loss_1d(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute 1D Wasserstein (Earth Mover's Distance) loss.
For 1D distributions with ordered bins, EMD equals the L1 distance
between CDFs:
EMD = mean(|cumsum(pred) - cumsum(target)|)
Args:
pred: [batch, num_bins] predicted histogram (after softmax).
target: [batch, num_bins] target histogram (sums to 1).
Returns:
Scalar loss value.
"""
cdf_pred = torch.cumsum(pred, dim=-1)
cdf_target = torch.cumsum(target, dim=-1)
return torch.mean(torch.abs(cdf_pred - cdf_target))
def train_one_epoch(model, dataloader, optimizer, device):
"""Train for one epoch.
Returns:
avg_total_loss, avg_emd_loss, avg_mse_loss
"""
model.train()
total_loss_sum = 0.0
emd_loss_sum = 0.0
mse_loss_sum = 0.0
num_batches = 0
for x_hole, x_board, y_equity, y_histogram in dataloader:
x_hole = x_hole.to(device)
x_board = x_board.to(device)
y_equity = y_equity.to(device)
y_histogram = y_histogram.to(device)
pred_equity, pred_histogram = model(x_hole, x_board)
# MSE loss for equity
mse = F.mse_loss(pred_equity, y_equity)
# EMD loss for histogram
emd = emd_loss_1d(pred_histogram, y_histogram)
# Total loss
loss = emd + LAMBDA_MSE * mse
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss_sum += loss.item()
emd_loss_sum += emd.item()
mse_loss_sum += mse.item()
num_batches += 1
return (
total_loss_sum / max(num_batches, 1),
emd_loss_sum / max(num_batches, 1),
mse_loss_sum / max(num_batches, 1),
)
@torch.no_grad()
def validate(model, dataloader, device):
"""Evaluate on validation set.
Returns:
avg_total_loss, avg_emd_loss, avg_mse_loss
"""
model.eval()
total_loss_sum = 0.0
emd_loss_sum = 0.0
mse_loss_sum = 0.0
num_batches = 0
for x_hole, x_board, y_equity, y_histogram in dataloader:
x_hole = x_hole.to(device)
x_board = x_board.to(device)
y_equity = y_equity.to(device)
y_histogram = y_histogram.to(device)
pred_equity, pred_histogram = model(x_hole, x_board)
mse = F.mse_loss(pred_equity, y_equity)
emd = emd_loss_1d(pred_histogram, y_histogram)
loss = emd + LAMBDA_MSE * mse
total_loss_sum += loss.item()
emd_loss_sum += emd.item()
mse_loss_sum += mse.item()
num_batches += 1
return (
total_loss_sum / max(num_batches, 1),
emd_loss_sum / max(num_batches, 1),
mse_loss_sum / max(num_batches, 1),
)
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- Data ---
os.makedirs(DATA_DIR, exist_ok=True)
train_save = os.path.join(DATA_DIR, "train_data.npz")
val_save = os.path.join(DATA_DIR, "val_data.npz")
# Use fewer rollouts for quick iteration; increase for production
num_rollouts = NUM_ROLLOUTS
print("\n=== Generating training data ===")
train_dataset = PokerCardDataset(
num_samples=NUM_TRAIN_SAMPLES,
num_rollouts=num_rollouts,
num_workers=NUM_WORKERS,
save_path=train_save,
)
print("\n=== Generating validation data ===")
val_dataset = PokerCardDataset(
num_samples=NUM_VAL_SAMPLES,
num_rollouts=num_rollouts,
num_workers=NUM_WORKERS,
save_path=val_save,
)
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=NUM_WORKERS,
pin_memory=True if device.type == "cuda" else False,
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True if device.type == "cuda" else False,
)
# --- Model ---
model = CardModel().to(device)
num_params = sum(p.numel() for p in model.parameters())
print(f"\nModel parameters: {num_params:,}")
optimizer = torch.optim.AdamW(
model.parameters(),
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=NUM_EPOCHS, eta_min=1e-5
)
# --- Training loop ---
print(f"\n{'Epoch':>5} | {'Train Loss':>10} | {'EMD':>8} | {'MSE':>8} | "
f"{'Val Loss':>10} | {'Val EMD':>8} | {'Val MSE':>8} | {'Time':>6}")
print("-" * 85)
best_val_loss = float("inf")
for epoch in range(1, NUM_EPOCHS + 1):
t0 = time.time()
train_total, train_emd, train_mse = train_one_epoch(
model, train_loader, optimizer, device
)
val_total, val_emd, val_mse = validate(model, val_loader, device)
scheduler.step()
elapsed = time.time() - t0
print(
f"{epoch:5d} | {train_total:10.4f} | {train_emd:8.4f} | {train_mse:8.4f} | "
f"{val_total:10.4f} | {val_emd:8.4f} | {val_mse:8.4f} | {elapsed:5.1f}s"
)
# Save best model
if val_total < best_val_loss:
best_val_loss = val_total
save_path = os.path.join(DATA_DIR, "best_card_model.pt")
torch.save(model.state_dict(), save_path)
# Save final model
final_path = os.path.join(DATA_DIR, "final_card_model.pt")
torch.save(model.state_dict(), final_path)
print(f"\nTraining complete. Best val loss: {best_val_loss:.4f}")
print(f"Best model saved to: {os.path.join(DATA_DIR, 'best_card_model.pt')}")
print(f"Final model saved to: {final_path}")
if __name__ == "__main__":
main()