#!/usr/bin/env python3 """ memory.py 单元测试 运行: python -m agent.test_memory """ import asyncio import math import os import shutil import sys import tempfile from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from agent.memory import ( MemoryEntry, MemoryStore, VectorEngine, VectorMemory, make_embedder, DummyEmbedder, OllamaEmbedder, OpenAIEmbedder, ) # ================================================================ # 测试工具 # ================================================================ _TEST_DB_DIR: Path | None = None def _setup_test_dir(): global _TEST_DB_DIR _TEST_DB_DIR = Path(tempfile.mkdtemp(prefix="ezvibe_test_")) return str(_TEST_DB_DIR / "memory.db") def _teardown_test_dir(): global _TEST_DB_DIR if _TEST_DB_DIR and _TEST_DB_DIR.exists(): shutil.rmtree(_TEST_DB_DIR, ignore_errors=True) # ================================================================ # MemoryEntry # ================================================================ def test_memory_entry_to_from_dict(): """测试 MemoryEntry 序列化往返。""" entry = MemoryEntry( text="用户今天说想喝奶茶", embedding=[0.1, 0.2, 0.3, 0.4], tags=["偏好", "饮食"], metadata={"source": "feishu"}, created_at=1700000000.0, ) d = entry.to_dict() restored = MemoryEntry.from_dict(d) assert restored.text == entry.text assert restored.embedding == entry.embedding assert restored.tags == entry.tags assert restored.metadata == entry.metadata assert restored.created_at == entry.created_at print("[PASS] test_memory_entry_to_from_dict") def test_memory_entry_default_fields(): """测试默认字段。""" entry = MemoryEntry(text="test", embedding=[0.1, 0.2, 0.3]) assert entry.tags == [] assert entry.metadata == {} assert entry.created_at > 0 print("[PASS] test_memory_entry_default_fields") # ================================================================ # MemoryStore # ================================================================ def test_memory_store_insert_and_get_all(): """测试插入和全量读取。""" db = _setup_test_dir() try: store = MemoryStore(db) entry = MemoryEntry( text="第一次测试", embedding=[0.1] * 128, tags=["test"], metadata={"x": 1}, ) store.insert(entry) all_entries = store.get_all() assert len(all_entries) == 1 assert all_entries[0].text == "第一次测试" assert all_entries[0].tags == ["test"] store.close() print("[PASS] test_memory_store_insert_and_get_all") finally: _teardown_test_dir() def test_memory_store_count(): """测试计数。""" db = _setup_test_dir() try: store = MemoryStore(db) assert store.count() == 0 for i in range(5): entry = MemoryEntry( text=f"测试 {i}", embedding=[0.1] * 128, ) store.insert(entry) assert store.count() == 5 store.close() print("[PASS] test_memory_store_count") finally: _teardown_test_dir() def test_memory_store_delete(): """测试删除。""" db = _setup_test_dir() try: store = MemoryStore(db) entry = MemoryEntry(text="待删除", embedding=[0.1] * 128) store.insert(entry) all_ids = store.get_all_ids_and_vectors() assert len(all_ids) == 1 eid = all_ids[0][0] ok = store.delete(eid) assert ok is True assert store.count() == 0 # 重复删除应返回 False ok2 = store.delete(eid) assert ok2 is False store.close() print("[PASS] test_memory_store_delete") finally: _teardown_test_dir() def test_memory_store_upsert(): """测试 upsert。""" db = _setup_test_dir() try: store = MemoryStore(db) entry = MemoryEntry( text="原始文本", embedding=[0.5] * 128, tags=["v1"], ) store.upsert(entry, "fixed-id-001") # 更新 entry2 = MemoryEntry( text="更新文本", embedding=[0.9] * 128, tags=["v2"], ) store.upsert(entry2, "fixed-id-001") all_entries = store.get_all() assert len(all_entries) == 1 assert all_entries[0].text == "更新文本" store.close() print("[PASS] test_memory_store_upsert") finally: _teardown_test_dir() # ================================================================ # VectorEngine # ================================================================ def test_vector_engine_build_empty(): """测试空索引。""" ve = VectorEngine(mode="numpy") ve.build([]) assert ve._ids == [] assert ve._matrix is None results = ve.search([0.1] * 128) assert results == [] print("[PASS] test_vector_engine_build_empty") def test_vector_engine_build_and_search(): """测试索引构建和检索。""" ve = VectorEngine(mode="numpy") entries = [ ("a", [1.0, 0.0, 0.0]), # 沿 x 轴 ("b", [0.0, 1.0, 0.0]), # 沿 y 轴 ("c", [0.0, 0.0, 1.0]), # 沿 z 轴 ] ve.build(entries) # 查询 x 轴方向(应命中 a) results = ve.search([1.0, 0.0, 0.0], top_k=3) assert results[0][0] == "a" assert abs(results[0][1] - 1.0) < 1e-6 # 查询 (1,1,0)(应 top: a ≈ b > c) results = ve.search([1.0, 1.0, 0.0], top_k=2) ids = [r[0] for r in results] assert "a" in ids and "b" in ids print("[PASS] test_vector_engine_build_and_search") def test_cosine_similarity(): """测试余弦相似度计算。""" # 完全相同 → 1.0 a = [1.0, 2.0, 3.0] b = [1.0, 2.0, 3.0] assert abs(VectorEngine._cosine_impl(a, b) - 1.0) < 1e-9 # 正交 → 0.0 c = [0.0, 1.0, 0.0] d = [1.0, 0.0, 0.0] assert abs(VectorEngine._cosine_impl(c, d)) < 1e-9 # 相反方向 → -1.0 e = [1.0, 0.0] f = [-1.0, 0.0] assert abs(VectorEngine._cosine_impl(e, f) + 1.0) < 1e-9 # 45 度(1/√2 ≈ 0.707) g = [1.0, 0.0] h = [1.0, 1.0] sim = VectorEngine._cosine_impl(g, h) expected = 1.0 / math.sqrt(2) assert abs(sim - expected) < 1e-6 print("[PASS] test_cosine_similarity") def test_cosine_similarity_zero_vector(): """测试零向量返回 0.0。""" zero = [0.0] * 5 nonzero = [1.0, 0.0, 0.0, 0.0, 0.0] assert VectorEngine._cosine_impl(zero, nonzero) == 0.0 assert VectorEngine._cosine_impl(zero, zero) == 0.0 print("[PASS] test_cosine_similarity_zero_vector") # ================================================================ # DummyEmbedder # ================================================================ def test_dummy_embedder(): """测试伪嵌入器。""" emb = DummyEmbedder(dimension=64) assert emb.dimension == 64 # fit 前:未 fit 返回零向量 v_zero = emb.embed_sync("hello world") assert len(v_zero) == 64 assert all(x == 0.0 for x in v_zero) # fit 后:相同文本 → 相同向量 emb.fit(["apple banana cherry", "dog elephant fruit"]) v1 = emb.embed_sync("apple banana cherry") v2 = emb.embed_sync("apple banana cherry") assert len(v1) == 64 assert v1 == v2 # 相同文本 → 相同向量 # 验证归一化 norm = math.sqrt(sum(x * x for x in v1)) assert abs(norm - 1.0) < 1e-6 # 不同文本 → 不同向量 v3 = emb.embed_sync("dog elephant fruit") assert v1 != v3 # 新文本(vocabulary 外):仍是有效向量(TF-IDF 对未知词给 0) v4 = emb.embed_sync("totally unknown xyz123") assert len(v4) == 64 assert isinstance(v4[0], float) print("[PASS] test_dummy_embedder") def test_dummy_embedder_async(): """测试伪嵌入器异步接口。""" async def _run(): emb = DummyEmbedder() emb.fit(["async test corpus"]) v = await emb.embed("async test") assert len(v) == DummyEmbedder.DIM return True assert asyncio.run(_run()) print("[PASS] test_dummy_embedder_async") # ================================================================ # make_embedder # ================================================================ def test_make_embedder_dummy(): """测试工厂函数 - dummy。""" emb = make_embedder("dummy") assert isinstance(emb, DummyEmbedder) print("[PASS] test_make_embedder_dummy") def test_make_embedder_unknown(): """测试工厂函数 - 未知后端。""" try: make_embedder("nonexistent_backend") assert False, "Should raise ValueError" except ValueError as e: assert "nonexistent_backend" in str(e) print("[PASS] test_make_embedder_unknown") # ================================================================ # VectorMemory — 集成测试 # ================================================================ def test_vector_memory_add_and_search(): """测试添加记忆和语义检索。""" async def _run(): mem = VectorMemory( storage_path=_setup_test_dir(), embedder_backend="dummy", embedder_kwargs={"dimension": 128}, ) await mem.initialize() await mem.add( "用户今天说想喝奶茶", tags=["饮食", "偏好"], metadata={"source": "feishu"}, ) await mem.add( "用户最近在学习 Python 编程", tags=["学习", "编程"], metadata={"source": "feishu"}, ) await mem.add( "用户表示最近工作压力很大", tags=["情绪", "工作"], metadata={"source": "feishu"}, ) results = await mem._search_async("喝奶茶", top_k=3, min_similarity=0.0) assert len(results) >= 1 assert "奶茶" in results[0]["text"] # "喝奶茶"→含"奶茶"的记忆 results2 = await mem._search_async("编程学习", top_k=2) assert any("Python" in r["text"] for r in results2) results3 = await mem._search_async("心情怎么样", top_k=2) assert any("压力" in r["text"] or "情绪" in str(r["tags"]) for r in results3) mem.close() _teardown_test_dir() asyncio.run(_run()) print("[PASS] test_vector_memory_add_and_search") def test_vector_memory_retrieve_context(): """测试 RAG 上下文检索。""" async def _run(): mem = VectorMemory( storage_path=_setup_test_dir(), embedder_backend="dummy", embedder_kwargs={"dimension": 128}, ) await mem.initialize() await mem.add("用户今天下午说想喝奶茶,不加糖", tags=["饮食"]) await mem.add("用户今天早上喝了美式咖啡", tags=["饮食"]) await mem.add("用户最近在学习 Rust 编程语言", tags=["学习"]) context = await mem.retrieve_context( "用户对喝的东西有什么偏好", top_k=2, min_similarity=0.1, ) assert "奶茶" in context or "咖啡" in context assert "相关记忆" in context assert "相似度" in context mem.close() _teardown_test_dir() asyncio.run(_run()) print("[PASS] test_vector_memory_retrieve_context") def test_vector_memory_get_user_profile(): """测试用户画像生成。""" async def _run(): mem = VectorMemory( storage_path=_setup_test_dir(), embedder_backend="dummy", embedder_kwargs={"dimension": 128}, ) await mem.initialize() await mem.add("文本1", tags=["A", "B"]) await mem.add("文本2", tags=["A", "C"]) await mem.add("文本3", tags=["B", "D"]) profile = mem.get_user_profile() assert profile["total"] == 3 assert profile["tag_counts"]["A"] == 2 assert profile["tag_counts"]["B"] == 2 mem.close() _teardown_test_dir() asyncio.run(_run()) print("[PASS] test_vector_memory_get_user_profile") def test_vector_memory_get_recent_memories(): """测试获取最近记忆。""" async def _run(): mem = VectorMemory( storage_path=_setup_test_dir(), embedder_backend="dummy", embedder_kwargs={"dimension": 128}, ) await mem.initialize() for i in range(5): await mem.add(f"记忆条目 {i}", tags=[f"标签{i}"]) recent = await mem.get_recent_memories(limit=3) assert len(recent) == 3 assert recent[0]["text"] == "记忆条目 4" recent_filtered = await mem.get_recent_memories( limit=10, tags_filter=["标签0", "标签1"] ) assert len(recent_filtered) == 2 mem.close() _teardown_test_dir() asyncio.run(_run()) print("[PASS] test_vector_memory_get_recent_memories") def test_vector_memory_delete(): """测试删除记忆。""" async def _run(): mem = VectorMemory( storage_path=_setup_test_dir(), embedder_backend="dummy", embedder_kwargs={"dimension": 128}, ) await mem.initialize() eid = await mem.add("待删除记忆") assert mem.count() == 1 ok = mem.delete(eid) assert ok is True assert mem.count() == 0 mem.close() _teardown_test_dir() asyncio.run(_run()) print("[PASS] test_vector_memory_delete") def test_vector_memory_sync_search(): """测试同步 search 方法。""" async def _run(): mem = VectorMemory( storage_path=_setup_test_dir(), embedder_backend="dummy", embedder_kwargs={"dimension": 128}, ) await mem.initialize() await mem.add("测试文本", tags=["test"]) # 同步 search(内部自动检测 loop 状态) results = mem.search("测试", top_k=3) assert len(results) == 1 assert results[0]["text"] == "测试文本" mem.close() _teardown_test_dir() asyncio.run(_run()) print("[PASS] test_vector_memory_sync_search") # ================================================================ # 运行所有测试 # ================================================================ def run_all(): tests = [ # MemoryEntry test_memory_entry_to_from_dict, test_memory_entry_default_fields, # MemoryStore test_memory_store_insert_and_get_all, test_memory_store_count, test_memory_store_delete, test_memory_store_upsert, # VectorEngine test_vector_engine_build_empty, test_vector_engine_build_and_search, test_cosine_similarity, test_cosine_similarity_zero_vector, # Embedder test_dummy_embedder, test_dummy_embedder_async, test_make_embedder_dummy, test_make_embedder_unknown, # VectorMemory 集成测试 test_vector_memory_add_and_search, test_vector_memory_retrieve_context, test_vector_memory_get_user_profile, test_vector_memory_get_recent_memories, test_vector_memory_delete, test_vector_memory_sync_search, ] passed = 0 failed = 0 for test in tests: try: test() passed += 1 except AssertionError as e: print(f"[FAIL] {test.__name__}: {e}") failed += 1 except Exception as e: print(f"[ERROR] {test.__name__}: {type(e).__name__}: {e}") failed += 1 print(f"\n{'='*50}") print(f"测试结果: {passed}/{passed+failed} 通过", end="") if failed: print(f", {failed} 失败") else: print(", 全部通过!") print(f"{'='*50}") return failed == 0 if __name__ == "__main__": success = run_all() # 清理 _teardown_test_dir() sys.exit(0 if success else 1)