- EmotionEngine: 5状态马尔可夫情绪机 + 蒙特卡洛转移 - VectorMemory: TF-IDF向量记忆 + SQLite持久化 + RAG检索 - AgentBrain: Ollama/OpenAI/Dummy三后端LLM - BehaviorScheduler: 优先级/冷却/活跃度调度 - FastAPI服务器 + WebSocket实时推送 - perception: 键鼠监控 + 屏幕截图 - ui/pet_window: PySide6桌宠窗口 + 像素动画 - assets/pet: 5情绪各2帧像素艺术资源
1009 lines
32 KiB
Python
1009 lines
32 KiB
Python
"""
|
||
EzVibe 向量记忆系统
|
||
====================
|
||
设计文档对应章节:核心模块结构 - 向量记忆系统
|
||
|
||
核心功能
|
||
• 长期记忆存取(文本 + 向量嵌入 + 元数据)
|
||
• 语义相似度检索(Cosine Similarity)
|
||
• RAG 工作流:Query → Embedding → Top-k 检索 → 拼入上下文
|
||
|
||
存储方案
|
||
初期:SQLite + JSON(配合 NumPy 计算)
|
||
进阶:FAISS / ChromaDB(预留接口)
|
||
|
||
嵌入函数
|
||
f: text → R^n
|
||
默认适配器:Ollama(本地 LLM 服务),可选 OpenAI / sklearn
|
||
|
||
RAG 工作流
|
||
1. 接收输入 Query
|
||
2. 将 Query 转换为 Embedding
|
||
3. 在 Memory 库中检索 Top-k 记忆片段
|
||
4. 将检索结果拼接入 System Prompt 上下文
|
||
5. 调用 LLM 进行推理并生成最终回复
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import math
|
||
import time
|
||
import uuid
|
||
from dataclasses import dataclass, field, asdict
|
||
from pathlib import Path
|
||
from typing import Any, Optional
|
||
|
||
import numpy as np
|
||
|
||
# sklearn 是 DummyEmbedder 的可选依赖
|
||
try:
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
except ImportError: # pragma: no cover
|
||
TfidfVectorizer = None # type: ignore[assignment, misc]
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ================================================================
|
||
# 1. 数据模型
|
||
# ================================================================
|
||
|
||
@dataclass
|
||
class MemoryEntry:
|
||
"""
|
||
单条记忆条目。
|
||
|
||
参数
|
||
----
|
||
text : str
|
||
记忆文本内容。
|
||
embedding : list[float]
|
||
对应的向量嵌入(由 embedder 生成)。
|
||
tags : list[str]
|
||
可选标签,用于分类过滤。
|
||
metadata : dict
|
||
可选附加元数据(如来源、关联用户行为等)。
|
||
created_at : float
|
||
创建时间戳(Unix timestamp)。
|
||
"""
|
||
text: str
|
||
embedding: list[float]
|
||
tags: list[str] = field(default_factory=list)
|
||
metadata: dict = field(default_factory=dict)
|
||
created_at: float = field(default_factory=lambda: time.time())
|
||
# id 由存储层(MemoryStore)分配,创建时为 None,insert 后填充
|
||
id: str | None = field(default=None, repr=False)
|
||
|
||
def to_dict(self) -> dict:
|
||
return asdict(self)
|
||
|
||
@classmethod
|
||
def from_dict(cls, d: dict) -> "MemoryEntry":
|
||
d = dict(d)
|
||
d.setdefault("tags", [])
|
||
d.setdefault("metadata", {})
|
||
d.setdefault("created_at", time.time())
|
||
return cls(**d)
|
||
|
||
|
||
# ================================================================
|
||
# 2. 嵌入函数适配器
|
||
# ================================================================
|
||
|
||
class EmbedderBase:
|
||
"""嵌入函数抽象基类。"""
|
||
|
||
@property
|
||
def dimension(self) -> int:
|
||
"""返回嵌入向量维度。子类实现。"""
|
||
raise NotImplementedError
|
||
|
||
async def embed(self, text: str) -> list[float]:
|
||
"""将文本转为嵌入向量。子类实现。"""
|
||
raise NotImplementedError
|
||
|
||
def embed_sync(self, text: str) -> list[float]:
|
||
"""同步版本(部分 embedder 不支持异步)。"""
|
||
raise NotImplementedError
|
||
|
||
|
||
class OllamaEmbedder(EmbedderBase):
|
||
"""
|
||
Ollama 本地嵌入适配器。
|
||
|
||
参数
|
||
----
|
||
base_url : str
|
||
Ollama 服务地址,默认 http://localhost:11434。
|
||
model : str
|
||
嵌入模型名,默认 nomic-embed-text。
|
||
其他可选:mxbai-embed-large, bge-m3, qwen2.5 等等。
|
||
"""
|
||
|
||
DEFAULT_URL = "http://localhost:11434"
|
||
DEFAULT_MODEL = "nomic-embed-text"
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str = DEFAULT_URL,
|
||
model: str = DEFAULT_MODEL,
|
||
) -> None:
|
||
self.base_url = base_url.rstrip("/")
|
||
self.model = model
|
||
self._client = None # lazy import
|
||
self._dim: int | None = None
|
||
|
||
@property
|
||
def dimension(self) -> int:
|
||
if self._dim is not None:
|
||
return self._dim
|
||
# 探测维度(首次 embed 时更新缓存)
|
||
import httpx
|
||
try:
|
||
resp = httpx.post(
|
||
f"{self.base_url}/api/embeddings",
|
||
json={"model": self.model, "prompt": "dimension_probe"},
|
||
timeout=5.0,
|
||
)
|
||
if resp.status_code == 200:
|
||
vec = resp.json().get("embedding", [])
|
||
self._dim = len(vec)
|
||
return self._dim
|
||
except Exception:
|
||
pass
|
||
# fallback
|
||
self._dim = 768
|
||
return self._dim
|
||
|
||
def embed_sync(self, text: str) -> list[float]:
|
||
import httpx
|
||
resp = httpx.post(
|
||
f"{self.base_url}/api/embeddings",
|
||
json={"model": self.model, "prompt": text},
|
||
timeout=30.0,
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()["embedding"]
|
||
|
||
async def embed(self, text: str) -> list[float]:
|
||
"""同步封装(httpx 默认同步,asyncio 线程池托底)。"""
|
||
import asyncio
|
||
loop = asyncio.get_running_loop()
|
||
return await loop.run_in_executor(None, self.embed_sync, text)
|
||
|
||
|
||
class OpenAIEmbedder(EmbedderBase):
|
||
"""
|
||
OpenAI 兼容嵌入适配器(如 OpenAI、DeepSeek、本地兼容 API)。
|
||
|
||
参数
|
||
----
|
||
api_key : str
|
||
API 密钥。设为 "local" 时使用本地兼容模式。
|
||
base_url : str
|
||
API 端点,默认 https://api.openai.com/v1。
|
||
model : str
|
||
嵌入模型,默认 text-embedding-3-small。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str = "local",
|
||
base_url: str = "https://api.openai.com/v1",
|
||
model: str = "text-embedding-3-small",
|
||
) -> None:
|
||
self.api_key = api_key
|
||
self.base_url = base_url.rstrip("/")
|
||
self.model = model
|
||
self._dim: int | None = None
|
||
|
||
@property
|
||
def dimension(self) -> int:
|
||
if self._dim is not None:
|
||
return self._dim
|
||
# text-embedding-3-small 默认 1536
|
||
self._dim = 1536
|
||
return self._dim
|
||
|
||
def embed_sync(self, text: str) -> list[float]:
|
||
import httpx
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
resp = httpx.post(
|
||
f"{self.base_url}/embeddings",
|
||
headers=headers,
|
||
json={"model": self.model, "input": text},
|
||
timeout=30.0,
|
||
)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
return data["data"][0]["embedding"]
|
||
|
||
async def embed(self, text: str) -> list[float]:
|
||
import asyncio
|
||
loop = asyncio.get_running_loop()
|
||
return await loop.run_in_executor(None, self.embed_sync, text)
|
||
|
||
|
||
class DummyEmbedder(EmbedderBase):
|
||
"""
|
||
基于 TF-IDF 的轻量嵌入器(用于开发/测试阶段)。
|
||
|
||
支持语义级别的相似度检索,相同/近义文本向量相似。
|
||
不依赖外部 API,不消耗远程资源。
|
||
注意:不应用于生产环境(请换 Ollama / OpenAI 等真实 Embedder)。
|
||
"""
|
||
|
||
DIM = 128 # 固定维度,TF-IDF 输出与 vocabulary size 绑定
|
||
|
||
def __init__(self, dimension: int = DIM) -> None:
|
||
if TfidfVectorizer is None:
|
||
raise ImportError(
|
||
"scikit-learn is required for DummyEmbedder. "
|
||
"Install it with: pip install scikit-learn"
|
||
)
|
||
self._dim = dimension
|
||
self._tfidf = TfidfVectorizer(
|
||
max_features=min(dimension * 4, 512),
|
||
analyzer="char_wb", # 字符级(带边界),适合无空格的中文
|
||
ngram_range=(1, 3), # 1-3 字 n-gram,捕获"奶茶"="奶"+"茶"+"奶茶"
|
||
lowercase=False, # 中文无需 lowercase
|
||
)
|
||
self._fitted = False
|
||
|
||
@property
|
||
def dimension(self) -> int:
|
||
return self._dim
|
||
|
||
def embed_sync(self, text: str) -> list[float]:
|
||
"""
|
||
基于 TF-IDF 生成向量,动态适配 vocabulary。
|
||
|
||
首次调用时会 fit_transform 整个语料库,
|
||
之后每次 add/search 调用会增量 transform。
|
||
"""
|
||
if not self._fitted:
|
||
# 没有语料,无法初始化 → 返回零向量(兼容初始化阶段)
|
||
return [0.0] * self._dim
|
||
|
||
vec = self._tfidf.transform([text]).toarray()[0].tolist()
|
||
# padding/truncate to fixed dim
|
||
if len(vec) < self._dim:
|
||
vec += [0.0] * (self._dim - len(vec))
|
||
else:
|
||
vec = vec[: self._dim]
|
||
|
||
norm = math.sqrt(sum(x * x for x in vec))
|
||
return [x / norm for x in vec] if norm > 0 else vec
|
||
|
||
def fit(self, texts: list[str]) -> None:
|
||
"""用文本列表拟合 TF-IDF 模型。"""
|
||
if not texts:
|
||
return
|
||
try:
|
||
self._tfidf.fit(texts)
|
||
except ValueError:
|
||
# TfidfVectorizer 在只有一个文档时可能报 ValueError,降级处理
|
||
self._tfidf.fit([t for t in texts if t])
|
||
self._fitted = True
|
||
|
||
def transform(self, text: str) -> list[float]:
|
||
"""对外暴露的 transform(调用 embed_sync)。"""
|
||
return self.embed_sync(text)
|
||
|
||
async def embed(self, text: str) -> list[float]:
|
||
return self.embed_sync(text)
|
||
|
||
|
||
def make_embedder(backend: str = "ollama", **kwargs) -> EmbedderBase:
|
||
"""
|
||
工厂函数:根据配置字符串创建嵌入适配器。
|
||
|
||
参数
|
||
----
|
||
backend : str
|
||
ollama | openai | dummy
|
||
**kwargs
|
||
透传给具体适配器构造函数。
|
||
|
||
返回
|
||
----
|
||
EmbedderBase 实例。
|
||
"""
|
||
backends = {
|
||
"ollama": OllamaEmbedder,
|
||
"openai": OpenAIEmbedder,
|
||
"dummy": DummyEmbedder,
|
||
}
|
||
cls = backends.get(backend.lower())
|
||
if cls is None:
|
||
raise ValueError(
|
||
f"Unknown embedder backend {backend!r}. "
|
||
f"Available: {list(backends.keys())}"
|
||
)
|
||
# DummyEmbedder 不接受 seed 参数(改为确定性哈希嵌入)
|
||
if cls is DummyEmbedder:
|
||
return cls(dimension=kwargs.pop("dimension", DummyEmbedder.DIM))
|
||
return cls(**kwargs)
|
||
|
||
|
||
# ================================================================
|
||
# 3. SQLite 存储层
|
||
# ================================================================
|
||
|
||
class MemoryStore:
|
||
"""
|
||
SQLite 持久化存储层。
|
||
|
||
表结构
|
||
------
|
||
memories (
|
||
id TEXT PRIMARY KEY,
|
||
text TEXT NOT NULL,
|
||
embedding BLOB NOT NULL, -- numpy 序列化 bytes
|
||
tags TEXT NOT NULL, -- JSON 列表
|
||
metadata TEXT NOT NULL, -- JSON 对象
|
||
created_at REAL NOT NULL
|
||
)
|
||
"""
|
||
|
||
_SCHEMA = """
|
||
CREATE TABLE IF NOT EXISTS memories (
|
||
id TEXT PRIMARY KEY,
|
||
text TEXT NOT NULL,
|
||
embedding BLOB NOT NULL,
|
||
tags TEXT NOT NULL DEFAULT '[]',
|
||
metadata TEXT NOT NULL DEFAULT '{}',
|
||
created_at REAL NOT NULL
|
||
);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_created_at ON memories(created_at DESC);
|
||
"""
|
||
|
||
def __init__(self, db_path: str = "data/MEMORY.db") -> None:
|
||
import sqlite3
|
||
self._db_path = db_path
|
||
self._conn: sqlite3.Connection | None = None
|
||
self._ensure_db()
|
||
|
||
def _ensure_db(self) -> None:
|
||
"""确保数据库和表已初始化。"""
|
||
import sqlite3
|
||
Path(self._db_path).parent.mkdir(parents=True, exist_ok=True)
|
||
# check_same_thread=False: 允许跨线程使用(配合 run_in_executor 场景)
|
||
conn = sqlite3.connect(self._db_path, check_same_thread=False)
|
||
conn.executescript(self._SCHEMA)
|
||
conn.commit()
|
||
self._conn = conn
|
||
logger.info("MemoryStore initialized: %s", self._db_path)
|
||
|
||
@property
|
||
def conn(self) -> sqlite3.Connection:
|
||
if self._conn is None:
|
||
self._ensure_db()
|
||
assert self._conn is not None
|
||
return self._conn
|
||
|
||
def insert(self, entry: MemoryEntry) -> MemoryEntry:
|
||
"""插入一条记忆,返回带 ID 的 MemoryEntry。"""
|
||
import sqlite3
|
||
entry_id = str(uuid.uuid4())
|
||
embedding_bytes = np.array(entry.embedding, dtype=np.float32).tobytes()
|
||
self.conn.execute(
|
||
"""
|
||
INSERT INTO memories
|
||
(id, text, embedding, tags, metadata, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
entry_id,
|
||
entry.text,
|
||
embedding_bytes,
|
||
json.dumps(entry.tags, ensure_ascii=False),
|
||
json.dumps(entry.metadata, ensure_ascii=False),
|
||
entry.created_at,
|
||
),
|
||
)
|
||
self.conn.commit()
|
||
entry.id = entry_id
|
||
return entry
|
||
|
||
def upsert(self, entry: MemoryEntry, entry_id: str) -> None:
|
||
"""插入或更新一条记忆(按 id)。"""
|
||
embedding_bytes = np.array(entry.embedding, dtype=np.float32).tobytes()
|
||
self.conn.execute(
|
||
"""
|
||
INSERT OR REPLACE INTO memories
|
||
(id, text, embedding, tags, metadata, created_at)
|
||
VALUES (?, ?, ?, ?, ?, ?)
|
||
""",
|
||
(
|
||
entry_id,
|
||
entry.text,
|
||
embedding_bytes,
|
||
json.dumps(entry.tags, ensure_ascii=False),
|
||
json.dumps(entry.metadata, ensure_ascii=False),
|
||
entry.created_at,
|
||
),
|
||
)
|
||
self.conn.commit()
|
||
|
||
def get_all(self) -> list[MemoryEntry]:
|
||
"""读取所有记忆。"""
|
||
import sqlite3
|
||
rows = self.conn.execute(
|
||
"SELECT id, text, embedding, tags, metadata, created_at "
|
||
"FROM memories ORDER BY created_at DESC"
|
||
).fetchall()
|
||
return [self._row_to_entry(row) for row in rows]
|
||
|
||
def get_all_ids_and_vectors(self) -> list[tuple[str, np.ndarray]]:
|
||
"""
|
||
读取所有记忆的 id 和 embedding(向量化检索专用)。
|
||
返回: list[(id, embedding_vector)]
|
||
"""
|
||
rows = self.conn.execute(
|
||
"SELECT id, embedding FROM memories"
|
||
).fetchall()
|
||
result = []
|
||
for row_id, emb_bytes in rows:
|
||
vec = np.frombuffer(emb_bytes, dtype=np.float32)
|
||
result.append((row_id, vec))
|
||
return result
|
||
|
||
def delete(self, entry_id: str) -> bool:
|
||
"""删除指定 id 的记忆。返回是否实际删除。"""
|
||
cur = self.conn.execute(
|
||
"DELETE FROM memories WHERE id = ?", (entry_id,)
|
||
)
|
||
self.conn.commit()
|
||
return cur.rowcount > 0
|
||
|
||
def count(self) -> int:
|
||
"""返回记忆总数。"""
|
||
cur = self.conn.execute("SELECT COUNT(*) FROM memories")
|
||
return cur.fetchone()[0]
|
||
|
||
@staticmethod
|
||
def _row_to_entry(row: tuple) -> MemoryEntry:
|
||
entry_id, text, emb_bytes, tags_str, meta_str, created_at = row
|
||
embedding = np.frombuffer(emb_bytes, dtype=np.float32).tolist()
|
||
entry = MemoryEntry(
|
||
text=text,
|
||
embedding=embedding,
|
||
tags=json.loads(tags_str),
|
||
metadata=json.loads(meta_str),
|
||
created_at=created_at,
|
||
)
|
||
entry.id = entry_id
|
||
return entry
|
||
|
||
def close(self) -> None:
|
||
if self._conn:
|
||
self._conn.close()
|
||
self._conn = None
|
||
|
||
|
||
# ================================================================
|
||
# 4. 向量检索引擎
|
||
# ================================================================
|
||
|
||
class VectorEngine:
|
||
"""
|
||
向量相似度检索引擎。
|
||
|
||
支持两种模式
|
||
numpy : 纯 NumPy 计算(小规模,< 10k 条)
|
||
faiss : Facebook FAISS(大规模,预留接口)
|
||
"""
|
||
|
||
def __init__(self, mode: str = "numpy") -> None:
|
||
self._mode = mode
|
||
self._ids: list[str] = []
|
||
self._matrix: np.ndarray | None = None # shape: (N, dim)
|
||
|
||
def build(self, entries: list[tuple[str, list[float]]]) -> None:
|
||
"""
|
||
从 id+embedding 对构建索引。
|
||
|
||
参数
|
||
----
|
||
entries : list[(id, embedding)]
|
||
id 列表和对应的 embedding 向量列表。
|
||
"""
|
||
if not entries:
|
||
self._ids = []
|
||
self._matrix = None
|
||
return
|
||
|
||
self._ids = [e[0] for e in entries]
|
||
vectors = np.array([e[1] for e in entries], dtype=np.float32)
|
||
|
||
if self._mode == "numpy":
|
||
# 归一化(NumPy 模式使用余弦相似度,归一化等价于余弦)
|
||
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||
norms[norms == 0] = 1.0
|
||
self._matrix = vectors / norms
|
||
elif self._mode == "faiss":
|
||
self._build_faiss(vectors)
|
||
else:
|
||
raise ValueError(f"Unknown vector engine mode: {self._mode!r}")
|
||
|
||
logger.info(
|
||
"VectorEngine built: mode=%s, entries=%d", self._mode, len(self._ids)
|
||
)
|
||
|
||
def update(self, entry: MemoryEntry) -> None:
|
||
"""
|
||
增量追加一条 entry 到索引(append-only)。
|
||
|
||
参数
|
||
----
|
||
entry : MemoryEntry
|
||
必须包含 id 和 embedding 向量(任意维度/归一化状态均可)。
|
||
"""
|
||
if self._matrix is None:
|
||
vec = np.array(entry.embedding, dtype=np.float32)
|
||
norm = np.linalg.norm(vec)
|
||
self._ids = [entry.id]
|
||
self._matrix = (vec / (norm if norm > 0 else 1.0)).reshape(1, -1)
|
||
else:
|
||
vec = np.array(entry.embedding, dtype=np.float32)
|
||
norm = np.linalg.norm(vec)
|
||
self._ids.append(entry.id)
|
||
self._matrix = np.vstack([
|
||
self._matrix,
|
||
(vec / (norm if norm > 0 else 1.0)).reshape(1, -1),
|
||
])
|
||
|
||
def _build_faiss(self, vectors: np.ndarray) -> None:
|
||
"""FAISS 索引构建。"""
|
||
try:
|
||
import faiss
|
||
except ImportError as exc:
|
||
raise ImportError(
|
||
"FAISS not installed. Run: pip install faiss-cpu"
|
||
) from exc
|
||
dim = vectors.shape[1]
|
||
self._matrix = vectors
|
||
# 简单的暴力索引(可选升级为 IVF、HNSW)
|
||
self._faiss_index = faiss.IndexFlatIP(dim)
|
||
# FAISS 要求 L2 内积需要归一化,Inner Product = Cosine(已归一化)
|
||
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||
norms[norms == 0] = 1.0
|
||
normalized = vectors / norms
|
||
self._faiss_index.add(normalized.astype(np.float32))
|
||
|
||
def search(
|
||
self,
|
||
query: list[float],
|
||
top_k: int = 5,
|
||
) -> list[tuple[str, float]]:
|
||
"""
|
||
余弦相似度检索。
|
||
|
||
参数
|
||
----
|
||
query : list[float]
|
||
查询向量。
|
||
top_k : int
|
||
返回前 k 条。
|
||
|
||
返回
|
||
----
|
||
list[(id, similarity_score)] 按相似度降序排列。
|
||
"""
|
||
if self._matrix is None or len(self._ids) == 0:
|
||
return []
|
||
|
||
q = np.array(query, dtype=np.float32)
|
||
q_norm = q / (np.linalg.norm(q) + 1e-9)
|
||
scores = self._matrix @ q_norm # shape: (N,)
|
||
|
||
if self._mode == "faiss":
|
||
assert hasattr(self, "_faiss_index"), "FAISS index not built"
|
||
import faiss
|
||
_, indices = self._faiss_index.search(
|
||
q_norm.reshape(1, -1).astype(np.float32), top_k
|
||
)
|
||
results = [(self._ids[i], float(scores[i])) for i in indices[0] if i >= 0]
|
||
else:
|
||
top_indices = np.argsort(scores)[::-1][:top_k]
|
||
results = [(self._ids[i], float(scores[i])) for i in top_indices]
|
||
|
||
return results
|
||
|
||
def cosine(self, a: list[float], b: list[float]) -> float:
|
||
"""
|
||
计算两条向量的余弦相似度。
|
||
|
||
sim(a,b) = (a·b) / (||a|| ||b||)
|
||
"""
|
||
return VectorEngine._cosine_impl(a, b)
|
||
|
||
@staticmethod
|
||
def _cosine_impl(a: list[float], b: list[float]) -> float:
|
||
"""纯 Python 实现(不依赖 numpy)。"""
|
||
dot = sum(x * y for x, y in zip(a, b))
|
||
norm_a = math.sqrt(sum(x * x for x in a))
|
||
norm_b = math.sqrt(sum(x * x for x in b))
|
||
if norm_a == 0 or norm_b == 0:
|
||
return 0.0
|
||
return dot / (norm_a * norm_b)
|
||
|
||
|
||
# ================================================================
|
||
# 5. 主记忆系统
|
||
# ================================================================
|
||
|
||
class VectorMemory:
|
||
"""
|
||
向量记忆系统 — RAG 工作流核心。
|
||
|
||
参数
|
||
----
|
||
storage_path : str
|
||
SQLite 数据库路径,默认 data/MEMORY.db。
|
||
embedder_backend : str
|
||
嵌入后端:ollama | openai | dummy。
|
||
embedder_kwargs : dict
|
||
透传给嵌入适配器的参数。
|
||
vector_mode : str
|
||
向量检索模式:numpy | faiss。
|
||
|
||
RAG 工作流
|
||
----------
|
||
1. 接收输入 Query
|
||
2. 将 Query 转换为 Embedding(通过 embedder)
|
||
3. 在 Memory 库中检索 Top-k 记忆片段
|
||
4. 将检索结果拼接入 System Prompt 上下文
|
||
5. 返回拼装好的上下文字符串
|
||
|
||
示例
|
||
----
|
||
>>> mem = VectorMemory(embedder_backend="dummy", seed=42)
|
||
>>> await mem.add("用户今天说想喝奶茶", tags=["偏好", "饮食"])
|
||
>>> results = await mem.search("用户最近聊过什么喝的", top_k=3)
|
||
>>> context = await mem.retrieve_context("用户对奶茶的态度")
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
storage_path: str = "data/MEMORY.db",
|
||
embedder_backend: str = "ollama",
|
||
embedder_kwargs: dict | None = None,
|
||
vector_mode: str = "numpy",
|
||
) -> None:
|
||
self._store = MemoryStore(storage_path)
|
||
self._vector = VectorEngine(mode=vector_mode)
|
||
self._embedder = make_embedder(embedder_backend, **(embedder_kwargs or {}))
|
||
self._index_ready = False
|
||
|
||
logger.info(
|
||
"VectorMemory initialized | embedder=%s | storage=%s | vector=%s",
|
||
embedder_backend,
|
||
storage_path,
|
||
vector_mode,
|
||
)
|
||
|
||
# ----------------------------------------------------------------
|
||
# 生命周期
|
||
# ----------------------------------------------------------------
|
||
|
||
async def initialize(self) -> None:
|
||
"""
|
||
异步初始化:构建向量索引。
|
||
应用启动时调用一次。
|
||
"""
|
||
import asyncio
|
||
loop = asyncio.get_running_loop()
|
||
await loop.run_in_executor(None, self._rebuild_index)
|
||
|
||
def _rebuild_index(self) -> None:
|
||
"""从数据库加载所有 embedding 构建向量索引(同步,在线程池执行)。"""
|
||
import logging
|
||
_logger = logging.getLogger(__name__)
|
||
entries = self._store.get_all_ids_and_vectors()
|
||
self._vector.build(entries)
|
||
self._index_ready = True
|
||
_logger.info("Vector index rebuilt: %d entries", len(entries))
|
||
|
||
# ----------------------------------------------------------------
|
||
# 存取操作
|
||
# ----------------------------------------------------------------
|
||
|
||
async def add(
|
||
self,
|
||
text: str,
|
||
tags: list[str] | None = None,
|
||
metadata: dict | None = None,
|
||
) -> str:
|
||
"""
|
||
添加一条新记忆。
|
||
|
||
参数
|
||
----
|
||
text : str
|
||
记忆文本。
|
||
tags : list[str] | None
|
||
标签列表(如 ["偏好", "饮食"])。
|
||
metadata : dict | None
|
||
附加元数据(如 {"source": "user_input", "channel": "feishu"})。
|
||
|
||
返回
|
||
----
|
||
str
|
||
新记忆的 UUID id。
|
||
"""
|
||
# 生成向量(自动处理 TF-IDF lazy fit)
|
||
embedding = await self._embed_text(text)
|
||
entry = MemoryEntry(
|
||
id=None, # 先占位,存储后再填充
|
||
text=text,
|
||
embedding=embedding,
|
||
tags=tags or [],
|
||
metadata=metadata or {},
|
||
)
|
||
# 存储(返回带 ID 的 entry)
|
||
entry = self._store.insert(entry)
|
||
entry_id = entry.id
|
||
|
||
# 更新向量索引
|
||
if not self._index_ready:
|
||
# 在当前线程同步执行(add 已在子线程/async 上下文中)
|
||
self._rebuild_index()
|
||
else:
|
||
self._vector.update(entry)
|
||
|
||
logger.debug("[Memory] Added: id=%s — %s", entry_id, text[:40])
|
||
return entry_id
|
||
|
||
async def _embed_text(self, text: str) -> list[float]:
|
||
"""生成文本向量;若使用 DummyEmbedder 且未 fit,先 fit 再 embed。"""
|
||
embedder = self._embedder
|
||
if hasattr(embedder, "_fitted") and not getattr(embedder, "_fitted", True):
|
||
# 懒拟合:用现有语料库 fit TF-IDF
|
||
texts = [row[1] for row in self._store.get_all()]
|
||
if texts:
|
||
embedder.fit(texts + [text])
|
||
else:
|
||
embedder.fit([text])
|
||
return await embedder.embed(text)
|
||
|
||
async def add_sync(
|
||
self,
|
||
text: str,
|
||
tags: list[str] | None = None,
|
||
metadata: dict | None = None,
|
||
) -> str:
|
||
"""同步版本的 add(供非 async 上下文调用)。"""
|
||
import asyncio
|
||
return await self.add(text, tags, metadata)
|
||
|
||
def search(
|
||
self,
|
||
query: str,
|
||
top_k: int = 5,
|
||
min_similarity: float = 0.0,
|
||
) -> list[dict]:
|
||
"""
|
||
语义检索。
|
||
|
||
参数
|
||
----
|
||
query : str
|
||
查询文本。
|
||
top_k : int
|
||
返回前 k 条结果。
|
||
min_similarity : float
|
||
最小相似度阈值(0.0 ~ 1.0)。
|
||
|
||
返回
|
||
----
|
||
list[dict] 每项包含 id, text, similarity, tags, metadata, created_at。
|
||
"""
|
||
import asyncio
|
||
import concurrent.futures
|
||
# 统一在子线程运行(避免 loop 冲突)
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||
return pool.submit(
|
||
asyncio.run, self._search_async(query, top_k, min_similarity)
|
||
).result()
|
||
|
||
async def _search_async(
|
||
self,
|
||
query: str,
|
||
top_k: int = 5,
|
||
min_similarity: float = 0.0,
|
||
) -> list[dict]:
|
||
if not self._index_ready:
|
||
await self.initialize()
|
||
|
||
# 1. Query → Embedding
|
||
query_vec = await self._embedder.embed(query)
|
||
|
||
# 2. 向量检索
|
||
top_ids = self._vector.search(query_vec, top_k=top_k * 3) # 多取一些,后面过滤
|
||
|
||
# 3. 补充详情(从 SQLite 读取完整数据)
|
||
results = []
|
||
seen_texts: set[str] = set()
|
||
for entry_id, score in top_ids:
|
||
if score < min_similarity:
|
||
continue
|
||
rows = self._store.conn.execute(
|
||
"SELECT id, text, tags, metadata, created_at FROM memories WHERE id = ?",
|
||
(entry_id,),
|
||
).fetchall()
|
||
for row in rows:
|
||
eid, text, tags_str, meta_str, created_at = row
|
||
text_normalized = text.strip()
|
||
if text_normalized in seen_texts:
|
||
continue
|
||
seen_texts.add(text_normalized)
|
||
results.append({
|
||
"id": eid,
|
||
"text": text,
|
||
"similarity": round(score, 4),
|
||
"tags": json.loads(tags_str),
|
||
"metadata": json.loads(meta_str),
|
||
"created_at": created_at,
|
||
})
|
||
if len(results) >= top_k:
|
||
break
|
||
|
||
# 4. 按相似度降序
|
||
results.sort(key=lambda x: x["similarity"], reverse=True)
|
||
return results[:top_k]
|
||
|
||
# ----------------------------------------------------------------
|
||
# RAG 核心
|
||
# ----------------------------------------------------------------
|
||
|
||
async def retrieve_context(
|
||
self,
|
||
query: str,
|
||
top_k: int = 3,
|
||
min_similarity: float = 0.3,
|
||
include_metadata: bool = False,
|
||
) -> str:
|
||
"""
|
||
RAG 上下文检索。
|
||
|
||
将检索结果格式化为字符串,拼入 LLM 的 System Prompt。
|
||
|
||
返回格式
|
||
--------
|
||
```
|
||
[相关记忆 #1] (相似度: 0.87, 标签: 偏好/饮食, 时间: 2026-05-02)
|
||
用户今天下午说想喝奶茶。
|
||
---
|
||
[相关记忆 #2] (相似度: 0.72, 标签: 习惯, 时间: 2026-04-30)
|
||
用户通常在下午3点喝咖啡。
|
||
---
|
||
```
|
||
"""
|
||
results = await self._search_async(query, top_k, min_similarity)
|
||
if not results:
|
||
return ""
|
||
|
||
import datetime
|
||
lines = []
|
||
for i, item in enumerate(results, 1):
|
||
ts = datetime.datetime.fromtimestamp(item["created_at"])
|
||
ts_str = ts.strftime("%Y-%m-%d %H:%M")
|
||
tags_str = "/".join(item["tags"]) if item["tags"] else "无标签"
|
||
meta_str = ""
|
||
if include_metadata and item["metadata"]:
|
||
meta_str = f", 来源: {item['metadata'].get('source', 'unknown')}"
|
||
|
||
lines.append(
|
||
f"[相关记忆 #{i}] "
|
||
f"(相似度: {item['similarity']:.2f}, "
|
||
f"标签: {tags_str}, "
|
||
f"时间: {ts_str}{meta_str})\n"
|
||
f"{item['text']}"
|
||
)
|
||
|
||
return "---\n".join(lines)
|
||
|
||
# ----------------------------------------------------------------
|
||
# 工具方法
|
||
# ----------------------------------------------------------------
|
||
|
||
def cosine_similarity(self, a: list[float], b: list[float]) -> float:
|
||
"""公开的 Cosine 相似度工具方法。"""
|
||
return VectorEngine._cosine_impl(a, b)
|
||
|
||
def get_user_profile(self) -> dict:
|
||
"""
|
||
从记忆中构建用户画像摘要。
|
||
策略:按标签聚合统计 + 最近 N 条记录。
|
||
"""
|
||
import datetime
|
||
all_entries = self._store.get_all()
|
||
if not all_entries:
|
||
return {"summary": "暂无记忆数据", "total": 0}
|
||
|
||
# 标签统计
|
||
tag_counts: dict[str, int] = {}
|
||
for e in all_entries:
|
||
for tag in e.tags:
|
||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||
|
||
# 最近 5 条
|
||
recent = [
|
||
{
|
||
"text": e.text[:100],
|
||
"tags": e.tags,
|
||
"time": datetime.datetime.fromtimestamp(e.created_at).strftime(
|
||
"%Y-%m-%d %H:%M"
|
||
),
|
||
}
|
||
for e in all_entries[:5]
|
||
]
|
||
|
||
return {
|
||
"total": len(all_entries),
|
||
"tag_counts": dict(sorted(tag_counts.items(), key=lambda x: -x[1])[:10]),
|
||
"recent": recent,
|
||
}
|
||
|
||
async def get_recent_memories(
|
||
self,
|
||
limit: int = 10,
|
||
tags_filter: list[str] | None = None,
|
||
) -> list[dict]:
|
||
"""返回最近的记忆条目。"""
|
||
import datetime
|
||
if tags_filter:
|
||
tag_filter_sql = " OR ".join(
|
||
f"tags LIKE ?" for _ in tags_filter
|
||
)
|
||
like_args = [f'%"{t}"%' for t in tags_filter]
|
||
rows = self._store.conn.execute(
|
||
f"SELECT id, text, tags, metadata, created_at "
|
||
f"FROM memories WHERE {tag_filter_sql} "
|
||
f"ORDER BY created_at DESC LIMIT ?",
|
||
like_args + [limit],
|
||
).fetchall()
|
||
else:
|
||
rows = self._store.conn.execute(
|
||
"SELECT id, text, tags, metadata, created_at "
|
||
"FROM memories ORDER BY created_at DESC LIMIT ?",
|
||
(limit,),
|
||
).fetchall()
|
||
|
||
results = []
|
||
for row in rows:
|
||
eid, text, tags_str, meta_str, created_at = row
|
||
results.append({
|
||
"id": eid,
|
||
"text": text,
|
||
"tags": json.loads(tags_str),
|
||
"metadata": json.loads(meta_str),
|
||
"created_at": datetime.datetime.fromtimestamp(created_at).strftime(
|
||
"%Y-%m-%d %H:%M:%S"
|
||
),
|
||
})
|
||
return results
|
||
|
||
def delete(self, entry_id: str) -> bool:
|
||
"""删除指定记忆。"""
|
||
ok = self._store.delete(entry_id)
|
||
if ok:
|
||
self._rebuild_index()
|
||
return ok
|
||
|
||
def count(self) -> int:
|
||
"""返回记忆总数。"""
|
||
return self._store.count()
|
||
|
||
def close(self) -> None:
|
||
"""关闭数据库连接。"""
|
||
self._store.close()
|