Files
EzVibe/api/server.py
e2hang 2a844e83a8 Initial commit: EzVibe AI 桌宠系统
- EmotionEngine: 5状态马尔可夫情绪机 + 蒙特卡洛转移
- VectorMemory: TF-IDF向量记忆 + SQLite持久化 + RAG检索
- AgentBrain: Ollama/OpenAI/Dummy三后端LLM
- BehaviorScheduler: 优先级/冷却/活跃度调度
- FastAPI服务器 + WebSocket实时推送
- perception: 键鼠监控 + 屏幕截图
- ui/pet_window: PySide6桌宠窗口 + 像素动画
- assets/pet: 5情绪各2帧像素艺术资源
2026-05-01 23:26:43 +08:00

614 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
EzVibe API Server
=================
设计文档对应章节:前后端通信接口 + 系统时序交互图
核心职责
• HTTP REST 接口:聊天、状态查询、记忆管理
• WebSocket 实时推送:主动行为通知、情绪状态变化
• 提供 Agent 与 Qt 前端的异步非阻塞通信通道
设计决策
• 开发模式FastAPI 内置 uvicorn 直接运行python -m api.server
• 生产模式uvicorn api.server:app --host 0.0.0.0 --port 8765
• 前端通过 HTTP POST /chat 发送消息
• 主动行为通过 WebSocket /ws 推送pet → 用户)
通信协议
• REST: JSON over HTTPS
• WebSocket: JSON 消息帧,格式 {type, payload, timestamp}
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Any
import numpy as np
logger = logging.getLogger(__name__)
# ================================================================
# 尝试导入 FastAPI可选依赖
# ================================================================
try:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
import starlette.responses
_HAS_FASTAPI = True
except ImportError:
_HAS_FASTAPI = False
FastAPI = None # type: ignore[misc, assignment]
BaseModel = object # type: ignore[misc, assignment]
# ================================================================
# API 模型Pydantic
# ================================================================
class ChatRequest(BaseModel):
"""POST /chat 请求体。"""
message: str = Field(..., min_length=1, max_length=4000)
emotion_state: str | None = Field(default=None, description="当前情绪状态")
context: dict[str, Any] | None = Field(default=None, description="额外上下文")
class ChatResponse(BaseModel):
"""POST /chat 响应体。"""
text: str
emotion_state: str
action: dict | None = None
memory_id: str | None = None
latency_ms: float
class MemoryAddRequest(BaseModel):
"""POST /memory/add 请求体。"""
text: str = Field(..., min_length=1)
tags: list[str] = Field(default_factory=list)
metadata: dict[str, Any] = Field(default_factory=dict)
class MemorySearchRequest(BaseModel):
"""POST /memory/search 请求体。"""
query: str = Field(..., min_length=1)
top_k: int = Field(default=5, ge=1, le=50)
min_similarity: float = Field(default=0.0, ge=0.0, le=1.0)
class EmotionUpdateRequest(BaseModel):
"""POST /emotion/update 请求体。"""
event: str = Field(..., description="事件名称(如 user_praise")
force_state: str | None = Field(default=None, description="强制状态")
# ================================================================
# WebSocket 连接管理器
# ================================================================
class ConnectionManager:
"""
管理所有 WebSocket 连接。
支持:
• 广播(全员推送)
• 单播(指定连接)
• 过滤推送(按连接标签)
消息格式
--------
所有消息为 JSON
{
"type": "event_type", # 如 "action", "emotion_change", "heartbeat"
"payload": {...}, # 实际数据
"timestamp": 1698765432.123,
}
"""
def __init__(self) -> None:
self._connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket) -> None:
"""接受并注册一个新连接。"""
await websocket.accept()
self._connections.append(websocket)
logger.info("[WS] Connected: total=%d", len(self._connections))
def disconnect(self, websocket: WebSocket) -> None:
"""注销连接。"""
if websocket in self._connections:
self._connections.remove(websocket)
logger.info("[WS] Disconnected: total=%d", len(self._connections))
async def send_json(self, websocket: WebSocket, data: dict) -> bool:
"""向单个连接发送 JSON 消息。"""
try:
await websocket.send_json(data)
return True
except Exception as exc:
logger.warning("[WS] Send failed: %s", exc)
self.disconnect(websocket)
return False
async def broadcast(self, data: dict) -> int:
"""
广播消息到所有连接。
返回成功发送数。
"""
success = 0
for conn in list(self._connections):
if await self.send_json(conn, data):
success += 1
return success
@property
def connection_count(self) -> int:
return len(self._connections)
# ================================================================
# 全局状态(供 FastAPI 注入依赖)
# ================================================================
class AppState:
"""
应用全局状态。
由 FastAPI 启动时初始化,并通过 Depends() 注入到各路由。
设计模式Application State 模式。
"""
def __init__(self) -> None:
self._agent_brain: Any = None
self._emotion_engine: Any = None
self._memory: Any = None
self._scheduler: Any = None
self._manager = ConnectionManager()
self._scheduler_task: asyncio.Task | None = None
self._running = False
def link_agent(self, brain: Any) -> None:
self._agent_brain = brain
def link_emotion(self, emotion: Any) -> None:
self._emotion_engine = emotion
def link_memory(self, memory: Any) -> None:
self._memory = memory
def link_scheduler(self, scheduler: Any) -> None:
self._scheduler = scheduler
@property
def brain(self):
if self._agent_brain is None:
raise HTTPException(503, "Agent 未初始化,请先启动应用")
return self._agent_brain
@property
def emotion(self):
return self._emotion_engine
@property
def memory(self):
return self._memory
@property
def scheduler(self):
return self._scheduler
@property
def ws_manager(self) -> ConnectionManager:
return self._manager
# 全局单例
_state = AppState()
# ================================================================
# FastAPI 应用构建器
# ================================================================
def create_app(state: AppState = _state) -> Any:
"""构建 FastAPI 应用实例。"""
if not _HAS_FASTAPI:
raise ImportError(
"FastAPI 未安装。运行: pip install fastapi 'uvicorn[standard]' "
"sse-starlette starlette"
)
app = FastAPI(
title="EzVibe API",
description="EzVibe AI 桌宠后端 API",
version="0.1.0",
)
# CORS允许前端Qt WebEngine / localhost访问
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 本地开发,生产环境应限制
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================
# 辅助函数
# ============================================================
def _ts() -> float:
return time.time()
def _make_ws_message(msg_type: str, payload: dict) -> dict:
return {"type": msg_type, "payload": payload, "timestamp": _ts()}
# ============================================================
# REST 路由
# ============================================================
@app.get("/health")
async def health():
"""健康检查。"""
return {
"status": "ok",
"timestamp": _ts(),
"connections": state.ws_manager.connection_count,
}
@app.post("/chat", response_model=ChatResponse)
async def chat(req: ChatRequest) -> ChatResponse:
"""
发送消息给 Agent获取回复。
设计文档:时序图 Step 1用户输入 → Step 2返回响应
"""
start = _ts()
brain = state.brain
emotion = state.emotion
# 1. 获取当前情绪状态
emotion_state = req.emotion_state or (
emotion.get_state() if emotion else "idle"
)
# 2. 调用 Agent Brain
result = await brain.think(
user_input=req.message,
emotion_state=emotion_state,
context=req.context,
)
# 3. 触发情绪转移(如果有事件)
if emotion and result.get("emotion_trigger"):
emotion.update(result["emotion_trigger"])
new_state = emotion.get_state()
await state.ws_manager.broadcast(_make_ws_message(
"emotion_change",
{"old": emotion_state, "new": new_state}
))
# 4. 如果有主动行为,推送 WebSocket
if result.get("action"):
await state.ws_manager.broadcast(_make_ws_message(
"action",
result["action"],
))
# 5. 返回结果
latency = (time.time() - start) * 1000
return ChatResponse(
text=result["text"],
emotion_state=result["emotion_state"],
action=result.get("action"),
memory_id=result.get("memory_id"),
latency_ms=round(latency, 1),
)
@app.get("/state")
async def get_state():
"""
获取 Agent 完整状态(用于前端同步)。
返回:情绪状态、记忆数量、调度器状态、活跃度等。
"""
emotion = state.emotion
memory = state.memory
scheduler = state.scheduler
brain = state.brain
return {
"timestamp": _ts(),
"emotion": {
"state": emotion.get_state() if emotion else None,
"display_name": emotion.get_display_name() if emotion else None,
"transition_probs": (
emotion.get_transition_probabilities()
if emotion else None
),
} if emotion else None,
"memory": {
"count": memory._store.count() if memory else 0,
} if memory else None,
"scheduler": scheduler.get_status() if scheduler else None,
"brain": brain.get_status() if brain else None,
}
@app.post("/emotion/update")
async def update_emotion(req: EmotionUpdateRequest):
"""触发情绪状态转移。"""
if not state.emotion:
raise HTTPException(503, "情绪引擎未初始化")
old_state = state.emotion.get_state()
if req.force_state:
state.emotion.force_state(req.force_state)
else:
state.emotion.update(req.event)
new_state = state.emotion.get_state()
return {
"old": old_state,
"new": new_state,
"triggered_by": req.force_state or req.event,
}
@app.get("/emotion/state")
async def get_emotion_state():
"""获取当前情绪状态。"""
if not state.emotion:
raise HTTPException(503, "情绪引擎未初始化")
return {
"state": state.emotion.get_state(),
"display_name": state.emotion.get_display_name(),
}
@app.post("/memory/add")
async def add_memory(req: MemoryAddRequest):
"""手动添加记忆。"""
if not state.memory:
raise HTTPException(503, "记忆系统未初始化")
memory_id = await state.memory.add(
text=req.text,
tags=req.tags,
metadata=req.metadata,
)
return {"id": memory_id, "text": req.text}
@app.post("/memory/search")
async def search_memory(req: MemorySearchRequest):
"""语义检索记忆。"""
if not state.memory:
raise HTTPException(503, "记忆系统未初始化")
results = await state.memory._search_async(
req.query, top_k=req.top_k, min_similarity=req.min_similarity
)
return {"results": results, "count": len(results)}
@app.get("/memory/all")
async def get_all_memories(limit: int = 50):
"""获取所有记忆(倒序)。"""
if not state.memory:
raise HTTPException(503, "记忆系统未初始化")
entries = state.memory._store.get_all()[:limit]
return {
"memories": [
{
"id": e.id,
"text": e.text,
"tags": e.tags,
"created_at": e.created_at,
}
for e in entries
],
"total": state.memory._store.count(),
}
@app.delete("/memory/{memory_id}")
async def delete_memory(memory_id: str):
"""删除指定记忆。"""
if not state.memory:
raise HTTPException(503, "记忆系统未初始化")
deleted = state.memory.delete(memory_id)
if not deleted:
raise HTTPException(404, "记忆不存在")
return {"deleted": memory_id}
# ============================================================
# WebSocket 路由
# ============================================================
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
"""
WebSocket 实时通信端点。
前端通过此连接接收主动行为推送。
消息类型(服务端 → 客户端)
----------------------------------------
action : 主动行为(提醒、闲聊触发)
emotion_change : 情绪状态变化
heartbeat : 心跳保活(每 30s
"""
await state.ws_manager.connect(websocket)
try:
# 发送连接确认
await websocket.send_json(_make_ws_message(
"connected",
{"message": "EzVibe 已连接!"}
))
# 启动心跳协程
async def heartbeat():
while True:
await asyncio.sleep(30)
try:
await websocket.send_json(_make_ws_message(
"heartbeat",
{"t": time.time()}
))
except Exception:
break
hb_task = asyncio.create_task(heartbeat())
# 监听客户端消息(目前主要接收前端的状态同步请求)
while True:
try:
data = await websocket.receive_json()
msg_type = data.get("type", "")
if msg_type == "ping":
await websocket.send_json(_make_ws_message(
"pong",
{"t": data.get("t")}
))
elif msg_type == "get_state":
# 前端请求完整状态
resp = await get_state()
await websocket.send_json(_make_ws_message(
"state_sync",
resp,
))
elif msg_type == "dismiss_action":
# 前端Dismiss了某个主动行为
action_id = data.get("action_id")
logger.info("[WS] Action dismissed: %s", action_id)
except Exception as exc:
logger.warning("[WS] Receive error: %s", exc)
break
except WebSocketDisconnect:
pass
finally:
hb_task.cancel()
state.ws_manager.disconnect(websocket)
# ============================================================
# SSE 流式端点(备选,适合简单前端集成)
# ============================================================
@app.get("/events")
async def sse_events():
"""
Server-Sent Events 流。
替代 WebSocket 的轻量方案,适合 EventSource API。
"""
async def event_generator():
while True:
await asyncio.sleep(5)
yield {
"event": "heartbeat",
"data": json.dumps({"t": time.time()}),
}
return EventSourceResponse(event_generator())
return app
# ================================================================
# 独立运行入口
# ================================================================
async def run_server(
host: str = "127.0.0.1",
port: int = 8765,
brain: Any = None,
emotion: Any = None,
memory: Any = None,
scheduler: Any = None,
) -> None:
"""
启动 API 服务器(独立进程模式)。
参数
----
host, port : API 监听地址。
brain, emotion, memory, scheduler : 核心模块实例None 时仅启动空 API
"""
import uvicorn
# 注入核心模块
if brain is not None:
_state.link_agent(brain)
if emotion is not None:
_state.link_emotion(emotion)
if memory is not None:
_state.link_memory(memory)
if scheduler is not None:
_state.link_scheduler(scheduler)
# 启动主动行为推送循环
if scheduler is not None:
_state._running = True
asyncio.create_task(_proactive_loop(_state))
app = create_app(_state)
config = uvicorn.Config(
app,
host=host,
port=port,
log_level="info",
)
server = uvicorn.Server(config)
logger.info("[API] EzVibe API Server starting on http://%s:%d", host, port)
await server.serve()
async def _proactive_loop(state: AppState) -> None:
"""
后台循环:定期检查调度器并推送主动行为。
运行在后台,不阻塞主服务。
"""
while state._running:
await asyncio.sleep(10) # 每 10 秒检查一次
if state.scheduler is None:
continue
try:
triggered = await state.scheduler.check_and_trigger()
for action in triggered:
await state.ws_manager.broadcast({
"type": "action",
"payload": action,
"timestamp": time.time(),
})
logger.info("[Proactive] Action triggered: %s", action.get("type"))
except Exception as exc:
logger.warning("[Proactive] Scheduler error: %s", exc)
# ================================================================
# 入口点python -m api.server
# ================================================================
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)s %(levelname)s %(message)s",
)
if not _HAS_FASTAPI:
print("ERROR: FastAPI 未安装。")
print("运行: pip install fastapi 'uvicorn[standard]' sse-starlette starlette")
import sys
sys.exit(1)
print("启动 EzVibe API Server独立模式...")
print("API 文档: http://127.0.0.1:8765/docs")
asyncio.run(run_server())