244 lines
7.8 KiB
Python
244 lines
7.8 KiB
Python
"""Data generator for Card Model using OpenSpiel Monte Carlo rollouts.
|
|
|
|
Generates (hole_cards, board_cards, equity, histogram) samples by:
|
|
1. Randomly sampling a game state (preflop / flop / turn)
|
|
2. Running N Monte Carlo rollouts to river
|
|
3. Computing equity (mean win rate) and histogram (binned distribution)
|
|
"""
|
|
|
|
import random
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import pyspiel
|
|
|
|
from .config import (
|
|
BOARD_SIZE,
|
|
HOLE_SIZE,
|
|
HUNL_GAME_STRING,
|
|
NUM_BINS,
|
|
NUM_ROLLOUTS,
|
|
PAD_TOKEN,
|
|
)
|
|
|
|
|
|
def _card_id_to_rank_suit(card_id: int) -> Tuple[int, int]:
|
|
"""Convert OpenSpiel card ID to (rank, suit).
|
|
|
|
Mapping: card_id = rank * 4 + suit
|
|
rank: 0=2, 1=3, ..., 8=T, 9=J, 10=Q, 11=K, 12=A
|
|
suit: 0=c, 1=d, 2=h, 3=s
|
|
"""
|
|
return card_id // 4, card_id % 4
|
|
|
|
|
|
def _rank_suit_to_card_id(rank: int, suit: int) -> int:
|
|
"""Convert (rank, suit) back to OpenSpiel card ID."""
|
|
return rank * 4 + suit
|
|
|
|
|
|
def extract_cards_from_state(state) -> Tuple[list, list]:
|
|
"""Extract player 0's hole cards and board cards from an OpenSpiel state.
|
|
|
|
Uses the state's to_dict() method to get card strings, then maps them
|
|
back to integer IDs (0-51).
|
|
|
|
Args:
|
|
state: An OpenSpiel UniversalPokerState object.
|
|
|
|
Returns:
|
|
hole_cards: List of 2 integers (0-51) for player 0's hole cards.
|
|
board_cards: List of 0-5 integers (0-51) for current board cards.
|
|
"""
|
|
d = state.to_dict()
|
|
|
|
# Parse hole cards string, e.g. "AcJs" -> ["Ac", "Js"]
|
|
hole_str = d["player_hands"][0]
|
|
hole_cards = []
|
|
i = 0
|
|
while i < len(hole_str):
|
|
rank_char = hole_str[i]
|
|
i += 1
|
|
suit_char = hole_str[i]
|
|
i += 1
|
|
rank = "23456789TJQKA".index(rank_char)
|
|
suit = "cdhs".index(suit_char)
|
|
hole_cards.append(_rank_suit_to_card_id(rank, suit))
|
|
|
|
# Parse board cards string, e.g. "KcTc3d" -> ["Kc", "Tc", "3d"]
|
|
board_str = d["board_cards"]
|
|
board_cards = []
|
|
if board_str:
|
|
i = 0
|
|
while i < len(board_str):
|
|
rank_char = board_str[i]
|
|
i += 1
|
|
suit_char = board_str[i]
|
|
i += 1
|
|
rank = "23456789TJQKA".index(rank_char)
|
|
suit = "cdhs".index(suit_char)
|
|
board_cards.append(_rank_suit_to_card_id(rank, suit))
|
|
|
|
return hole_cards, board_cards
|
|
|
|
|
|
def _sample_random_state(game) -> Tuple:
|
|
"""Sample a random game state at preflop, flop, or turn.
|
|
|
|
Deals hole cards, then randomly decides how many betting rounds to play
|
|
(0=preflop, 1=flop, 2=turn) using check/call only.
|
|
|
|
Returns:
|
|
state: The OpenSpiel state at the chosen street.
|
|
hole_cards: Player 0's hole card IDs.
|
|
board_cards: Board card IDs.
|
|
used_cards: Set of all card IDs already dealt.
|
|
"""
|
|
state = game.new_initial_state()
|
|
|
|
# Deal hole cards
|
|
while state.is_chance_node():
|
|
outcomes = state.chance_outcomes()
|
|
action = random.choice(outcomes)[0]
|
|
state.apply_action(action)
|
|
|
|
# Decide how many streets to advance (0=preflop, 1=flop, 2=turn)
|
|
target_street = random.randint(0, 2)
|
|
|
|
for _ in range(target_street):
|
|
# Both players check/call to advance the round
|
|
while not state.is_chance_node() and not state.is_terminal():
|
|
state.apply_action(1) # 1 = check/call
|
|
if state.is_terminal():
|
|
break
|
|
# Deal community cards
|
|
while state.is_chance_node():
|
|
outcomes = state.chance_outcomes()
|
|
action = random.choice(outcomes)[0]
|
|
state.apply_action(action)
|
|
|
|
if state.is_terminal():
|
|
# Fallback: start over
|
|
return _sample_random_state(game)
|
|
|
|
hole_cards, board_cards = extract_cards_from_state(state)
|
|
used_cards = set(hole_cards) | set(board_cards)
|
|
return state, hole_cards, board_cards, used_cards
|
|
|
|
|
|
def _monte_carlo_rollout(game, hole_cards, board_cards, used_cards) -> float:
|
|
"""Run a single Monte Carlo rollout from the current state to river.
|
|
|
|
Randomly deals the remaining board cards and opponent's hole cards,
|
|
then evaluates who wins.
|
|
|
|
Args:
|
|
game: OpenSpiel game object.
|
|
hole_cards: Player 0's hole card IDs.
|
|
board_cards: Current board card IDs.
|
|
used_cards: Set of card IDs already in play.
|
|
|
|
Returns:
|
|
1.0 if player 0 wins, 0.5 for tie, 0.0 for loss.
|
|
"""
|
|
remaining_deck = [c for c in range(52) if c not in used_cards]
|
|
random.shuffle(remaining_deck)
|
|
|
|
ri = 0 # Index into remaining deck
|
|
|
|
# Build complete board: need 5 cards total
|
|
full_board = list(board_cards)
|
|
while len(full_board) < 5:
|
|
full_board.append(remaining_deck[ri])
|
|
ri += 1
|
|
|
|
# Assign opponent hole cards (2 cards)
|
|
opp_hole = [remaining_deck[ri], remaining_deck[ri + 1]]
|
|
ri += 2
|
|
|
|
# Evaluate hand strengths using a fresh game state
|
|
# We simulate a runout by creating a game and dealing cards in order
|
|
state = game.new_initial_state()
|
|
|
|
# Build the ordered list of all cards to deal:
|
|
# OpenSpiel deals: P0_card0, P0_card1, P1_card0, P1_card1,
|
|
# then flop(3), turn(1), river(1)
|
|
cards_to_deal = [hole_cards[0], hole_cards[1],
|
|
opp_hole[0], opp_hole[1]]
|
|
cards_to_deal.extend(full_board)
|
|
|
|
card_idx = 0
|
|
while not state.is_terminal():
|
|
if state.is_chance_node():
|
|
# Find the matching card ID among chance outcomes
|
|
outcomes = state.chance_outcomes()
|
|
target = cards_to_deal[card_idx]
|
|
# Find outcome action that matches our target card
|
|
found = False
|
|
for action, prob in outcomes:
|
|
if action == target:
|
|
state.apply_action(action)
|
|
card_idx += 1
|
|
found = True
|
|
break
|
|
if not found:
|
|
# Card not available in outcomes (shouldn't happen if logic is correct)
|
|
# Fall back to random
|
|
action = random.choice(outcomes)[0]
|
|
state.apply_action(action)
|
|
card_idx += 1
|
|
else:
|
|
state.apply_action(1) # check/call
|
|
|
|
returns = state.returns()
|
|
if returns[0] > 0:
|
|
return 1.0
|
|
elif returns[0] == 0:
|
|
return 0.5
|
|
else:
|
|
return 0.0
|
|
|
|
|
|
def generate_sample(game=None) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray]:
|
|
"""Generate a single training sample via Monte Carlo rollout.
|
|
|
|
Args:
|
|
game: Optional pre-loaded OpenSpiel game. If None, loads one.
|
|
|
|
Returns:
|
|
x_hole: np.ndarray of shape [2], int64, player 0's hole card IDs.
|
|
x_board: np.ndarray of shape [5], int64, board card IDs (padded with PAD_TOKEN).
|
|
y_equity: float32 scalar, mean win rate across rollouts.
|
|
y_histogram: np.ndarray of shape [50], float32, normalized equity distribution.
|
|
"""
|
|
if game is None:
|
|
game = pyspiel.load_game(HUNL_GAME_STRING)
|
|
|
|
state, hole_cards, board_cards, used_cards = _sample_random_state(game)
|
|
|
|
# Run Monte Carlo rollouts
|
|
results = np.zeros(NUM_ROLLOUTS, dtype=np.float32)
|
|
for i in range(NUM_ROLLOUTS):
|
|
results[i] = _monte_carlo_rollout(game, hole_cards, board_cards, used_cards)
|
|
|
|
# Compute equity target (mean win rate)
|
|
y_equity = float(np.mean(results))
|
|
|
|
# Compute histogram target (binned distribution)
|
|
bin_edges = np.linspace(0, 1, NUM_BINS + 1)
|
|
# Each rollout result is in [0, 1]; bin it
|
|
hist, _ = np.histogram(results, bins=bin_edges)
|
|
y_histogram = hist.astype(np.float32)
|
|
hist_sum = y_histogram.sum()
|
|
if hist_sum > 0:
|
|
y_histogram /= hist_sum
|
|
else:
|
|
y_histogram = np.ones(NUM_BINS, dtype=np.float32) / NUM_BINS
|
|
|
|
# Pad board cards to fixed size
|
|
x_board = np.array(board_cards + [PAD_TOKEN] * (BOARD_SIZE - len(board_cards)),
|
|
dtype=np.int64)
|
|
x_hole = np.array(hole_cards, dtype=np.int64)
|
|
|
|
return x_hole, x_board, np.float32(y_equity), y_histogram
|