Files
new/card_model/dataset.py
2026-04-21 18:03:00 +08:00

196 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""PyTorch Dataset for Poker Card Model.
Generates samples using multi-process data generation and supports
both in-memory generation and disk persistence.
"""
import os
from multiprocessing import Pool
from typing import Optional
import tempfile
import glob
import os
import numpy as np
from multiprocessing import Pool
import numpy as np
import torch
from torch.utils.data import Dataset
from .config import (
BOARD_SIZE,
HOLE_SIZE,
NUM_BINS,
NUM_ROLLOUTS,
NUM_VAL_SAMPLES,
NUM_TRAIN_SAMPLES,
PAD_TOKEN,
)
from .data_generator import generate_sample
class PokerCardDataset(Dataset):
"""Dataset for Card Model equity prediction.
Each sample contains:
x_hole: [2] int64 - player 0's hole card IDs (0-51)
x_board: [5] int64 - board card IDs (0-51), padded with PAD_TOKEN=52
y_equity: [1] float32 - mean equity (win rate)
y_histogram:[50] float32 - normalized equity distribution over bins
"""
def __init__(
self,
num_samples: int,
num_rollouts: int = NUM_ROLLOUTS,
num_workers: int = 0,
save_path: Optional[str] = None,
):
"""Initialize dataset.
Args:
num_samples: Number of samples to generate.
num_rollouts: Monte Carlo rollouts per sample.
num_workers: Number of parallel workers for data generation.
0 = single process (mainly for debugging).
save_path: If provided, cache generated data to this path (.npz).
If the file exists, loads from disk instead of regenerating.
"""
self.num_samples = num_samples
self.num_rollouts = num_rollouts
# Try loading from cache
if save_path and os.path.exists(save_path):
data = np.load(save_path)
self.x_hole = data["x_hole"]
self.x_board = data["x_board"]
self.y_equity = data["y_equity"]
self.y_histogram = data["y_histogram"]
print(f"Loaded {num_samples} samples from {save_path}")
return
# Generate data
print(f"Generating {num_samples} samples with {num_rollouts} rollouts each "
f"({num_workers} workers)...")
if num_workers > 0:
self._generate_parallel(num_samples, num_workers)
else:
self._generate_sequential(num_samples)
# Save to cache
if save_path:
os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
np.savez(
save_path,
x_hole=self.x_hole,
x_board=self.x_board,
y_equity=self.y_equity,
y_histogram=self.y_histogram,
)
print(f"Saved {num_samples} samples to {save_path}")
def _generate_sequential(self, num_samples: int):
"""Generate samples sequentially (single process)."""
import pyspiel
from .config import HUNL_GAME_STRING
game = pyspiel.load_game(HUNL_GAME_STRING)
x_hole_list = []
x_board_list = []
y_equity_list = []
y_histogram_list = []
for i in range(num_samples):
if (i + 1) % 100 == 0:
print(f" Generated {i + 1}/{num_samples} samples...")
x_h, x_b, y_e, y_hist = generate_sample(game)
x_hole_list.append(x_h)
x_board_list.append(x_b)
y_equity_list.append(y_e)
y_histogram_list.append(y_hist)
self.x_hole = np.stack(x_hole_list)
self.x_board = np.stack(x_board_list)
self.y_equity = np.array(y_equity_list, dtype=np.float32).reshape(-1, 1)
self.y_histogram = np.stack(y_histogram_list)
def _generate_parallel(self, num_samples: int, num_workers: int):
"""工业级多进程生成防死锁Worker直接写硬盘主进程读取合并"""
chunk_size = num_samples // num_workers
remainder = num_samples % num_workers
chunks = []
# 在当前目录创建一个临时文件夹用来存 Worker 的碎片数据
temp_dir = tempfile.mkdtemp(prefix="poker_data_", dir=".")
for i in range(num_workers):
c_size = chunk_size + (1 if i < remainder else 0)
chunk_file = os.path.join(temp_dir, f"chunk_{i}.npz")
chunks.append((c_size, chunk_file))
print(f" 分配任务完毕,共有 {num_workers} 个 Worker。数据将临时存入 {temp_dir} ...")
# 使用 imap_unordered算完一个打印一个进度全透明
with Pool(num_workers) as pool:
for i, _ in enumerate(pool.imap_unordered(_generate_chunk_to_disk, chunks)):
print(f" [进度] {i + 1}/{num_workers} 个 CPU 核心已完成计算并落盘。")
print(" 所有 CPU 计算完毕,正在合并碎片数据 (可能需要几十秒)...")
x_hole_list, x_board_list, y_equity_list, y_histogram_list = [], [], [], []
# 依次读取所有临时文件并合并
for _, chunk_file in chunks:
data = np.load(chunk_file)
x_hole_list.append(data["x_hole"])
x_board_list.append(data["x_board"])
y_equity_list.append(data["y_equity"])
y_histogram_list.append(data["y_histogram"])
os.remove(chunk_file) # 读完就删掉临时碎片
os.rmdir(temp_dir) # 删除临时文件夹
self.x_hole = np.concatenate(x_hole_list, axis=0)
self.x_board = np.concatenate(x_board_list, axis=0)
self.y_equity = np.concatenate(y_equity_list, axis=0)
self.y_histogram = np.concatenate(y_histogram_list, axis=0)
def __len__(self) -> int:
return len(self.x_hole)
def __getitem__(self, idx: int):
return (
torch.tensor(self.x_hole[idx], dtype=torch.int64),
torch.tensor(self.x_board[idx], dtype=torch.int64),
torch.tensor(self.y_equity[idx], dtype=torch.float32),
torch.tensor(self.y_histogram[idx], dtype=torch.float32),
)
def _generate_chunk_to_disk(args):
"""Worker 函数:生成数据并直接保存到硬盘临时文件中,彻底绕过 Python 管道"""
num_samples, save_path = args
import pyspiel
from .config import HUNL_GAME_STRING
from .data_generator import generate_sample
game = pyspiel.load_game(HUNL_GAME_STRING)
x_hole_list, x_board_list, y_equity_list, y_histogram_list = [], [], [], []
for _ in range(num_samples):
x_h, x_b, y_e, y_hist = generate_sample(game)
x_hole_list.append(x_h)
x_board_list.append(x_b)
y_equity_list.append(y_e)
y_histogram_list.append(y_hist)
# 算完直接砸进硬盘,绝不通过 return 传回巨量数据!
np.savez_compressed(
save_path,
x_hole=np.stack(x_hole_list),
x_board=np.stack(x_board_list),
y_equity=np.array(y_equity_list, dtype=np.float32).reshape(-1, 1),
y_histogram=np.stack(y_histogram_list)
)
return True