""" cfr_buffer.py — CFR 经验回放池 (极致内存优化版) 使用 PyTorch Tensor 预分配连续内存,实现 Ring Buffer(环形缓冲区)。 彻底抛弃 Python 原生 List 嵌套,解决上百万条经验导致的几十 GB 内存占用爆炸问题。 内存占用物理估算(1000万条): - info_states: 10M * 55 * 4B = 2.2 GB - legal_masks: 10M * 6 * 4B = 240 MB - regrets: 10M * 6 * 4B = 240 MB - strategies: 10M * 6 * 4B = 240 MB 总计固定占用不足 3 GB 系统内存 (RAM),极其高效! """ from typing import List, Tuple import torch class CFRBuffer: """CFR 经验回放池 (Ring Buffer)""" def __init__(self, max_size: int = 10_000_000, feat_dim: int = 55, num_actions: int = 5) -> None: self.max_size = max_size self.feat_dim = feat_dim self.num_actions = num_actions self.ptr = 0 # 写入指针 self.size = 0 # 当前实际数据量 # 核心:初始化时直接向系统申请连续的 C++ 底层内存块 self._info_states = torch.zeros((max_size, feat_dim), dtype=torch.float32) self._legal_masks = torch.zeros((max_size, num_actions), dtype=torch.float32) self._regrets = torch.zeros((max_size, num_actions), dtype=torch.float32) self._strategies = torch.zeros((max_size, num_actions), dtype=torch.float32) def __len__(self) -> int: return self.size def add( self, info_state: List[float], legal_mask: List[int], regrets: List[float], strategy: List[float], ) -> None: """向缓冲区添加一条经验。游标 ptr 循环覆盖旧数据。""" # 直接将 python list 转换为 tensor 并写入预分配的内存对应行 self._info_states[self.ptr] = torch.tensor(info_state, dtype=torch.float32) self._legal_masks[self.ptr] = torch.tensor(legal_mask, dtype=torch.float32) self._regrets[self.ptr] = torch.tensor(regrets, dtype=torch.float32) self._strategies[self.ptr] = torch.tensor(strategy, dtype=torch.float32) # 环形指针推进 self.ptr = (self.ptr + 1) % self.max_size self.size = min(self.size + 1, self.max_size) def sample( self, batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """从连续内存中极速切片采样""" if self.size == 0: raise ValueError("CFRBuffer 为空,无法采样") actual_bs = min(batch_size, self.size) # 极速生成随机索引 indices = torch.randint(0, self.size, (actual_bs,)) # 直接通过高级索引切片,零额外对象创建开销 return ( self._info_states[indices], self._legal_masks[indices], self._regrets[indices], self._strategies[indices] ) def state_dict(self) -> dict: """导出当前有效数据用于检查点保存。 注意:为了避免检查点文件无意义的庞大,只保存 [0 : self.size] 的有效切片。 """ return { "ptr": self.ptr, "size": self.size, "info_states": self._info_states[:self.size].clone(), "legal_masks": self._legal_masks[:self.size].clone(), "regrets": self._regrets[:self.size].clone(), "strategies": self._strategies[:self.size].clone(), } def load_state_dict(self, state: dict) -> None: """从检查点恢复数据(支持新旧版本兼容)""" if "size" in state: # ── 新版本 Tensor Ring 格式 ── self.size = state["size"] self.ptr = state["ptr"] # 将保存的有效切片灌回预分配的大内存中 self._info_states[:self.size] = state["info_states"] self._legal_masks[:self.size] = state["legal_masks"] self._regrets[:self.size] = state["regrets"] self._strategies[:self.size] = state["strategies"] else: # ── 旧版本纯 List 格式(向下兼容补丁) ── print("[Buffer] 检测到旧版本检查点,正在将其转换为高速连续 Tensor 格式...") old_info = state["info_states"] # 旧版已经是 Tensor 了 old_legal = torch.tensor(state["legal_masks"], dtype=torch.float32) old_regrets = torch.tensor(state["regrets"], dtype=torch.float32) old_strats = torch.tensor(state["strategies"], dtype=torch.float32) self.size = old_legal.shape[0] # 如果之前 buffer 没满,游标就在末尾;如果满了,游标回绕到 0 self.ptr = self.size % self.max_size if self.size > 0: self._info_states[:self.size] = old_info self._legal_masks[:self.size] = old_legal self._regrets[:self.size] = old_regrets self._strategies[:self.size] = old_strats print(f"[Buffer] 成功灌入 {self.size} 条历史经验。") def clear(self) -> None: self.ptr = 0 self.size = 0 # ─────────────────────── 实例化测试 ─────────────────────── if __name__ == "__main__": print("=" * 60) print("CFRBuffer (Tensor Ring Buffer) 单元测试") print("=" * 60) buf = CFRBuffer(max_size=100) for i in range(120): buf.add([float(i)] * 55, [1]*6, [0.0]*6, [0.1]*6) assert len(buf) == 100, f"长度应为 100,实际 {len(buf)}" assert buf.ptr == 20, f"游标应循环至 20,实际 {buf.ptr}" # 第 0 行原本是 0,现在应该被覆盖为 100 assert buf._info_states[0][0].item() == 100.0, "环形覆盖逻辑错误" sd = buf.state_dict() assert sd["info_states"].shape[0] == 100, "导出形状错误" print("✓ 核心逻辑完全正确!")