Files
new/cfr_buffer.py
2026-05-06 17:36:51 +08:00

142 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
cfr_buffer.py — CFR 经验回放池 (极致内存优化版)
使用 PyTorch Tensor 预分配连续内存,实现 Ring Buffer环形缓冲区
彻底抛弃 Python 原生 List 嵌套,解决上百万条经验导致的几十 GB 内存占用爆炸问题。
内存占用物理估算1000万条:
- info_states: 10M * 55 * 4B = 2.2 GB
- legal_masks: 10M * 6 * 4B = 240 MB
- regrets: 10M * 6 * 4B = 240 MB
- strategies: 10M * 6 * 4B = 240 MB
总计固定占用不足 3 GB 系统内存 (RAM),极其高效!
"""
from typing import List, Tuple
import torch
class CFRBuffer:
"""CFR 经验回放池 (Ring Buffer)"""
def __init__(self, max_size: int = 10_000_000, feat_dim: int = 55, num_actions: int = 5) -> None:
self.max_size = max_size
self.feat_dim = feat_dim
self.num_actions = num_actions
self.ptr = 0 # 写入指针
self.size = 0 # 当前实际数据量
# 核心:初始化时直接向系统申请连续的 C++ 底层内存块
self._info_states = torch.zeros((max_size, feat_dim), dtype=torch.float32)
self._legal_masks = torch.zeros((max_size, num_actions), dtype=torch.float32)
self._regrets = torch.zeros((max_size, num_actions), dtype=torch.float32)
self._strategies = torch.zeros((max_size, num_actions), dtype=torch.float32)
def __len__(self) -> int:
return self.size
def add(
self,
info_state: List[float],
legal_mask: List[int],
regrets: List[float],
strategy: List[float],
) -> None:
"""向缓冲区添加一条经验。游标 ptr 循环覆盖旧数据。"""
# 直接将 python list 转换为 tensor 并写入预分配的内存对应行
self._info_states[self.ptr] = torch.tensor(info_state, dtype=torch.float32)
self._legal_masks[self.ptr] = torch.tensor(legal_mask, dtype=torch.float32)
self._regrets[self.ptr] = torch.tensor(regrets, dtype=torch.float32)
self._strategies[self.ptr] = torch.tensor(strategy, dtype=torch.float32)
# 环形指针推进
self.ptr = (self.ptr + 1) % self.max_size
self.size = min(self.size + 1, self.max_size)
def sample(
self, batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""从连续内存中极速切片采样"""
if self.size == 0:
raise ValueError("CFRBuffer 为空,无法采样")
actual_bs = min(batch_size, self.size)
# 极速生成随机索引
indices = torch.randint(0, self.size, (actual_bs,))
# 直接通过高级索引切片,零额外对象创建开销
return (
self._info_states[indices],
self._legal_masks[indices],
self._regrets[indices],
self._strategies[indices]
)
def state_dict(self) -> dict:
"""导出当前有效数据用于检查点保存。
注意:为了避免检查点文件无意义的庞大,只保存 [0 : self.size] 的有效切片。
"""
return {
"ptr": self.ptr,
"size": self.size,
"info_states": self._info_states[:self.size].clone(),
"legal_masks": self._legal_masks[:self.size].clone(),
"regrets": self._regrets[:self.size].clone(),
"strategies": self._strategies[:self.size].clone(),
}
def load_state_dict(self, state: dict) -> None:
"""从检查点恢复数据(支持新旧版本兼容)"""
if "size" in state:
# ── 新版本 Tensor Ring 格式 ──
self.size = state["size"]
self.ptr = state["ptr"]
# 将保存的有效切片灌回预分配的大内存中
self._info_states[:self.size] = state["info_states"]
self._legal_masks[:self.size] = state["legal_masks"]
self._regrets[:self.size] = state["regrets"]
self._strategies[:self.size] = state["strategies"]
else:
# ── 旧版本纯 List 格式(向下兼容补丁) ──
print("[Buffer] 检测到旧版本检查点,正在将其转换为高速连续 Tensor 格式...")
old_info = state["info_states"] # 旧版已经是 Tensor 了
old_legal = torch.tensor(state["legal_masks"], dtype=torch.float32)
old_regrets = torch.tensor(state["regrets"], dtype=torch.float32)
old_strats = torch.tensor(state["strategies"], dtype=torch.float32)
self.size = old_legal.shape[0]
# 如果之前 buffer 没满,游标就在末尾;如果满了,游标回绕到 0
self.ptr = self.size % self.max_size
if self.size > 0:
self._info_states[:self.size] = old_info
self._legal_masks[:self.size] = old_legal
self._regrets[:self.size] = old_regrets
self._strategies[:self.size] = old_strats
print(f"[Buffer] 成功灌入 {self.size} 条历史经验。")
def clear(self) -> None:
self.ptr = 0
self.size = 0
# ─────────────────────── 实例化测试 ───────────────────────
if __name__ == "__main__":
print("=" * 60)
print("CFRBuffer (Tensor Ring Buffer) 单元测试")
print("=" * 60)
buf = CFRBuffer(max_size=100)
for i in range(120):
buf.add([float(i)] * 55, [1]*6, [0.0]*6, [0.1]*6)
assert len(buf) == 100, f"长度应为 100实际 {len(buf)}"
assert buf.ptr == 20, f"游标应循环至 20实际 {buf.ptr}"
# 第 0 行原本是 0现在应该被覆盖为 100
assert buf._info_states[0][0].item() == 100.0, "环形覆盖逻辑错误"
sd = buf.state_dict()
assert sd["info_states"].shape[0] == 100, "导出形状错误"
print("✓ 核心逻辑完全正确!")