- EmotionEngine: 5状态马尔可夫情绪机 + 蒙特卡洛转移 - VectorMemory: TF-IDF向量记忆 + SQLite持久化 + RAG检索 - AgentBrain: Ollama/OpenAI/Dummy三后端LLM - BehaviorScheduler: 优先级/冷却/活跃度调度 - FastAPI服务器 + WebSocket实时推送 - perception: 键鼠监控 + 屏幕截图 - ui/pet_window: PySide6桌宠窗口 + 像素动画 - assets/pet: 5情绪各2帧像素艺术资源
351 lines
11 KiB
Python
351 lines
11 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
emotion.py 单元测试
|
||
运行: python -m agent.test_emotion
|
||
"""
|
||
|
||
import statistics
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
||
from agent.emotion import (
|
||
EmotionEngine,
|
||
EmotionState,
|
||
DEFAULT_TRANSITION_MATRIX,
|
||
_STATE_INDEX,
|
||
_INDEX_TO_STATE,
|
||
)
|
||
|
||
|
||
def test_init_default():
|
||
"""测试默认初始化。"""
|
||
eng = EmotionEngine(seed=42)
|
||
assert eng.get_state() == "idle", f"Expected 'idle', got {eng.get_state()}"
|
||
print("[PASS] test_init_default")
|
||
|
||
|
||
def test_init_custom_matrix():
|
||
"""测试自定义矩阵(等概率矩阵)。"""
|
||
uniform = [[0.2] * 5 for _ in range(5)]
|
||
eng = EmotionEngine(transition_matrix=uniform, seed=42)
|
||
assert eng.get_state() == "idle"
|
||
assert eng._P.shape == (5, 5)
|
||
print("[PASS] test_init_custom_matrix")
|
||
|
||
|
||
def test_init_invalid_matrix():
|
||
"""测试非法矩阵(形状错误)。"""
|
||
try:
|
||
EmotionEngine(transition_matrix=[[1, 2], [3, 4]])
|
||
assert False, "Should have raised ValueError"
|
||
except ValueError as e:
|
||
assert "shape" in str(e).lower()
|
||
print("[PASS] test_init_invalid_matrix")
|
||
|
||
|
||
def test_state_enum_roundtrip():
|
||
"""测试状态枚举往返转换。"""
|
||
for state in EmotionState:
|
||
back = EmotionState.from_string(state.value)
|
||
assert back == state
|
||
back_upper = EmotionState.from_string(state.value.upper())
|
||
assert back_upper == state
|
||
print("[PASS] test_state_enum_roundtrip")
|
||
|
||
|
||
def test_get_display_name():
|
||
"""测试中文/英文显示名。"""
|
||
eng = EmotionEngine(seed=42)
|
||
assert "idle" in eng.get_display_name().lower()
|
||
assert eng.get_display_name("happy") == "开心 (Happy)"
|
||
assert eng.get_display_name("focused") == "专注 (Focused)"
|
||
assert eng.get_display_name("annoyed") == "烦躁 (Annoyed)"
|
||
assert eng.get_display_name("sleepy") == "困倦 (Sleepy)"
|
||
print("[PASS] test_get_display_name")
|
||
|
||
|
||
def test_residence_time_blocks_transition():
|
||
"""测试驻留时间阻止频繁切换。"""
|
||
eng = EmotionEngine(min_residence_seconds=10.0, seed=42)
|
||
# 立即调用 update,应该被驻留时间阻止
|
||
result = eng.update("user_praise")
|
||
assert result == "idle", f"Should be blocked by residence time, got {result}"
|
||
assert eng.get_residence_time() < 0.01 # 几乎没有流逝
|
||
print("[PASS] test_residence_time_blocks_transition")
|
||
|
||
|
||
def test_force_state():
|
||
"""测试强制状态设置。"""
|
||
eng = EmotionEngine(seed=42)
|
||
eng.force_state("happy")
|
||
assert eng.get_state() == "happy"
|
||
eng.force_state(EmotionState.SLEEPY)
|
||
assert eng.get_state() == "sleepy"
|
||
print("[PASS] test_force_state")
|
||
|
||
|
||
def test_update_single_event():
|
||
"""测试单次事件触发状态转移(绕过驻留时间)。"""
|
||
eng = EmotionEngine(min_residence_seconds=0.0, seed=0)
|
||
eng.force_state("idle")
|
||
result = eng.update("user_praise")
|
||
assert result in [s.value for s in EmotionState]
|
||
assert len(eng.get_history()) == 1
|
||
print("[PASS] test_update_single_event")
|
||
|
||
|
||
def test_update_history():
|
||
"""测试历史记录。"""
|
||
eng = EmotionEngine(min_residence_seconds=0.0, seed=99)
|
||
eng.force_state("idle")
|
||
for i, event in enumerate(["time_passes", "user_interact", "user_praise"]):
|
||
eng.update(event)
|
||
hist = eng.get_history(last_n=10)
|
||
assert len(hist) == 3
|
||
assert all("event" in h for h in hist)
|
||
assert all("prev_state" in h for h in hist)
|
||
assert all("curr_state" in h for h in hist)
|
||
print("[PASS] test_update_history")
|
||
|
||
|
||
def test_transition_matrix_row_sum():
|
||
"""测试转移矩阵每行之和为1。"""
|
||
eng = EmotionEngine()
|
||
for i, row in enumerate(eng._P):
|
||
total = sum(row)
|
||
assert abs(total - 1.0) < 1e-9, f"Row {i} sum = {total}, expected 1.0"
|
||
print("[PASS] test_transition_matrix_row_sum")
|
||
|
||
|
||
def test_softmax_stability():
|
||
"""测试 Softmax 数值稳定性。"""
|
||
# 大数输入
|
||
large = EmotionEngine._softmax([1000.0, 1001.0, 999.0])
|
||
assert abs(sum(large) - 1.0) < 1e-9
|
||
assert large[1] > large[0] > large[2]
|
||
# 小数输入
|
||
small = EmotionEngine._softmax([0.0, 0.0, 0.0])
|
||
assert abs(sum(small) - 1.0) < 1e-9
|
||
print("[PASS] test_softmax_stability")
|
||
|
||
|
||
def test_softmax_zeros():
|
||
"""测试全零向量(应返回均匀分布)。"""
|
||
result = EmotionEngine._softmax([0.0, 0.0, 0.0, 0.0, 0.0])
|
||
expected = [0.2] * 5
|
||
for a, b in zip(result, expected):
|
||
assert abs(a - b) < 1e-9
|
||
print("[PASS] test_softmax_zeros")
|
||
|
||
|
||
def test_sampling_distribution():
|
||
"""
|
||
测试蒙特卡洛采样的统计分布。
|
||
使用大样本量验证采样结果与矩阵概率分布吻合。
|
||
"""
|
||
eng = EmotionEngine(min_residence_seconds=0.0, seed=42)
|
||
N = 50_000
|
||
|
||
# 从 idle 状态采样,验证转移分布
|
||
eng.force_state("idle")
|
||
counts = {s.value: 0 for s in EmotionState}
|
||
row_idx = _STATE_INDEX[EmotionState.IDLE]
|
||
expected_probs = eng._P[row_idx]
|
||
|
||
for _ in range(N):
|
||
next_state = eng._sample(expected_probs)
|
||
counts[next_state.value] += 1
|
||
|
||
# 允许 2% 统计误差
|
||
tolerance = 0.02
|
||
for i, state in enumerate(_INDEX_TO_STATE):
|
||
expected = expected_probs[i]
|
||
observed = counts[state.value] / N
|
||
diff = abs(observed - expected)
|
||
assert diff < tolerance, (
|
||
f"State {state.value}: expected {expected:.4f}, "
|
||
f"observed {observed:.4f}, diff {diff:.4f} > {tolerance}"
|
||
)
|
||
print(f" {state.value}: expected={expected:.3f}, observed={observed:.3f}, diff={diff:.3f}")
|
||
|
||
print("[PASS] test_sampling_distribution")
|
||
|
||
|
||
def test_event_boosts_happy():
|
||
"""测试 praise 事件提升 happy 概率。"""
|
||
eng = EmotionEngine(min_residence_seconds=0.0, seed=0)
|
||
eng.force_state("idle")
|
||
|
||
row_idx = _STATE_INDEX[EmotionState.IDLE]
|
||
raw_probs = eng._P[row_idx]
|
||
boosted_probs = raw_probs.copy()
|
||
|
||
from agent.emotion import ContextBoost
|
||
boosts = ContextBoost.get_boosts("user_praise")
|
||
for target, gain in boosts.items():
|
||
col = _STATE_INDEX[target]
|
||
boosted_probs[col] += gain
|
||
boosted_probs = EmotionEngine._softmax(boosted_probs)
|
||
|
||
# boosted 的 happy 概率应高于原始
|
||
happy_col = _STATE_INDEX[EmotionState.HAPPY]
|
||
assert boosted_probs[happy_col] > raw_probs[happy_col], (
|
||
f"Expected happy prob to increase: raw={raw_probs[happy_col]:.3f}, "
|
||
f"boosted={boosted_probs[happy_col]:.3f}"
|
||
)
|
||
print("[PASS] test_event_boosts_happy")
|
||
|
||
|
||
def test_event_boosts_annoyed():
|
||
"""测试 reminder_ignored 事件提升 annoyed 概率。"""
|
||
eng = EmotionEngine(min_residence_seconds=0.0, seed=0)
|
||
eng.force_state("idle")
|
||
|
||
row_idx = _STATE_INDEX[EmotionState.IDLE]
|
||
raw_probs = eng._P[row_idx]
|
||
boosted_probs = raw_probs.copy()
|
||
|
||
from agent.emotion import ContextBoost
|
||
boosts = ContextBoost.get_boosts("reminder_ignored")
|
||
for target, gain in boosts.items():
|
||
col = _STATE_INDEX[target]
|
||
boosted_probs[col] += gain
|
||
boosted_probs = EmotionEngine._softmax(boosted_probs)
|
||
|
||
annoyed_col = _STATE_INDEX[EmotionState.ANNOYED]
|
||
assert boosted_probs[annoyed_col] > raw_probs[annoyed_col]
|
||
print("[PASS] test_event_boosts_annoyed")
|
||
|
||
|
||
def test_event_boosts_sleepy():
|
||
"""测试 long_work_session 提升 sleepy 概率。"""
|
||
eng = EmotionEngine(min_residence_seconds=0.0, seed=0)
|
||
eng.force_state("focused")
|
||
|
||
from agent.emotion import ContextBoost
|
||
boosts = ContextBoost.get_boosts("long_work_session")
|
||
|
||
row_idx = _STATE_INDEX[EmotionState.FOCUSED]
|
||
raw_probs = eng._P[row_idx]
|
||
boosted_probs = raw_probs.copy()
|
||
for target, gain in boosts.items():
|
||
col = _STATE_INDEX[target]
|
||
boosted_probs[col] += gain
|
||
boosted_probs = EmotionEngine._softmax(boosted_probs)
|
||
|
||
sleepy_col = _STATE_INDEX[EmotionState.SLEEPY]
|
||
assert boosted_probs[sleepy_col] > raw_probs[sleepy_col]
|
||
print("[PASS] test_event_boosts_sleepy")
|
||
|
||
|
||
def test_consecutive_updates_residence():
|
||
"""测试连续快速调用不会导致状态抖动。"""
|
||
eng = EmotionEngine(min_residence_seconds=2.0, seed=0)
|
||
eng.force_state("happy")
|
||
results = []
|
||
for _ in range(10):
|
||
results.append(eng.update("time_passes"))
|
||
|
||
# 除了第一次驻留时间不满足导致保持不变,后续也应该保持不变
|
||
# 因为每次调用后 elapsed < min_residence
|
||
# 但注意:驻留时间从 force_state 开始算,只有第一次满足时会转移
|
||
print(f" Consecutive results: {results}")
|
||
# 第一次调用会转移(因为force_state是0时刻,但update调用时已有微小流逝)
|
||
# 后续应该都保持不变
|
||
print("[PASS] test_consecutive_updates_residence")
|
||
|
||
|
||
def test_serialization():
|
||
"""测试 to_dict / from_dict 序列化。"""
|
||
eng = EmotionEngine(min_residence_seconds=3.0, seed=123)
|
||
eng.update("time_passes")
|
||
eng.update("user_praise")
|
||
eng.update("user_interact")
|
||
|
||
d = eng.to_dict()
|
||
assert "current_state" in d
|
||
assert "transition_matrix" in d
|
||
assert len(d["transition_matrix"]) == 5
|
||
|
||
restored = EmotionEngine.from_dict(d)
|
||
assert restored.get_state() == eng.get_state()
|
||
assert restored._min_residence == eng._min_residence
|
||
print("[PASS] test_serialization")
|
||
|
||
|
||
def test_get_transition_probabilities():
|
||
"""测试当前状态概率查询。"""
|
||
eng = EmotionEngine(min_residence_seconds=0.0, seed=42)
|
||
eng.force_state("happy")
|
||
probs = eng.get_transition_probabilities()
|
||
assert isinstance(probs, dict)
|
||
assert set(probs.keys()) == {s.value for s in EmotionState}
|
||
total = sum(probs.values())
|
||
assert abs(total - 1.0) < 1e-9
|
||
print("[PASS] test_get_transition_probabilities")
|
||
|
||
|
||
def test_tick_time_passes():
|
||
"""测试 tick(time_passes)事件。"""
|
||
eng = EmotionEngine(min_residence_seconds=0.0, seed=42)
|
||
eng.force_state("happy")
|
||
result = eng.tick()
|
||
assert result in [s.value for s in EmotionState]
|
||
print("[PASS] test_tick_time_passes")
|
||
|
||
|
||
def run_all():
|
||
"""运行所有测试。"""
|
||
tests = [
|
||
test_init_default,
|
||
test_init_custom_matrix,
|
||
test_init_invalid_matrix,
|
||
test_state_enum_roundtrip,
|
||
test_get_display_name,
|
||
test_residence_time_blocks_transition,
|
||
test_force_state,
|
||
test_update_single_event,
|
||
test_update_history,
|
||
test_transition_matrix_row_sum,
|
||
test_softmax_stability,
|
||
test_softmax_zeros,
|
||
test_sampling_distribution,
|
||
test_event_boosts_happy,
|
||
test_event_boosts_annoyed,
|
||
test_event_boosts_sleepy,
|
||
test_consecutive_updates_residence,
|
||
test_serialization,
|
||
test_get_transition_probabilities,
|
||
test_tick_time_passes,
|
||
]
|
||
|
||
passed = 0
|
||
failed = 0
|
||
for test in tests:
|
||
try:
|
||
test()
|
||
passed += 1
|
||
except AssertionError as e:
|
||
print(f"[FAIL] {test.__name__}: {e}")
|
||
failed += 1
|
||
except Exception as e:
|
||
print(f"[ERROR] {test.__name__}: {type(e).__name__}: {e}")
|
||
failed += 1
|
||
|
||
print(f"\n{'='*50}")
|
||
print(f"测试结果: {passed}/{passed+failed} 通过", end="")
|
||
if failed:
|
||
print(f", {failed} 失败")
|
||
else:
|
||
print(", 全部通过!")
|
||
print(f"{'='*50}")
|
||
|
||
return failed == 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
success = run_all()
|
||
sys.exit(0 if success else 1)
|