229 lines
10 KiB
Python
229 lines
10 KiB
Python
"""
|
||
cfr_buffer.py — CFR 经验回放池 (极致内存优化版)
|
||
|
||
使用 PyTorch Tensor 预分配连续内存,实现 Ring Buffer(环形缓冲区)。
|
||
彻底抛弃 Python 原生 List 嵌套,解决上百万条经验导致的几十 GB 内存占用爆炸问题。
|
||
|
||
内存占用物理估算(1000万条):
|
||
- info_states: 10M * 55 * 4B = 2.2 GB
|
||
- legal_masks: 10M * 5 * 4B = 200 MB
|
||
- regrets: 10M * 5 * 4B = 200 MB
|
||
- strategies: 10M * 5 * 4B = 200 MB
|
||
总计固定占用不足 3 GB 系统内存 (RAM),极其高效!
|
||
"""
|
||
|
||
from typing import List, Tuple
|
||
import torch
|
||
|
||
class CFRBuffer:
|
||
"""CFR 经验回放池 (Ring Buffer)"""
|
||
|
||
def __init__(self, max_size: int = 10_000_000, feat_dim: int = 55, num_actions: int = 5) -> None:
|
||
self.max_size = max_size
|
||
self.feat_dim = feat_dim
|
||
self.num_actions = num_actions
|
||
|
||
self.ptr = 0 # 写入指针
|
||
self.size = 0 # 当前实际数据量
|
||
|
||
# 核心:初始化时直接向系统申请连续的 C++ 底层内存块
|
||
self._info_states = torch.zeros((max_size, feat_dim), dtype=torch.float32)
|
||
self._legal_masks = torch.zeros((max_size, num_actions), dtype=torch.float32)
|
||
self._regrets = torch.zeros((max_size, num_actions), dtype=torch.float32)
|
||
self._strategies = torch.zeros((max_size, num_actions), dtype=torch.float32)
|
||
|
||
def __len__(self) -> int:
|
||
return self.size
|
||
|
||
def add(
|
||
self,
|
||
info_state: List[float],
|
||
legal_mask: List[int],
|
||
regrets: List[float],
|
||
strategy: List[float],
|
||
) -> None:
|
||
"""向缓冲区添加一条经验。游标 ptr 循环覆盖旧数据。"""
|
||
# 直接将 python list 转换为 tensor 并写入预分配的内存对应行
|
||
self._info_states[self.ptr] = torch.tensor(info_state, dtype=torch.float32)
|
||
self._legal_masks[self.ptr] = torch.tensor(legal_mask, dtype=torch.float32)
|
||
self._regrets[self.ptr] = torch.tensor(regrets, dtype=torch.float32)
|
||
self._strategies[self.ptr] = torch.tensor(strategy, dtype=torch.float32)
|
||
|
||
# 环形指针推进
|
||
self.ptr = (self.ptr + 1) % self.max_size
|
||
self.size = min(self.size + 1, self.max_size)
|
||
|
||
def add_batch(
|
||
self,
|
||
info_states: List[List[float]],
|
||
legal_masks: List[List[int]],
|
||
regrets: List[List[float]],
|
||
strategies: List[List[float]],
|
||
) -> None:
|
||
"""批量写入,消除 Python 循环和微型 Tensor 创建的巨大 CPU 开销"""
|
||
batch_size = len(info_states)
|
||
if batch_size == 0: return
|
||
|
||
t_info = torch.tensor(info_states, dtype=torch.float32)
|
||
t_legal = torch.tensor(legal_masks, dtype=torch.float32)
|
||
t_regrets = torch.tensor(regrets, dtype=torch.float32)
|
||
t_strats = torch.tensor(strategies, dtype=torch.float32)
|
||
|
||
end_ptr = self.ptr + batch_size
|
||
if end_ptr <= self.max_size:
|
||
self._info_states[self.ptr : end_ptr] = t_info
|
||
self._legal_masks[self.ptr : end_ptr] = t_legal
|
||
self._regrets[self.ptr : end_ptr] = t_regrets
|
||
self._strategies[self.ptr : end_ptr] = t_strats
|
||
else:
|
||
first_part_len = self.max_size - self.ptr
|
||
self._info_states[self.ptr : self.max_size] = t_info[:first_part_len]
|
||
self._legal_masks[self.ptr : self.max_size] = t_legal[:first_part_len]
|
||
self._regrets[self.ptr : self.max_size] = t_regrets[:first_part_len]
|
||
self._strategies[self.ptr : self.max_size] = t_strats[:first_part_len]
|
||
|
||
second_part_len = batch_size - first_part_len
|
||
self._info_states[0 : second_part_len] = t_info[first_part_len:]
|
||
self._legal_masks[0 : second_part_len] = t_legal[first_part_len:]
|
||
self._regrets[0 : second_part_len] = t_regrets[first_part_len:]
|
||
self._strategies[0 : second_part_len] = t_strats[first_part_len:]
|
||
|
||
self.ptr = end_ptr % self.max_size
|
||
self.size = min(self.size + batch_size, self.max_size)
|
||
|
||
def add_batch_tensors(
|
||
self,
|
||
t_info: torch.Tensor,
|
||
t_legal: torch.Tensor,
|
||
t_regrets: torch.Tensor,
|
||
t_strats: torch.Tensor,
|
||
) -> None:
|
||
"""极速写入:直接接收 C++ Tensor 进行 Zero-Copy 内存赋值,并包含越界截断保护"""
|
||
batch_size = t_info.shape[0]
|
||
if batch_size == 0: return
|
||
|
||
# 极端情况越界保护
|
||
if batch_size > self.max_size:
|
||
t_info = t_info[-self.max_size:]
|
||
t_legal = t_legal[-self.max_size:]
|
||
t_regrets = t_regrets[-self.max_size:]
|
||
t_strats = t_strats[-self.max_size:]
|
||
batch_size = self.max_size
|
||
|
||
end_ptr = self.ptr + batch_size
|
||
if end_ptr <= self.max_size:
|
||
self._info_states[self.ptr : end_ptr] = t_info
|
||
self._legal_masks[self.ptr : end_ptr] = t_legal
|
||
self._regrets[self.ptr : end_ptr] = t_regrets
|
||
self._strategies[self.ptr : end_ptr] = t_strats
|
||
else:
|
||
first_part_len = self.max_size - self.ptr
|
||
self._info_states[self.ptr : self.max_size] = t_info[:first_part_len]
|
||
self._legal_masks[self.ptr : self.max_size] = t_legal[:first_part_len]
|
||
self._regrets[self.ptr : self.max_size] = t_regrets[:first_part_len]
|
||
self._strategies[self.ptr : self.max_size] = t_strats[:first_part_len]
|
||
|
||
second_part_len = batch_size - first_part_len
|
||
self._info_states[0 : second_part_len] = t_info[first_part_len:]
|
||
self._legal_masks[0 : second_part_len] = t_legal[first_part_len:]
|
||
self._regrets[0 : second_part_len] = t_regrets[first_part_len:]
|
||
self._strategies[0 : second_part_len] = t_strats[first_part_len:]
|
||
|
||
self.ptr = end_ptr % self.max_size
|
||
self.size = min(self.size + batch_size, self.max_size)
|
||
|
||
def sample(
|
||
self, batch_size: int
|
||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||
"""从连续内存中极速切片采样"""
|
||
if self.size == 0:
|
||
raise ValueError("CFRBuffer 为空,无法采样")
|
||
|
||
actual_bs = min(batch_size, self.size)
|
||
# 极速生成随机索引
|
||
indices = torch.randint(0, self.size, (actual_bs,))
|
||
|
||
# 直接通过高级索引切片,零额外对象创建开销
|
||
return (
|
||
self._info_states[indices],
|
||
self._legal_masks[indices],
|
||
self._regrets[indices],
|
||
self._strategies[indices]
|
||
)
|
||
|
||
def state_dict(self) -> dict:
|
||
"""导出当前有效数据用于检查点保存。
|
||
注意:为了避免检查点文件无意义的庞大,只保存 [0 : self.size] 的有效切片。
|
||
"""
|
||
return {
|
||
"ptr": self.ptr,
|
||
"size": self.size,
|
||
"info_states": self._info_states[:self.size].clone(),
|
||
"legal_masks": self._legal_masks[:self.size].clone(),
|
||
"regrets": self._regrets[:self.size].clone(),
|
||
"strategies": self._strategies[:self.size].clone(),
|
||
}
|
||
|
||
def load_state_dict(self, state: dict) -> None:
|
||
"""从检查点恢复数据(支持新旧版本兼容,含 max_size 缩小后的安全截断)"""
|
||
if "size" in state:
|
||
# ── 新版本 Tensor Ring 格式 ──
|
||
load_size = min(state["size"], self.max_size)
|
||
self.size = load_size
|
||
self.ptr = state["ptr"] % self.max_size
|
||
|
||
# 将保存的有效切片灌回预分配的大内存中(双侧安全截断)
|
||
self._info_states[:load_size] = state["info_states"][:load_size]
|
||
self._legal_masks[:load_size] = state["legal_masks"][:load_size]
|
||
self._regrets[:load_size] = state["regrets"][:load_size]
|
||
self._strategies[:load_size] = state["strategies"][:load_size]
|
||
|
||
if load_size < state["size"]:
|
||
print(f"[Buffer] 警告: 检查点含 {state['size']} 条数据,当前 max_size={self.max_size},截断加载 {load_size} 条")
|
||
else:
|
||
# ── 旧版本纯 List 格式(向下兼容补丁) ──
|
||
print("[Buffer] 检测到旧版本检查点,正在将其转换为高速连续 Tensor 格式...")
|
||
old_info = state["info_states"] # 旧版已经是 Tensor 了
|
||
old_legal = torch.tensor(state["legal_masks"], dtype=torch.float32)
|
||
old_regrets = torch.tensor(state["regrets"], dtype=torch.float32)
|
||
old_strats = torch.tensor(state["strategies"], dtype=torch.float32)
|
||
|
||
load_size = min(old_legal.shape[0], self.max_size)
|
||
self.size = load_size
|
||
# 如果之前 buffer 没满,游标就在末尾;如果满了,游标回绕到 0
|
||
self.ptr = self.size % self.max_size
|
||
|
||
if self.size > 0:
|
||
self._info_states[:load_size] = old_info[:load_size]
|
||
self._legal_masks[:load_size] = old_legal[:load_size]
|
||
self._regrets[:load_size] = old_regrets[:load_size]
|
||
self._strategies[:load_size] = old_strats[:load_size]
|
||
|
||
if load_size < old_legal.shape[0]:
|
||
print(f"[Buffer] 警告: 旧检查点含 {old_legal.shape[0]} 条数据,当前 max_size={self.max_size},截断加载 {load_size} 条")
|
||
print(f"[Buffer] 成功灌入 {self.size} 条历史经验。")
|
||
|
||
def clear(self) -> None:
|
||
self.ptr = 0
|
||
self.size = 0
|
||
|
||
|
||
# ─────────────────────── 实例化测试 ───────────────────────
|
||
if __name__ == "__main__":
|
||
print("=" * 60)
|
||
print("CFRBuffer (Tensor Ring Buffer) 单元测试")
|
||
print("=" * 60)
|
||
|
||
buf = CFRBuffer(max_size=100)
|
||
|
||
for i in range(120):
|
||
buf.add([float(i)] * 55, [1]*5, [0.0]*5, [0.2]*5)
|
||
|
||
assert len(buf) == 100, f"长度应为 100,实际 {len(buf)}"
|
||
assert buf.ptr == 20, f"游标应循环至 20,实际 {buf.ptr}"
|
||
# 第 0 行原本是 0,现在应该被覆盖为 100
|
||
assert buf._info_states[0][0].item() == 100.0, "环形覆盖逻辑错误"
|
||
|
||
sd = buf.state_dict()
|
||
assert sd["info_states"].shape[0] == 100, "导出形状错误"
|
||
print("✓ 核心逻辑完全正确!") |