73 lines
2.5 KiB
Python
73 lines
2.5 KiB
Python
"""Configuration constants for the Card Model."""
|
|
|
|
# --- Card encoding ---
|
|
NUM_CARDS = 52 # Standard 52-card deck (0-51)
|
|
PAD_TOKEN = 52 # Padding token for missing board cards
|
|
VOCAB_SIZE = 53 # 52 cards + 1 pad token
|
|
|
|
# --- Card ID mapping (OpenSpiel universal_poker) ---
|
|
# ID = rank * 4 + suit, where:
|
|
# rank: 0=2, 1=3, ..., 8=T, 9=J, 10=Q, 11=K, 12=A
|
|
# suit: 0=c, 1=d, 2=h, 3=s
|
|
RANK_NAMES = ['2', '3', '4', '5', '6', '7', '8', '9', 'T', 'J', 'Q', 'K', 'A']
|
|
SUIT_NAMES = ['c', 'd', 'h', 's']
|
|
|
|
# --- Game configuration ---
|
|
# Botzone 比赛规则对齐:
|
|
# - 2人 HUNL, 初始筹码 20000, SB=50, BB=100
|
|
# - OpenSpiel blind 参数顺序: [player0_blind, player1_blind]
|
|
# - blind=50 100 → player0=SB(50), player1=BB(100)
|
|
# - numBoardCards=0 3 1 1 → Preflop/Flop/Turn/River 发牌数
|
|
HUNL_GAME_STRING = (
|
|
"universal_poker("
|
|
"betting=nolimit,"
|
|
"numPlayers=2,"
|
|
"numRanks=13,"
|
|
"numSuits=4,"
|
|
"numHoleCards=2,"
|
|
"numRounds=4,"
|
|
"numBoardCards=0 3 1 1,"
|
|
"stack=20000 20000,"
|
|
"blind=50 100,"
|
|
"bettingAbstraction=fullgame)"
|
|
)
|
|
|
|
# --- Data generation ---
|
|
NUM_ROLLOUTS = 1000 # Monte Carlo rollouts per sample
|
|
NUM_BINS = 50 # Equity histogram bins
|
|
HOLE_SIZE = 2 # Number of hole cards
|
|
BOARD_SIZE = 5 # Max number of board cards (3 flop + 1 turn + 1 river)
|
|
|
|
# --- Model ---
|
|
EMBEDDING_DIM = 64
|
|
MLP_HIDDEN = [512, 512, 256]
|
|
|
|
# --- Paths ---
|
|
import os
|
|
_PACKAGE_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
DATA_DIR = os.path.join(_PACKAGE_DIR, "data")
|
|
CHECKPOINT_DIR = os.path.join(_PACKAGE_DIR, "checkpoints")
|
|
|
|
# --- Training ---
|
|
#BATCH_SIZE = 256
|
|
#LEARNING_RATE = 1e-3
|
|
WEIGHT_DECAY = 1e-4
|
|
#NUM_EPOCHS = 50
|
|
LAMBDA_MSE = 0.1 # Weight for MSE loss relative to EMD loss
|
|
#NUM_TRAIN_SAMPLES = 20000 # Adjust up to 50000+ for production quality
|
|
#NUM_VAL_SAMPLES = 2000 # Adjust up to 5000+ for production quality
|
|
#NUM_WORKERS = 4 # DataLoader workers
|
|
|
|
#NUM_ROLLOUTS = 1000 # 保持不变,保证直方图精度
|
|
#NUM_BINS = 50 # 保持不变
|
|
|
|
# 扩大百倍的数据量,让它背下所有极端牌型(如天皇家同花顺)
|
|
NUM_TRAIN_SAMPLES = 10_000_000 # 200 万局样本 (大概需要几个小时生成)
|
|
NUM_VAL_SAMPLES = 100_000 # 5 万局验证
|
|
NUM_WORKERS = 22 # 填入你 CPU 的核心数 (例如 32 或 64)
|
|
|
|
# 增大 Batch Size 榨取显存
|
|
BATCH_SIZE = 16384
|
|
LEARNING_RATE = 5e-4 # 调小一点,使其收敛更平滑
|
|
NUM_EPOCHS = 64 # 多跑几轮
|