Files
EzVibe/agent/test_memory.py
e2hang 2a844e83a8 Initial commit: EzVibe AI 桌宠系统
- EmotionEngine: 5状态马尔可夫情绪机 + 蒙特卡洛转移
- VectorMemory: TF-IDF向量记忆 + SQLite持久化 + RAG检索
- AgentBrain: Ollama/OpenAI/Dummy三后端LLM
- BehaviorScheduler: 优先级/冷却/活跃度调度
- FastAPI服务器 + WebSocket实时推送
- perception: 键鼠监控 + 屏幕截图
- ui/pet_window: PySide6桌宠窗口 + 像素动画
- assets/pet: 5情绪各2帧像素艺术资源
2026-05-01 23:26:43 +08:00

564 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
memory.py 单元测试
运行: python -m agent.test_memory
"""
import asyncio
import math
import os
import shutil
import sys
import tempfile
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from agent.memory import (
MemoryEntry,
MemoryStore,
VectorEngine,
VectorMemory,
make_embedder,
DummyEmbedder,
OllamaEmbedder,
OpenAIEmbedder,
)
# ================================================================
# 测试工具
# ================================================================
_TEST_DB_DIR: Path | None = None
def _setup_test_dir():
global _TEST_DB_DIR
_TEST_DB_DIR = Path(tempfile.mkdtemp(prefix="ezvibe_test_"))
return str(_TEST_DB_DIR / "memory.db")
def _teardown_test_dir():
global _TEST_DB_DIR
if _TEST_DB_DIR and _TEST_DB_DIR.exists():
shutil.rmtree(_TEST_DB_DIR, ignore_errors=True)
# ================================================================
# MemoryEntry
# ================================================================
def test_memory_entry_to_from_dict():
"""测试 MemoryEntry 序列化往返。"""
entry = MemoryEntry(
text="用户今天说想喝奶茶",
embedding=[0.1, 0.2, 0.3, 0.4],
tags=["偏好", "饮食"],
metadata={"source": "feishu"},
created_at=1700000000.0,
)
d = entry.to_dict()
restored = MemoryEntry.from_dict(d)
assert restored.text == entry.text
assert restored.embedding == entry.embedding
assert restored.tags == entry.tags
assert restored.metadata == entry.metadata
assert restored.created_at == entry.created_at
print("[PASS] test_memory_entry_to_from_dict")
def test_memory_entry_default_fields():
"""测试默认字段。"""
entry = MemoryEntry(text="test", embedding=[0.1, 0.2, 0.3])
assert entry.tags == []
assert entry.metadata == {}
assert entry.created_at > 0
print("[PASS] test_memory_entry_default_fields")
# ================================================================
# MemoryStore
# ================================================================
def test_memory_store_insert_and_get_all():
"""测试插入和全量读取。"""
db = _setup_test_dir()
try:
store = MemoryStore(db)
entry = MemoryEntry(
text="第一次测试",
embedding=[0.1] * 128,
tags=["test"],
metadata={"x": 1},
)
store.insert(entry)
all_entries = store.get_all()
assert len(all_entries) == 1
assert all_entries[0].text == "第一次测试"
assert all_entries[0].tags == ["test"]
store.close()
print("[PASS] test_memory_store_insert_and_get_all")
finally:
_teardown_test_dir()
def test_memory_store_count():
"""测试计数。"""
db = _setup_test_dir()
try:
store = MemoryStore(db)
assert store.count() == 0
for i in range(5):
entry = MemoryEntry(
text=f"测试 {i}",
embedding=[0.1] * 128,
)
store.insert(entry)
assert store.count() == 5
store.close()
print("[PASS] test_memory_store_count")
finally:
_teardown_test_dir()
def test_memory_store_delete():
"""测试删除。"""
db = _setup_test_dir()
try:
store = MemoryStore(db)
entry = MemoryEntry(text="待删除", embedding=[0.1] * 128)
store.insert(entry)
all_ids = store.get_all_ids_and_vectors()
assert len(all_ids) == 1
eid = all_ids[0][0]
ok = store.delete(eid)
assert ok is True
assert store.count() == 0
# 重复删除应返回 False
ok2 = store.delete(eid)
assert ok2 is False
store.close()
print("[PASS] test_memory_store_delete")
finally:
_teardown_test_dir()
def test_memory_store_upsert():
"""测试 upsert。"""
db = _setup_test_dir()
try:
store = MemoryStore(db)
entry = MemoryEntry(
text="原始文本",
embedding=[0.5] * 128,
tags=["v1"],
)
store.upsert(entry, "fixed-id-001")
# 更新
entry2 = MemoryEntry(
text="更新文本",
embedding=[0.9] * 128,
tags=["v2"],
)
store.upsert(entry2, "fixed-id-001")
all_entries = store.get_all()
assert len(all_entries) == 1
assert all_entries[0].text == "更新文本"
store.close()
print("[PASS] test_memory_store_upsert")
finally:
_teardown_test_dir()
# ================================================================
# VectorEngine
# ================================================================
def test_vector_engine_build_empty():
"""测试空索引。"""
ve = VectorEngine(mode="numpy")
ve.build([])
assert ve._ids == []
assert ve._matrix is None
results = ve.search([0.1] * 128)
assert results == []
print("[PASS] test_vector_engine_build_empty")
def test_vector_engine_build_and_search():
"""测试索引构建和检索。"""
ve = VectorEngine(mode="numpy")
entries = [
("a", [1.0, 0.0, 0.0]), # 沿 x 轴
("b", [0.0, 1.0, 0.0]), # 沿 y 轴
("c", [0.0, 0.0, 1.0]), # 沿 z 轴
]
ve.build(entries)
# 查询 x 轴方向(应命中 a
results = ve.search([1.0, 0.0, 0.0], top_k=3)
assert results[0][0] == "a"
assert abs(results[0][1] - 1.0) < 1e-6
# 查询 (1,1,0)(应 top: a ≈ b > c
results = ve.search([1.0, 1.0, 0.0], top_k=2)
ids = [r[0] for r in results]
assert "a" in ids and "b" in ids
print("[PASS] test_vector_engine_build_and_search")
def test_cosine_similarity():
"""测试余弦相似度计算。"""
# 完全相同 → 1.0
a = [1.0, 2.0, 3.0]
b = [1.0, 2.0, 3.0]
assert abs(VectorEngine._cosine_impl(a, b) - 1.0) < 1e-9
# 正交 → 0.0
c = [0.0, 1.0, 0.0]
d = [1.0, 0.0, 0.0]
assert abs(VectorEngine._cosine_impl(c, d)) < 1e-9
# 相反方向 → -1.0
e = [1.0, 0.0]
f = [-1.0, 0.0]
assert abs(VectorEngine._cosine_impl(e, f) + 1.0) < 1e-9
# 45 度1/√2 ≈ 0.707
g = [1.0, 0.0]
h = [1.0, 1.0]
sim = VectorEngine._cosine_impl(g, h)
expected = 1.0 / math.sqrt(2)
assert abs(sim - expected) < 1e-6
print("[PASS] test_cosine_similarity")
def test_cosine_similarity_zero_vector():
"""测试零向量返回 0.0。"""
zero = [0.0] * 5
nonzero = [1.0, 0.0, 0.0, 0.0, 0.0]
assert VectorEngine._cosine_impl(zero, nonzero) == 0.0
assert VectorEngine._cosine_impl(zero, zero) == 0.0
print("[PASS] test_cosine_similarity_zero_vector")
# ================================================================
# DummyEmbedder
# ================================================================
def test_dummy_embedder():
"""测试伪嵌入器。"""
emb = DummyEmbedder(dimension=64)
assert emb.dimension == 64
# fit 前:未 fit 返回零向量
v_zero = emb.embed_sync("hello world")
assert len(v_zero) == 64
assert all(x == 0.0 for x in v_zero)
# fit 后:相同文本 → 相同向量
emb.fit(["apple banana cherry", "dog elephant fruit"])
v1 = emb.embed_sync("apple banana cherry")
v2 = emb.embed_sync("apple banana cherry")
assert len(v1) == 64
assert v1 == v2 # 相同文本 → 相同向量
# 验证归一化
norm = math.sqrt(sum(x * x for x in v1))
assert abs(norm - 1.0) < 1e-6
# 不同文本 → 不同向量
v3 = emb.embed_sync("dog elephant fruit")
assert v1 != v3
# 新文本vocabulary 外仍是有效向量TF-IDF 对未知词给 0
v4 = emb.embed_sync("totally unknown xyz123")
assert len(v4) == 64
assert isinstance(v4[0], float)
print("[PASS] test_dummy_embedder")
def test_dummy_embedder_async():
"""测试伪嵌入器异步接口。"""
async def _run():
emb = DummyEmbedder()
emb.fit(["async test corpus"])
v = await emb.embed("async test")
assert len(v) == DummyEmbedder.DIM
return True
assert asyncio.run(_run())
print("[PASS] test_dummy_embedder_async")
# ================================================================
# make_embedder
# ================================================================
def test_make_embedder_dummy():
"""测试工厂函数 - dummy。"""
emb = make_embedder("dummy")
assert isinstance(emb, DummyEmbedder)
print("[PASS] test_make_embedder_dummy")
def test_make_embedder_unknown():
"""测试工厂函数 - 未知后端。"""
try:
make_embedder("nonexistent_backend")
assert False, "Should raise ValueError"
except ValueError as e:
assert "nonexistent_backend" in str(e)
print("[PASS] test_make_embedder_unknown")
# ================================================================
# VectorMemory — 集成测试
# ================================================================
def test_vector_memory_add_and_search():
"""测试添加记忆和语义检索。"""
async def _run():
mem = VectorMemory(
storage_path=_setup_test_dir(),
embedder_backend="dummy",
embedder_kwargs={"dimension": 128},
)
await mem.initialize()
await mem.add(
"用户今天说想喝奶茶",
tags=["饮食", "偏好"],
metadata={"source": "feishu"},
)
await mem.add(
"用户最近在学习 Python 编程",
tags=["学习", "编程"],
metadata={"source": "feishu"},
)
await mem.add(
"用户表示最近工作压力很大",
tags=["情绪", "工作"],
metadata={"source": "feishu"},
)
results = await mem._search_async("喝奶茶", top_k=3, min_similarity=0.0)
assert len(results) >= 1
assert "奶茶" in results[0]["text"] # "喝奶茶"→含"奶茶"的记忆
results2 = await mem._search_async("编程学习", top_k=2)
assert any("Python" in r["text"] for r in results2)
results3 = await mem._search_async("心情怎么样", top_k=2)
assert any("压力" in r["text"] or "情绪" in str(r["tags"]) for r in results3)
mem.close()
_teardown_test_dir()
asyncio.run(_run())
print("[PASS] test_vector_memory_add_and_search")
def test_vector_memory_retrieve_context():
"""测试 RAG 上下文检索。"""
async def _run():
mem = VectorMemory(
storage_path=_setup_test_dir(),
embedder_backend="dummy",
embedder_kwargs={"dimension": 128},
)
await mem.initialize()
await mem.add("用户今天下午说想喝奶茶,不加糖", tags=["饮食"])
await mem.add("用户今天早上喝了美式咖啡", tags=["饮食"])
await mem.add("用户最近在学习 Rust 编程语言", tags=["学习"])
context = await mem.retrieve_context(
"用户对喝的东西有什么偏好",
top_k=2,
min_similarity=0.1,
)
assert "奶茶" in context or "咖啡" in context
assert "相关记忆" in context
assert "相似度" in context
mem.close()
_teardown_test_dir()
asyncio.run(_run())
print("[PASS] test_vector_memory_retrieve_context")
def test_vector_memory_get_user_profile():
"""测试用户画像生成。"""
async def _run():
mem = VectorMemory(
storage_path=_setup_test_dir(),
embedder_backend="dummy",
embedder_kwargs={"dimension": 128},
)
await mem.initialize()
await mem.add("文本1", tags=["A", "B"])
await mem.add("文本2", tags=["A", "C"])
await mem.add("文本3", tags=["B", "D"])
profile = mem.get_user_profile()
assert profile["total"] == 3
assert profile["tag_counts"]["A"] == 2
assert profile["tag_counts"]["B"] == 2
mem.close()
_teardown_test_dir()
asyncio.run(_run())
print("[PASS] test_vector_memory_get_user_profile")
def test_vector_memory_get_recent_memories():
"""测试获取最近记忆。"""
async def _run():
mem = VectorMemory(
storage_path=_setup_test_dir(),
embedder_backend="dummy",
embedder_kwargs={"dimension": 128},
)
await mem.initialize()
for i in range(5):
await mem.add(f"记忆条目 {i}", tags=[f"标签{i}"])
recent = await mem.get_recent_memories(limit=3)
assert len(recent) == 3
assert recent[0]["text"] == "记忆条目 4"
recent_filtered = await mem.get_recent_memories(
limit=10, tags_filter=["标签0", "标签1"]
)
assert len(recent_filtered) == 2
mem.close()
_teardown_test_dir()
asyncio.run(_run())
print("[PASS] test_vector_memory_get_recent_memories")
def test_vector_memory_delete():
"""测试删除记忆。"""
async def _run():
mem = VectorMemory(
storage_path=_setup_test_dir(),
embedder_backend="dummy",
embedder_kwargs={"dimension": 128},
)
await mem.initialize()
eid = await mem.add("待删除记忆")
assert mem.count() == 1
ok = mem.delete(eid)
assert ok is True
assert mem.count() == 0
mem.close()
_teardown_test_dir()
asyncio.run(_run())
print("[PASS] test_vector_memory_delete")
def test_vector_memory_sync_search():
"""测试同步 search 方法。"""
async def _run():
mem = VectorMemory(
storage_path=_setup_test_dir(),
embedder_backend="dummy",
embedder_kwargs={"dimension": 128},
)
await mem.initialize()
await mem.add("测试文本", tags=["test"])
# 同步 search内部自动检测 loop 状态)
results = mem.search("测试", top_k=3)
assert len(results) == 1
assert results[0]["text"] == "测试文本"
mem.close()
_teardown_test_dir()
asyncio.run(_run())
print("[PASS] test_vector_memory_sync_search")
# ================================================================
# 运行所有测试
# ================================================================
def run_all():
tests = [
# MemoryEntry
test_memory_entry_to_from_dict,
test_memory_entry_default_fields,
# MemoryStore
test_memory_store_insert_and_get_all,
test_memory_store_count,
test_memory_store_delete,
test_memory_store_upsert,
# VectorEngine
test_vector_engine_build_empty,
test_vector_engine_build_and_search,
test_cosine_similarity,
test_cosine_similarity_zero_vector,
# Embedder
test_dummy_embedder,
test_dummy_embedder_async,
test_make_embedder_dummy,
test_make_embedder_unknown,
# VectorMemory 集成测试
test_vector_memory_add_and_search,
test_vector_memory_retrieve_context,
test_vector_memory_get_user_profile,
test_vector_memory_get_recent_memories,
test_vector_memory_delete,
test_vector_memory_sync_search,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except AssertionError as e:
print(f"[FAIL] {test.__name__}: {e}")
failed += 1
except Exception as e:
print(f"[ERROR] {test.__name__}: {type(e).__name__}: {e}")
failed += 1
print(f"\n{'='*50}")
print(f"测试结果: {passed}/{passed+failed} 通过", end="")
if failed:
print(f", {failed} 失败")
else:
print(", 全部通过!")
print(f"{'='*50}")
return failed == 0
if __name__ == "__main__":
success = run_all()
# 清理
_teardown_test_dir()
sys.exit(0 if success else 1)