""" cfr_buffer.py — CFR 经验回放池 (极致内存优化版) 使用 PyTorch Tensor 预分配连续内存,实现 Ring Buffer(环形缓冲区)。 彻底抛弃 Python 原生 List 嵌套,解决上百万条经验导致的几十 GB 内存占用爆炸问题。 内存占用物理估算(1000万条): - info_states: 10M * 55 * 4B = 2.2 GB - legal_masks: 10M * 5 * 4B = 200 MB - regrets: 10M * 5 * 4B = 200 MB - strategies: 10M * 5 * 4B = 200 MB 总计固定占用不足 3 GB 系统内存 (RAM),极其高效! """ from typing import List, Tuple import torch class CFRBuffer: """CFR 经验回放池 (Ring Buffer)""" def __init__(self, max_size: int = 10_000_000, feat_dim: int = 55, num_actions: int = 5) -> None: self.max_size = max_size self.feat_dim = feat_dim self.num_actions = num_actions self.ptr = 0 # 写入指针 self.size = 0 # 当前实际数据量 # 核心:初始化时直接向系统申请连续的 C++ 底层内存块 self._info_states = torch.zeros((max_size, feat_dim), dtype=torch.float32) self._legal_masks = torch.zeros((max_size, num_actions), dtype=torch.float32) self._regrets = torch.zeros((max_size, num_actions), dtype=torch.float32) self._strategies = torch.zeros((max_size, num_actions), dtype=torch.float32) def __len__(self) -> int: return self.size def add( self, info_state: List[float], legal_mask: List[int], regrets: List[float], strategy: List[float], ) -> None: """向缓冲区添加一条经验。游标 ptr 循环覆盖旧数据。""" # 直接将 python list 转换为 tensor 并写入预分配的内存对应行 self._info_states[self.ptr] = torch.tensor(info_state, dtype=torch.float32) self._legal_masks[self.ptr] = torch.tensor(legal_mask, dtype=torch.float32) self._regrets[self.ptr] = torch.tensor(regrets, dtype=torch.float32) self._strategies[self.ptr] = torch.tensor(strategy, dtype=torch.float32) # 环形指针推进 self.ptr = (self.ptr + 1) % self.max_size self.size = min(self.size + 1, self.max_size) def add_batch( self, info_states: List[List[float]], legal_masks: List[List[int]], regrets: List[List[float]], strategies: List[List[float]], ) -> None: """批量写入,消除 Python 循环和微型 Tensor 创建的巨大 CPU 开销""" batch_size = len(info_states) if batch_size == 0: return t_info = torch.tensor(info_states, dtype=torch.float32) t_legal = torch.tensor(legal_masks, dtype=torch.float32) t_regrets = torch.tensor(regrets, dtype=torch.float32) t_strats = torch.tensor(strategies, dtype=torch.float32) end_ptr = self.ptr + batch_size if end_ptr <= self.max_size: self._info_states[self.ptr : end_ptr] = t_info self._legal_masks[self.ptr : end_ptr] = t_legal self._regrets[self.ptr : end_ptr] = t_regrets self._strategies[self.ptr : end_ptr] = t_strats else: first_part_len = self.max_size - self.ptr self._info_states[self.ptr : self.max_size] = t_info[:first_part_len] self._legal_masks[self.ptr : self.max_size] = t_legal[:first_part_len] self._regrets[self.ptr : self.max_size] = t_regrets[:first_part_len] self._strategies[self.ptr : self.max_size] = t_strats[:first_part_len] second_part_len = batch_size - first_part_len self._info_states[0 : second_part_len] = t_info[first_part_len:] self._legal_masks[0 : second_part_len] = t_legal[first_part_len:] self._regrets[0 : second_part_len] = t_regrets[first_part_len:] self._strategies[0 : second_part_len] = t_strats[first_part_len:] self.ptr = end_ptr % self.max_size self.size = min(self.size + batch_size, self.max_size) def add_batch_tensors( self, t_info: torch.Tensor, t_legal: torch.Tensor, t_regrets: torch.Tensor, t_strats: torch.Tensor, ) -> None: """极速写入:直接接收 C++ Tensor 进行 Zero-Copy 内存赋值,并包含越界截断保护""" batch_size = t_info.shape[0] if batch_size == 0: return # 极端情况越界保护 if batch_size > self.max_size: t_info = t_info[-self.max_size:] t_legal = t_legal[-self.max_size:] t_regrets = t_regrets[-self.max_size:] t_strats = t_strats[-self.max_size:] batch_size = self.max_size end_ptr = self.ptr + batch_size if end_ptr <= self.max_size: self._info_states[self.ptr : end_ptr] = t_info self._legal_masks[self.ptr : end_ptr] = t_legal self._regrets[self.ptr : end_ptr] = t_regrets self._strategies[self.ptr : end_ptr] = t_strats else: first_part_len = self.max_size - self.ptr self._info_states[self.ptr : self.max_size] = t_info[:first_part_len] self._legal_masks[self.ptr : self.max_size] = t_legal[:first_part_len] self._regrets[self.ptr : self.max_size] = t_regrets[:first_part_len] self._strategies[self.ptr : self.max_size] = t_strats[:first_part_len] second_part_len = batch_size - first_part_len self._info_states[0 : second_part_len] = t_info[first_part_len:] self._legal_masks[0 : second_part_len] = t_legal[first_part_len:] self._regrets[0 : second_part_len] = t_regrets[first_part_len:] self._strategies[0 : second_part_len] = t_strats[first_part_len:] self.ptr = end_ptr % self.max_size self.size = min(self.size + batch_size, self.max_size) def sample( self, batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """从连续内存中极速切片采样""" if self.size == 0: raise ValueError("CFRBuffer 为空,无法采样") actual_bs = min(batch_size, self.size) # 极速生成随机索引 indices = torch.randint(0, self.size, (actual_bs,)) # 直接通过高级索引切片,零额外对象创建开销 return ( self._info_states[indices], self._legal_masks[indices], self._regrets[indices], self._strategies[indices] ) def state_dict(self) -> dict: """导出当前有效数据用于检查点保存。 注意:为了避免检查点文件无意义的庞大,只保存 [0 : self.size] 的有效切片。 """ return { "ptr": self.ptr, "size": self.size, "info_states": self._info_states[:self.size].clone(), "legal_masks": self._legal_masks[:self.size].clone(), "regrets": self._regrets[:self.size].clone(), "strategies": self._strategies[:self.size].clone(), } def load_state_dict(self, state: dict) -> None: """从检查点恢复数据(支持新旧版本兼容,含 max_size 缩小后的安全截断)""" if "size" in state: # ── 新版本 Tensor Ring 格式 ── load_size = min(state["size"], self.max_size) self.size = load_size self.ptr = state["ptr"] % self.max_size # 将保存的有效切片灌回预分配的大内存中(双侧安全截断) self._info_states[:load_size] = state["info_states"][:load_size] self._legal_masks[:load_size] = state["legal_masks"][:load_size] self._regrets[:load_size] = state["regrets"][:load_size] self._strategies[:load_size] = state["strategies"][:load_size] if load_size < state["size"]: print(f"[Buffer] 警告: 检查点含 {state['size']} 条数据,当前 max_size={self.max_size},截断加载 {load_size} 条") else: # ── 旧版本纯 List 格式(向下兼容补丁) ── print("[Buffer] 检测到旧版本检查点,正在将其转换为高速连续 Tensor 格式...") old_info = state["info_states"] # 旧版已经是 Tensor 了 old_legal = torch.tensor(state["legal_masks"], dtype=torch.float32) old_regrets = torch.tensor(state["regrets"], dtype=torch.float32) old_strats = torch.tensor(state["strategies"], dtype=torch.float32) load_size = min(old_legal.shape[0], self.max_size) self.size = load_size # 如果之前 buffer 没满,游标就在末尾;如果满了,游标回绕到 0 self.ptr = self.size % self.max_size if self.size > 0: self._info_states[:load_size] = old_info[:load_size] self._legal_masks[:load_size] = old_legal[:load_size] self._regrets[:load_size] = old_regrets[:load_size] self._strategies[:load_size] = old_strats[:load_size] if load_size < old_legal.shape[0]: print(f"[Buffer] 警告: 旧检查点含 {old_legal.shape[0]} 条数据,当前 max_size={self.max_size},截断加载 {load_size} 条") print(f"[Buffer] 成功灌入 {self.size} 条历史经验。") def clear(self) -> None: self.ptr = 0 self.size = 0 # ─────────────────────── 实例化测试 ─────────────────────── if __name__ == "__main__": print("=" * 60) print("CFRBuffer (Tensor Ring Buffer) 单元测试") print("=" * 60) buf = CFRBuffer(max_size=100) for i in range(120): buf.add([float(i)] * 55, [1]*5, [0.0]*5, [0.2]*5) assert len(buf) == 100, f"长度应为 100,实际 {len(buf)}" assert buf.ptr == 20, f"游标应循环至 20,实际 {buf.ptr}" # 第 0 行原本是 0,现在应该被覆盖为 100 assert buf._info_states[0][0].item() == 100.0, "环形覆盖逻辑错误" sd = buf.state_dict() assert sd["info_states"].shape[0] == 100, "导出形状错误" print("✓ 核心逻辑完全正确!")