196 lines
6.9 KiB
Python
196 lines
6.9 KiB
Python
"""PyTorch Dataset for Poker Card Model.
|
||
|
||
Generates samples using multi-process data generation and supports
|
||
both in-memory generation and disk persistence.
|
||
"""
|
||
|
||
import os
|
||
from multiprocessing import Pool
|
||
from typing import Optional
|
||
import tempfile
|
||
import glob
|
||
import os
|
||
import numpy as np
|
||
from multiprocessing import Pool
|
||
import numpy as np
|
||
import torch
|
||
from torch.utils.data import Dataset
|
||
|
||
from .config import (
|
||
BOARD_SIZE,
|
||
HOLE_SIZE,
|
||
NUM_BINS,
|
||
NUM_ROLLOUTS,
|
||
NUM_VAL_SAMPLES,
|
||
NUM_TRAIN_SAMPLES,
|
||
PAD_TOKEN,
|
||
)
|
||
from .data_generator import generate_sample
|
||
|
||
|
||
class PokerCardDataset(Dataset):
|
||
"""Dataset for Card Model equity prediction.
|
||
|
||
Each sample contains:
|
||
x_hole: [2] int64 - player 0's hole card IDs (0-51)
|
||
x_board: [5] int64 - board card IDs (0-51), padded with PAD_TOKEN=52
|
||
y_equity: [1] float32 - mean equity (win rate)
|
||
y_histogram:[50] float32 - normalized equity distribution over bins
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
num_samples: int,
|
||
num_rollouts: int = NUM_ROLLOUTS,
|
||
num_workers: int = 0,
|
||
save_path: Optional[str] = None,
|
||
):
|
||
"""Initialize dataset.
|
||
|
||
Args:
|
||
num_samples: Number of samples to generate.
|
||
num_rollouts: Monte Carlo rollouts per sample.
|
||
num_workers: Number of parallel workers for data generation.
|
||
0 = single process (mainly for debugging).
|
||
save_path: If provided, cache generated data to this path (.npz).
|
||
If the file exists, loads from disk instead of regenerating.
|
||
"""
|
||
self.num_samples = num_samples
|
||
self.num_rollouts = num_rollouts
|
||
|
||
# Try loading from cache
|
||
if save_path and os.path.exists(save_path):
|
||
data = np.load(save_path)
|
||
self.x_hole = data["x_hole"]
|
||
self.x_board = data["x_board"]
|
||
self.y_equity = data["y_equity"]
|
||
self.y_histogram = data["y_histogram"]
|
||
print(f"Loaded {num_samples} samples from {save_path}")
|
||
return
|
||
|
||
# Generate data
|
||
print(f"Generating {num_samples} samples with {num_rollouts} rollouts each "
|
||
f"({num_workers} workers)...")
|
||
|
||
if num_workers > 0:
|
||
self._generate_parallel(num_samples, num_workers)
|
||
else:
|
||
self._generate_sequential(num_samples)
|
||
|
||
# Save to cache
|
||
if save_path:
|
||
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
||
np.savez(
|
||
save_path,
|
||
x_hole=self.x_hole,
|
||
x_board=self.x_board,
|
||
y_equity=self.y_equity,
|
||
y_histogram=self.y_histogram,
|
||
)
|
||
print(f"Saved {num_samples} samples to {save_path}")
|
||
|
||
def _generate_sequential(self, num_samples: int):
|
||
"""Generate samples sequentially (single process)."""
|
||
import pyspiel
|
||
from .config import HUNL_GAME_STRING
|
||
game = pyspiel.load_game(HUNL_GAME_STRING)
|
||
|
||
x_hole_list = []
|
||
x_board_list = []
|
||
y_equity_list = []
|
||
y_histogram_list = []
|
||
|
||
for i in range(num_samples):
|
||
if (i + 1) % 100 == 0:
|
||
print(f" Generated {i + 1}/{num_samples} samples...")
|
||
x_h, x_b, y_e, y_hist = generate_sample(game)
|
||
x_hole_list.append(x_h)
|
||
x_board_list.append(x_b)
|
||
y_equity_list.append(y_e)
|
||
y_histogram_list.append(y_hist)
|
||
|
||
self.x_hole = np.stack(x_hole_list)
|
||
self.x_board = np.stack(x_board_list)
|
||
self.y_equity = np.array(y_equity_list, dtype=np.float32).reshape(-1, 1)
|
||
self.y_histogram = np.stack(y_histogram_list)
|
||
|
||
def _generate_parallel(self, num_samples: int, num_workers: int):
|
||
"""工业级多进程生成(防死锁):Worker直接写硬盘,主进程读取合并"""
|
||
chunk_size = num_samples // num_workers
|
||
remainder = num_samples % num_workers
|
||
chunks = []
|
||
|
||
# 在当前目录创建一个临时文件夹用来存 Worker 的碎片数据
|
||
temp_dir = tempfile.mkdtemp(prefix="poker_data_", dir=".")
|
||
|
||
for i in range(num_workers):
|
||
c_size = chunk_size + (1 if i < remainder else 0)
|
||
chunk_file = os.path.join(temp_dir, f"chunk_{i}.npz")
|
||
chunks.append((c_size, chunk_file))
|
||
|
||
print(f" 分配任务完毕,共有 {num_workers} 个 Worker。数据将临时存入 {temp_dir} ...")
|
||
|
||
# 使用 imap_unordered,算完一个打印一个,进度全透明!
|
||
with Pool(num_workers) as pool:
|
||
for i, _ in enumerate(pool.imap_unordered(_generate_chunk_to_disk, chunks)):
|
||
print(f" [进度] {i + 1}/{num_workers} 个 CPU 核心已完成计算并落盘。")
|
||
|
||
print(" 所有 CPU 计算完毕,正在合并碎片数据 (可能需要几十秒)...")
|
||
|
||
x_hole_list, x_board_list, y_equity_list, y_histogram_list = [], [], [], []
|
||
|
||
# 依次读取所有临时文件并合并
|
||
for _, chunk_file in chunks:
|
||
data = np.load(chunk_file)
|
||
x_hole_list.append(data["x_hole"])
|
||
x_board_list.append(data["x_board"])
|
||
y_equity_list.append(data["y_equity"])
|
||
y_histogram_list.append(data["y_histogram"])
|
||
os.remove(chunk_file) # 读完就删掉临时碎片
|
||
|
||
os.rmdir(temp_dir) # 删除临时文件夹
|
||
|
||
self.x_hole = np.concatenate(x_hole_list, axis=0)
|
||
self.x_board = np.concatenate(x_board_list, axis=0)
|
||
self.y_equity = np.concatenate(y_equity_list, axis=0)
|
||
self.y_histogram = np.concatenate(y_histogram_list, axis=0)
|
||
|
||
def __len__(self) -> int:
|
||
return len(self.x_hole)
|
||
|
||
def __getitem__(self, idx: int):
|
||
return (
|
||
torch.tensor(self.x_hole[idx], dtype=torch.int64),
|
||
torch.tensor(self.x_board[idx], dtype=torch.int64),
|
||
torch.tensor(self.y_equity[idx], dtype=torch.float32),
|
||
torch.tensor(self.y_histogram[idx], dtype=torch.float32),
|
||
)
|
||
|
||
|
||
def _generate_chunk_to_disk(args):
|
||
"""Worker 函数:生成数据并直接保存到硬盘临时文件中,彻底绕过 Python 管道"""
|
||
num_samples, save_path = args
|
||
import pyspiel
|
||
from .config import HUNL_GAME_STRING
|
||
from .data_generator import generate_sample
|
||
game = pyspiel.load_game(HUNL_GAME_STRING)
|
||
|
||
x_hole_list, x_board_list, y_equity_list, y_histogram_list = [], [], [], []
|
||
|
||
for _ in range(num_samples):
|
||
x_h, x_b, y_e, y_hist = generate_sample(game)
|
||
x_hole_list.append(x_h)
|
||
x_board_list.append(x_b)
|
||
y_equity_list.append(y_e)
|
||
y_histogram_list.append(y_hist)
|
||
|
||
# 算完直接砸进硬盘,绝不通过 return 传回巨量数据!
|
||
np.savez_compressed(
|
||
save_path,
|
||
x_hole=np.stack(x_hole_list),
|
||
x_board=np.stack(x_board_list),
|
||
y_equity=np.array(y_equity_list, dtype=np.float32).reshape(-1, 1),
|
||
y_histogram=np.stack(y_histogram_list)
|
||
)
|
||
return True
|