142 lines
5.9 KiB
Python
142 lines
5.9 KiB
Python
"""
|
||
cfr_buffer.py — CFR 经验回放池 (极致内存优化版)
|
||
|
||
使用 PyTorch Tensor 预分配连续内存,实现 Ring Buffer(环形缓冲区)。
|
||
彻底抛弃 Python 原生 List 嵌套,解决上百万条经验导致的几十 GB 内存占用爆炸问题。
|
||
|
||
内存占用物理估算(1000万条):
|
||
- info_states: 10M * 55 * 4B = 2.2 GB
|
||
- legal_masks: 10M * 6 * 4B = 240 MB
|
||
- regrets: 10M * 6 * 4B = 240 MB
|
||
- strategies: 10M * 6 * 4B = 240 MB
|
||
总计固定占用不足 3 GB 系统内存 (RAM),极其高效!
|
||
"""
|
||
|
||
from typing import List, Tuple
|
||
import torch
|
||
|
||
class CFRBuffer:
|
||
"""CFR 经验回放池 (Ring Buffer)"""
|
||
|
||
def __init__(self, max_size: int = 10_000_000, feat_dim: int = 55, num_actions: int = 5) -> None:
|
||
self.max_size = max_size
|
||
self.feat_dim = feat_dim
|
||
self.num_actions = num_actions
|
||
|
||
self.ptr = 0 # 写入指针
|
||
self.size = 0 # 当前实际数据量
|
||
|
||
# 核心:初始化时直接向系统申请连续的 C++ 底层内存块
|
||
self._info_states = torch.zeros((max_size, feat_dim), dtype=torch.float32)
|
||
self._legal_masks = torch.zeros((max_size, num_actions), dtype=torch.float32)
|
||
self._regrets = torch.zeros((max_size, num_actions), dtype=torch.float32)
|
||
self._strategies = torch.zeros((max_size, num_actions), dtype=torch.float32)
|
||
|
||
def __len__(self) -> int:
|
||
return self.size
|
||
|
||
def add(
|
||
self,
|
||
info_state: List[float],
|
||
legal_mask: List[int],
|
||
regrets: List[float],
|
||
strategy: List[float],
|
||
) -> None:
|
||
"""向缓冲区添加一条经验。游标 ptr 循环覆盖旧数据。"""
|
||
# 直接将 python list 转换为 tensor 并写入预分配的内存对应行
|
||
self._info_states[self.ptr] = torch.tensor(info_state, dtype=torch.float32)
|
||
self._legal_masks[self.ptr] = torch.tensor(legal_mask, dtype=torch.float32)
|
||
self._regrets[self.ptr] = torch.tensor(regrets, dtype=torch.float32)
|
||
self._strategies[self.ptr] = torch.tensor(strategy, dtype=torch.float32)
|
||
|
||
# 环形指针推进
|
||
self.ptr = (self.ptr + 1) % self.max_size
|
||
self.size = min(self.size + 1, self.max_size)
|
||
|
||
def sample(
|
||
self, batch_size: int
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""从连续内存中极速切片采样"""
|
||
if self.size == 0:
|
||
raise ValueError("CFRBuffer 为空,无法采样")
|
||
|
||
actual_bs = min(batch_size, self.size)
|
||
# 极速生成随机索引
|
||
indices = torch.randint(0, self.size, (actual_bs,))
|
||
|
||
# 直接通过高级索引切片,零额外对象创建开销
|
||
return (
|
||
self._info_states[indices],
|
||
self._legal_masks[indices],
|
||
self._regrets[indices],
|
||
self._strategies[indices]
|
||
)
|
||
|
||
def state_dict(self) -> dict:
|
||
"""导出当前有效数据用于检查点保存。
|
||
注意:为了避免检查点文件无意义的庞大,只保存 [0 : self.size] 的有效切片。
|
||
"""
|
||
return {
|
||
"ptr": self.ptr,
|
||
"size": self.size,
|
||
"info_states": self._info_states[:self.size].clone(),
|
||
"legal_masks": self._legal_masks[:self.size].clone(),
|
||
"regrets": self._regrets[:self.size].clone(),
|
||
"strategies": self._strategies[:self.size].clone(),
|
||
}
|
||
|
||
def load_state_dict(self, state: dict) -> None:
|
||
"""从检查点恢复数据(支持新旧版本兼容)"""
|
||
if "size" in state:
|
||
# ── 新版本 Tensor Ring 格式 ──
|
||
self.size = state["size"]
|
||
self.ptr = state["ptr"]
|
||
|
||
# 将保存的有效切片灌回预分配的大内存中
|
||
self._info_states[:self.size] = state["info_states"]
|
||
self._legal_masks[:self.size] = state["legal_masks"]
|
||
self._regrets[:self.size] = state["regrets"]
|
||
self._strategies[:self.size] = state["strategies"]
|
||
else:
|
||
# ── 旧版本纯 List 格式(向下兼容补丁) ──
|
||
print("[Buffer] 检测到旧版本检查点,正在将其转换为高速连续 Tensor 格式...")
|
||
old_info = state["info_states"] # 旧版已经是 Tensor 了
|
||
old_legal = torch.tensor(state["legal_masks"], dtype=torch.float32)
|
||
old_regrets = torch.tensor(state["regrets"], dtype=torch.float32)
|
||
old_strats = torch.tensor(state["strategies"], dtype=torch.float32)
|
||
|
||
self.size = old_legal.shape[0]
|
||
# 如果之前 buffer 没满,游标就在末尾;如果满了,游标回绕到 0
|
||
self.ptr = self.size % self.max_size
|
||
|
||
if self.size > 0:
|
||
self._info_states[:self.size] = old_info
|
||
self._legal_masks[:self.size] = old_legal
|
||
self._regrets[:self.size] = old_regrets
|
||
self._strategies[:self.size] = old_strats
|
||
print(f"[Buffer] 成功灌入 {self.size} 条历史经验。")
|
||
|
||
def clear(self) -> None:
|
||
self.ptr = 0
|
||
self.size = 0
|
||
|
||
|
||
# ─────────────────────── 实例化测试 ───────────────────────
|
||
if __name__ == "__main__":
|
||
print("=" * 60)
|
||
print("CFRBuffer (Tensor Ring Buffer) 单元测试")
|
||
print("=" * 60)
|
||
|
||
buf = CFRBuffer(max_size=100)
|
||
|
||
for i in range(120):
|
||
buf.add([float(i)] * 55, [1]*6, [0.0]*6, [0.1]*6)
|
||
|
||
assert len(buf) == 100, f"长度应为 100,实际 {len(buf)}"
|
||
assert buf.ptr == 20, f"游标应循环至 20,实际 {buf.ptr}"
|
||
# 第 0 行原本是 0,现在应该被覆盖为 100
|
||
assert buf._info_states[0][0].item() == 100.0, "环形覆盖逻辑错误"
|
||
|
||
sd = buf.state_dict()
|
||
assert sd["info_states"].shape[0] == 100, "导出形状错误"
|
||
print("✓ 核心逻辑完全正确!") |