- EmotionEngine: 5状态马尔可夫情绪机 + 蒙特卡洛转移 - VectorMemory: TF-IDF向量记忆 + SQLite持久化 + RAG检索 - AgentBrain: Ollama/OpenAI/Dummy三后端LLM - BehaviorScheduler: 优先级/冷却/活跃度调度 - FastAPI服务器 + WebSocket实时推送 - perception: 键鼠监控 + 屏幕截图 - ui/pet_window: PySide6桌宠窗口 + 像素动画 - assets/pet: 5情绪各2帧像素艺术资源
614 lines
19 KiB
Python
614 lines
19 KiB
Python
"""
|
||
EzVibe API Server
|
||
=================
|
||
设计文档对应章节:前后端通信接口 + 系统时序交互图
|
||
|
||
核心职责
|
||
• HTTP REST 接口:聊天、状态查询、记忆管理
|
||
• WebSocket 实时推送:主动行为通知、情绪状态变化
|
||
• 提供 Agent 与 Qt 前端的异步非阻塞通信通道
|
||
|
||
设计决策
|
||
• 开发模式:FastAPI 内置 uvicorn 直接运行(python -m api.server)
|
||
• 生产模式:uvicorn api.server:app --host 0.0.0.0 --port 8765
|
||
• 前端通过 HTTP POST /chat 发送消息
|
||
• 主动行为通过 WebSocket /ws 推送(pet → 用户)
|
||
|
||
通信协议
|
||
• REST: JSON over HTTPS
|
||
• WebSocket: JSON 消息帧,格式 {type, payload, timestamp}
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import time
|
||
from typing import Any
|
||
|
||
import numpy as np
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# ================================================================
|
||
# 尝试导入 FastAPI(可选依赖)
|
||
# ================================================================
|
||
|
||
try:
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
from pydantic import BaseModel, Field
|
||
from sse_starlette.sse import EventSourceResponse
|
||
import starlette.responses
|
||
_HAS_FASTAPI = True
|
||
except ImportError:
|
||
_HAS_FASTAPI = False
|
||
FastAPI = None # type: ignore[misc, assignment]
|
||
BaseModel = object # type: ignore[misc, assignment]
|
||
|
||
|
||
# ================================================================
|
||
# API 模型(Pydantic)
|
||
# ================================================================
|
||
|
||
class ChatRequest(BaseModel):
|
||
"""POST /chat 请求体。"""
|
||
|
||
message: str = Field(..., min_length=1, max_length=4000)
|
||
emotion_state: str | None = Field(default=None, description="当前情绪状态")
|
||
context: dict[str, Any] | None = Field(default=None, description="额外上下文")
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
"""POST /chat 响应体。"""
|
||
|
||
text: str
|
||
emotion_state: str
|
||
action: dict | None = None
|
||
memory_id: str | None = None
|
||
latency_ms: float
|
||
|
||
|
||
class MemoryAddRequest(BaseModel):
|
||
"""POST /memory/add 请求体。"""
|
||
|
||
text: str = Field(..., min_length=1)
|
||
tags: list[str] = Field(default_factory=list)
|
||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||
|
||
|
||
class MemorySearchRequest(BaseModel):
|
||
"""POST /memory/search 请求体。"""
|
||
|
||
query: str = Field(..., min_length=1)
|
||
top_k: int = Field(default=5, ge=1, le=50)
|
||
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0)
|
||
|
||
|
||
class EmotionUpdateRequest(BaseModel):
|
||
"""POST /emotion/update 请求体。"""
|
||
|
||
event: str = Field(..., description="事件名称(如 user_praise)")
|
||
force_state: str | None = Field(default=None, description="强制状态")
|
||
|
||
|
||
# ================================================================
|
||
# WebSocket 连接管理器
|
||
# ================================================================
|
||
|
||
class ConnectionManager:
|
||
"""
|
||
管理所有 WebSocket 连接。
|
||
|
||
支持:
|
||
• 广播(全员推送)
|
||
• 单播(指定连接)
|
||
• 过滤推送(按连接标签)
|
||
|
||
消息格式
|
||
--------
|
||
所有消息为 JSON:
|
||
{
|
||
"type": "event_type", # 如 "action", "emotion_change", "heartbeat"
|
||
"payload": {...}, # 实际数据
|
||
"timestamp": 1698765432.123,
|
||
}
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._connections: list[WebSocket] = []
|
||
|
||
async def connect(self, websocket: WebSocket) -> None:
|
||
"""接受并注册一个新连接。"""
|
||
await websocket.accept()
|
||
self._connections.append(websocket)
|
||
logger.info("[WS] Connected: total=%d", len(self._connections))
|
||
|
||
def disconnect(self, websocket: WebSocket) -> None:
|
||
"""注销连接。"""
|
||
if websocket in self._connections:
|
||
self._connections.remove(websocket)
|
||
logger.info("[WS] Disconnected: total=%d", len(self._connections))
|
||
|
||
async def send_json(self, websocket: WebSocket, data: dict) -> bool:
|
||
"""向单个连接发送 JSON 消息。"""
|
||
try:
|
||
await websocket.send_json(data)
|
||
return True
|
||
except Exception as exc:
|
||
logger.warning("[WS] Send failed: %s", exc)
|
||
self.disconnect(websocket)
|
||
return False
|
||
|
||
async def broadcast(self, data: dict) -> int:
|
||
"""
|
||
广播消息到所有连接。
|
||
|
||
返回成功发送数。
|
||
"""
|
||
success = 0
|
||
for conn in list(self._connections):
|
||
if await self.send_json(conn, data):
|
||
success += 1
|
||
return success
|
||
|
||
@property
|
||
def connection_count(self) -> int:
|
||
return len(self._connections)
|
||
|
||
|
||
# ================================================================
|
||
# 全局状态(供 FastAPI 注入依赖)
|
||
# ================================================================
|
||
|
||
class AppState:
|
||
"""
|
||
应用全局状态。
|
||
|
||
由 FastAPI 启动时初始化,并通过 Depends() 注入到各路由。
|
||
设计模式:Application State 模式。
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
self._agent_brain: Any = None
|
||
self._emotion_engine: Any = None
|
||
self._memory: Any = None
|
||
self._scheduler: Any = None
|
||
self._manager = ConnectionManager()
|
||
self._scheduler_task: asyncio.Task | None = None
|
||
self._running = False
|
||
|
||
def link_agent(self, brain: Any) -> None:
|
||
self._agent_brain = brain
|
||
|
||
def link_emotion(self, emotion: Any) -> None:
|
||
self._emotion_engine = emotion
|
||
|
||
def link_memory(self, memory: Any) -> None:
|
||
self._memory = memory
|
||
|
||
def link_scheduler(self, scheduler: Any) -> None:
|
||
self._scheduler = scheduler
|
||
|
||
@property
|
||
def brain(self):
|
||
if self._agent_brain is None:
|
||
raise HTTPException(503, "Agent 未初始化,请先启动应用")
|
||
return self._agent_brain
|
||
|
||
@property
|
||
def emotion(self):
|
||
return self._emotion_engine
|
||
|
||
@property
|
||
def memory(self):
|
||
return self._memory
|
||
|
||
@property
|
||
def scheduler(self):
|
||
return self._scheduler
|
||
|
||
@property
|
||
def ws_manager(self) -> ConnectionManager:
|
||
return self._manager
|
||
|
||
|
||
# 全局单例
|
||
_state = AppState()
|
||
|
||
|
||
# ================================================================
|
||
# FastAPI 应用构建器
|
||
# ================================================================
|
||
|
||
def create_app(state: AppState = _state) -> Any:
|
||
"""构建 FastAPI 应用实例。"""
|
||
if not _HAS_FASTAPI:
|
||
raise ImportError(
|
||
"FastAPI 未安装。运行: pip install fastapi 'uvicorn[standard]' "
|
||
"sse-starlette starlette"
|
||
)
|
||
|
||
app = FastAPI(
|
||
title="EzVibe API",
|
||
description="EzVibe AI 桌宠后端 API",
|
||
version="0.1.0",
|
||
)
|
||
|
||
# CORS:允许前端(Qt WebEngine / localhost)访问
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"], # 本地开发,生产环境应限制
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# ============================================================
|
||
# 辅助函数
|
||
# ============================================================
|
||
|
||
def _ts() -> float:
|
||
return time.time()
|
||
|
||
def _make_ws_message(msg_type: str, payload: dict) -> dict:
|
||
return {"type": msg_type, "payload": payload, "timestamp": _ts()}
|
||
|
||
# ============================================================
|
||
# REST 路由
|
||
# ============================================================
|
||
|
||
@app.get("/health")
|
||
async def health():
|
||
"""健康检查。"""
|
||
return {
|
||
"status": "ok",
|
||
"timestamp": _ts(),
|
||
"connections": state.ws_manager.connection_count,
|
||
}
|
||
|
||
@app.post("/chat", response_model=ChatResponse)
|
||
async def chat(req: ChatRequest) -> ChatResponse:
|
||
"""
|
||
发送消息给 Agent,获取回复。
|
||
|
||
设计文档:时序图 Step 1(用户输入) → Step 2(返回响应)。
|
||
"""
|
||
start = _ts()
|
||
brain = state.brain
|
||
emotion = state.emotion
|
||
|
||
# 1. 获取当前情绪状态
|
||
emotion_state = req.emotion_state or (
|
||
emotion.get_state() if emotion else "idle"
|
||
)
|
||
|
||
# 2. 调用 Agent Brain
|
||
result = await brain.think(
|
||
user_input=req.message,
|
||
emotion_state=emotion_state,
|
||
context=req.context,
|
||
)
|
||
|
||
# 3. 触发情绪转移(如果有事件)
|
||
if emotion and result.get("emotion_trigger"):
|
||
emotion.update(result["emotion_trigger"])
|
||
new_state = emotion.get_state()
|
||
await state.ws_manager.broadcast(_make_ws_message(
|
||
"emotion_change",
|
||
{"old": emotion_state, "new": new_state}
|
||
))
|
||
|
||
# 4. 如果有主动行为,推送 WebSocket
|
||
if result.get("action"):
|
||
await state.ws_manager.broadcast(_make_ws_message(
|
||
"action",
|
||
result["action"],
|
||
))
|
||
|
||
# 5. 返回结果
|
||
latency = (time.time() - start) * 1000
|
||
return ChatResponse(
|
||
text=result["text"],
|
||
emotion_state=result["emotion_state"],
|
||
action=result.get("action"),
|
||
memory_id=result.get("memory_id"),
|
||
latency_ms=round(latency, 1),
|
||
)
|
||
|
||
@app.get("/state")
|
||
async def get_state():
|
||
"""
|
||
获取 Agent 完整状态(用于前端同步)。
|
||
|
||
返回:情绪状态、记忆数量、调度器状态、活跃度等。
|
||
"""
|
||
emotion = state.emotion
|
||
memory = state.memory
|
||
scheduler = state.scheduler
|
||
brain = state.brain
|
||
|
||
return {
|
||
"timestamp": _ts(),
|
||
"emotion": {
|
||
"state": emotion.get_state() if emotion else None,
|
||
"display_name": emotion.get_display_name() if emotion else None,
|
||
"transition_probs": (
|
||
emotion.get_transition_probabilities()
|
||
if emotion else None
|
||
),
|
||
} if emotion else None,
|
||
"memory": {
|
||
"count": memory._store.count() if memory else 0,
|
||
} if memory else None,
|
||
"scheduler": scheduler.get_status() if scheduler else None,
|
||
"brain": brain.get_status() if brain else None,
|
||
}
|
||
|
||
@app.post("/emotion/update")
|
||
async def update_emotion(req: EmotionUpdateRequest):
|
||
"""触发情绪状态转移。"""
|
||
if not state.emotion:
|
||
raise HTTPException(503, "情绪引擎未初始化")
|
||
old_state = state.emotion.get_state()
|
||
if req.force_state:
|
||
state.emotion.force_state(req.force_state)
|
||
else:
|
||
state.emotion.update(req.event)
|
||
new_state = state.emotion.get_state()
|
||
return {
|
||
"old": old_state,
|
||
"new": new_state,
|
||
"triggered_by": req.force_state or req.event,
|
||
}
|
||
|
||
@app.get("/emotion/state")
|
||
async def get_emotion_state():
|
||
"""获取当前情绪状态。"""
|
||
if not state.emotion:
|
||
raise HTTPException(503, "情绪引擎未初始化")
|
||
return {
|
||
"state": state.emotion.get_state(),
|
||
"display_name": state.emotion.get_display_name(),
|
||
}
|
||
|
||
@app.post("/memory/add")
|
||
async def add_memory(req: MemoryAddRequest):
|
||
"""手动添加记忆。"""
|
||
if not state.memory:
|
||
raise HTTPException(503, "记忆系统未初始化")
|
||
memory_id = await state.memory.add(
|
||
text=req.text,
|
||
tags=req.tags,
|
||
metadata=req.metadata,
|
||
)
|
||
return {"id": memory_id, "text": req.text}
|
||
|
||
@app.post("/memory/search")
|
||
async def search_memory(req: MemorySearchRequest):
|
||
"""语义检索记忆。"""
|
||
if not state.memory:
|
||
raise HTTPException(503, "记忆系统未初始化")
|
||
results = await state.memory._search_async(
|
||
req.query, top_k=req.top_k, min_similarity=req.min_similarity
|
||
)
|
||
return {"results": results, "count": len(results)}
|
||
|
||
@app.get("/memory/all")
|
||
async def get_all_memories(limit: int = 50):
|
||
"""获取所有记忆(倒序)。"""
|
||
if not state.memory:
|
||
raise HTTPException(503, "记忆系统未初始化")
|
||
entries = state.memory._store.get_all()[:limit]
|
||
return {
|
||
"memories": [
|
||
{
|
||
"id": e.id,
|
||
"text": e.text,
|
||
"tags": e.tags,
|
||
"created_at": e.created_at,
|
||
}
|
||
for e in entries
|
||
],
|
||
"total": state.memory._store.count(),
|
||
}
|
||
|
||
@app.delete("/memory/{memory_id}")
|
||
async def delete_memory(memory_id: str):
|
||
"""删除指定记忆。"""
|
||
if not state.memory:
|
||
raise HTTPException(503, "记忆系统未初始化")
|
||
deleted = state.memory.delete(memory_id)
|
||
if not deleted:
|
||
raise HTTPException(404, "记忆不存在")
|
||
return {"deleted": memory_id}
|
||
|
||
# ============================================================
|
||
# WebSocket 路由
|
||
# ============================================================
|
||
|
||
@app.websocket("/ws")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
"""
|
||
WebSocket 实时通信端点。
|
||
|
||
前端通过此连接接收主动行为推送。
|
||
|
||
消息类型(服务端 → 客户端)
|
||
----------------------------------------
|
||
action : 主动行为(提醒、闲聊触发)
|
||
emotion_change : 情绪状态变化
|
||
heartbeat : 心跳保活(每 30s)
|
||
"""
|
||
await state.ws_manager.connect(websocket)
|
||
try:
|
||
# 发送连接确认
|
||
await websocket.send_json(_make_ws_message(
|
||
"connected",
|
||
{"message": "EzVibe 已连接!"}
|
||
))
|
||
|
||
# 启动心跳协程
|
||
async def heartbeat():
|
||
while True:
|
||
await asyncio.sleep(30)
|
||
try:
|
||
await websocket.send_json(_make_ws_message(
|
||
"heartbeat",
|
||
{"t": time.time()}
|
||
))
|
||
except Exception:
|
||
break
|
||
|
||
hb_task = asyncio.create_task(heartbeat())
|
||
|
||
# 监听客户端消息(目前主要接收前端的状态同步请求)
|
||
while True:
|
||
try:
|
||
data = await websocket.receive_json()
|
||
msg_type = data.get("type", "")
|
||
|
||
if msg_type == "ping":
|
||
await websocket.send_json(_make_ws_message(
|
||
"pong",
|
||
{"t": data.get("t")}
|
||
))
|
||
elif msg_type == "get_state":
|
||
# 前端请求完整状态
|
||
resp = await get_state()
|
||
await websocket.send_json(_make_ws_message(
|
||
"state_sync",
|
||
resp,
|
||
))
|
||
elif msg_type == "dismiss_action":
|
||
# 前端Dismiss了某个主动行为
|
||
action_id = data.get("action_id")
|
||
logger.info("[WS] Action dismissed: %s", action_id)
|
||
|
||
except Exception as exc:
|
||
logger.warning("[WS] Receive error: %s", exc)
|
||
break
|
||
|
||
except WebSocketDisconnect:
|
||
pass
|
||
finally:
|
||
hb_task.cancel()
|
||
state.ws_manager.disconnect(websocket)
|
||
|
||
# ============================================================
|
||
# SSE 流式端点(备选,适合简单前端集成)
|
||
# ============================================================
|
||
|
||
@app.get("/events")
|
||
async def sse_events():
|
||
"""
|
||
Server-Sent Events 流。
|
||
|
||
替代 WebSocket 的轻量方案,适合 EventSource API。
|
||
"""
|
||
|
||
async def event_generator():
|
||
while True:
|
||
await asyncio.sleep(5)
|
||
yield {
|
||
"event": "heartbeat",
|
||
"data": json.dumps({"t": time.time()}),
|
||
}
|
||
|
||
return EventSourceResponse(event_generator())
|
||
|
||
return app
|
||
|
||
|
||
# ================================================================
|
||
# 独立运行入口
|
||
# ================================================================
|
||
|
||
async def run_server(
|
||
host: str = "127.0.0.1",
|
||
port: int = 8765,
|
||
brain: Any = None,
|
||
emotion: Any = None,
|
||
memory: Any = None,
|
||
scheduler: Any = None,
|
||
) -> None:
|
||
"""
|
||
启动 API 服务器(独立进程模式)。
|
||
|
||
参数
|
||
----
|
||
host, port : API 监听地址。
|
||
brain, emotion, memory, scheduler : 核心模块实例(None 时仅启动空 API)。
|
||
"""
|
||
import uvicorn
|
||
|
||
# 注入核心模块
|
||
if brain is not None:
|
||
_state.link_agent(brain)
|
||
if emotion is not None:
|
||
_state.link_emotion(emotion)
|
||
if memory is not None:
|
||
_state.link_memory(memory)
|
||
if scheduler is not None:
|
||
_state.link_scheduler(scheduler)
|
||
|
||
# 启动主动行为推送循环
|
||
if scheduler is not None:
|
||
_state._running = True
|
||
asyncio.create_task(_proactive_loop(_state))
|
||
|
||
app = create_app(_state)
|
||
config = uvicorn.Config(
|
||
app,
|
||
host=host,
|
||
port=port,
|
||
log_level="info",
|
||
)
|
||
server = uvicorn.Server(config)
|
||
logger.info("[API] EzVibe API Server starting on http://%s:%d", host, port)
|
||
await server.serve()
|
||
|
||
|
||
async def _proactive_loop(state: AppState) -> None:
|
||
"""
|
||
后台循环:定期检查调度器并推送主动行为。
|
||
|
||
运行在后台,不阻塞主服务。
|
||
"""
|
||
while state._running:
|
||
await asyncio.sleep(10) # 每 10 秒检查一次
|
||
if state.scheduler is None:
|
||
continue
|
||
try:
|
||
triggered = await state.scheduler.check_and_trigger()
|
||
for action in triggered:
|
||
await state.ws_manager.broadcast({
|
||
"type": "action",
|
||
"payload": action,
|
||
"timestamp": time.time(),
|
||
})
|
||
logger.info("[Proactive] Action triggered: %s", action.get("type"))
|
||
except Exception as exc:
|
||
logger.warning("[Proactive] Scheduler error: %s", exc)
|
||
|
||
|
||
# ================================================================
|
||
# 入口点(python -m api.server)
|
||
# ================================================================
|
||
|
||
if __name__ == "__main__":
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(name)s %(levelname)s %(message)s",
|
||
)
|
||
if not _HAS_FASTAPI:
|
||
print("ERROR: FastAPI 未安装。")
|
||
print("运行: pip install fastapi 'uvicorn[standard]' sse-starlette starlette")
|
||
import sys
|
||
sys.exit(1)
|
||
print("启动 EzVibe API Server(独立模式)...")
|
||
print("API 文档: http://127.0.0.1:8765/docs")
|
||
asyncio.run(run_server())
|