"""PyTorch Dataset for Poker Card Model. Generates samples using multi-process data generation and supports both in-memory generation and disk persistence. """ import os from multiprocessing import Pool from typing import Optional import tempfile import glob import os import numpy as np from multiprocessing import Pool import numpy as np import torch from torch.utils.data import Dataset from .config import ( BOARD_SIZE, HOLE_SIZE, NUM_BINS, NUM_ROLLOUTS, NUM_VAL_SAMPLES, NUM_TRAIN_SAMPLES, PAD_TOKEN, ) from .data_generator import generate_sample class PokerCardDataset(Dataset): """Dataset for Card Model equity prediction. Each sample contains: x_hole: [2] int64 - player 0's hole card IDs (0-51) x_board: [5] int64 - board card IDs (0-51), padded with PAD_TOKEN=52 y_equity: [1] float32 - mean equity (win rate) y_histogram:[50] float32 - normalized equity distribution over bins """ def __init__( self, num_samples: int, num_rollouts: int = NUM_ROLLOUTS, num_workers: int = 0, save_path: Optional[str] = None, ): """Initialize dataset. Args: num_samples: Number of samples to generate. num_rollouts: Monte Carlo rollouts per sample. num_workers: Number of parallel workers for data generation. 0 = single process (mainly for debugging). save_path: If provided, cache generated data to this path (.npz). If the file exists, loads from disk instead of regenerating. """ self.num_samples = num_samples self.num_rollouts = num_rollouts # Try loading from cache if save_path and os.path.exists(save_path): data = np.load(save_path) self.x_hole = data["x_hole"] self.x_board = data["x_board"] self.y_equity = data["y_equity"] self.y_histogram = data["y_histogram"] print(f"Loaded {num_samples} samples from {save_path}") return # Generate data print(f"Generating {num_samples} samples with {num_rollouts} rollouts each " f"({num_workers} workers)...") if num_workers > 0: self._generate_parallel(num_samples, num_workers) else: self._generate_sequential(num_samples) # Save to cache if save_path: os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) np.savez( save_path, x_hole=self.x_hole, x_board=self.x_board, y_equity=self.y_equity, y_histogram=self.y_histogram, ) print(f"Saved {num_samples} samples to {save_path}") def _generate_sequential(self, num_samples: int): """Generate samples sequentially (single process).""" import pyspiel from .config import HUNL_GAME_STRING game = pyspiel.load_game(HUNL_GAME_STRING) x_hole_list = [] x_board_list = [] y_equity_list = [] y_histogram_list = [] for i in range(num_samples): if (i + 1) % 100 == 0: print(f" Generated {i + 1}/{num_samples} samples...") x_h, x_b, y_e, y_hist = generate_sample(game) x_hole_list.append(x_h) x_board_list.append(x_b) y_equity_list.append(y_e) y_histogram_list.append(y_hist) self.x_hole = np.stack(x_hole_list) self.x_board = np.stack(x_board_list) self.y_equity = np.array(y_equity_list, dtype=np.float32).reshape(-1, 1) self.y_histogram = np.stack(y_histogram_list) def _generate_parallel(self, num_samples: int, num_workers: int): """工业级多进程生成(防死锁):Worker直接写硬盘,主进程读取合并""" chunk_size = num_samples // num_workers remainder = num_samples % num_workers chunks = [] # 在当前目录创建一个临时文件夹用来存 Worker 的碎片数据 temp_dir = tempfile.mkdtemp(prefix="poker_data_", dir=".") for i in range(num_workers): c_size = chunk_size + (1 if i < remainder else 0) chunk_file = os.path.join(temp_dir, f"chunk_{i}.npz") chunks.append((c_size, chunk_file)) print(f" 分配任务完毕,共有 {num_workers} 个 Worker。数据将临时存入 {temp_dir} ...") # 使用 imap_unordered,算完一个打印一个,进度全透明! with Pool(num_workers) as pool: for i, _ in enumerate(pool.imap_unordered(_generate_chunk_to_disk, chunks)): print(f" [进度] {i + 1}/{num_workers} 个 CPU 核心已完成计算并落盘。") print(" 所有 CPU 计算完毕,正在合并碎片数据 (可能需要几十秒)...") x_hole_list, x_board_list, y_equity_list, y_histogram_list = [], [], [], [] # 依次读取所有临时文件并合并 for _, chunk_file in chunks: data = np.load(chunk_file) x_hole_list.append(data["x_hole"]) x_board_list.append(data["x_board"]) y_equity_list.append(data["y_equity"]) y_histogram_list.append(data["y_histogram"]) os.remove(chunk_file) # 读完就删掉临时碎片 os.rmdir(temp_dir) # 删除临时文件夹 self.x_hole = np.concatenate(x_hole_list, axis=0) self.x_board = np.concatenate(x_board_list, axis=0) self.y_equity = np.concatenate(y_equity_list, axis=0) self.y_histogram = np.concatenate(y_histogram_list, axis=0) def __len__(self) -> int: return len(self.x_hole) def __getitem__(self, idx: int): return ( torch.tensor(self.x_hole[idx], dtype=torch.int64), torch.tensor(self.x_board[idx], dtype=torch.int64), torch.tensor(self.y_equity[idx], dtype=torch.float32), torch.tensor(self.y_histogram[idx], dtype=torch.float32), ) def _generate_chunk_to_disk(args): """Worker 函数:生成数据并直接保存到硬盘临时文件中,彻底绕过 Python 管道""" num_samples, save_path = args import pyspiel from .config import HUNL_GAME_STRING from .data_generator import generate_sample game = pyspiel.load_game(HUNL_GAME_STRING) x_hole_list, x_board_list, y_equity_list, y_histogram_list = [], [], [], [] for _ in range(num_samples): x_h, x_b, y_e, y_hist = generate_sample(game) x_hole_list.append(x_h) x_board_list.append(x_b) y_equity_list.append(y_e) y_histogram_list.append(y_hist) # 算完直接砸进硬盘,绝不通过 return 传回巨量数据! np.savez_compressed( save_path, x_hole=np.stack(x_hole_list), x_board=np.stack(x_board_list), y_equity=np.array(y_equity_list, dtype=np.float32).reshape(-1, 1), y_histogram=np.stack(y_histogram_list) ) return True