109 lines
3.4 KiB
Python
109 lines
3.4 KiB
Python
|
|
"""
|
|||
|
|
会话记忆管理器
|
|||
|
|
|
|||
|
|
管理每个用户的对话上下文,支持多轮对话状态机。
|
|||
|
|
内存缓存,10 分钟 TTL 自动清除。
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import time
|
|||
|
|
from typing import Dict, List, Optional
|
|||
|
|
|
|||
|
|
from app.utils.logger import logger
|
|||
|
|
|
|||
|
|
SESSION_TTL = 600 # 10 分钟
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UserSession:
|
|||
|
|
"""单个用户的会话上下文"""
|
|||
|
|
|
|||
|
|
def __init__(self, user_id: str):
|
|||
|
|
self.user_id = user_id
|
|||
|
|
self.state = "idle" # idle / waiting_location / waiting_confirm / waiting_close_photo
|
|||
|
|
self.pending_image_url = "" # 暂存的图片 COS key
|
|||
|
|
self.pending_analysis = "" # VLM 分析结果描述
|
|||
|
|
self.pending_alarm_type = "" # VLM 识别的告警类型
|
|||
|
|
self.pending_order_id = "" # 待结单的工单 ID
|
|||
|
|
self.pending_alarm_id = "" # 关联的告警 ID
|
|||
|
|
self.history: List[Dict] = [] # 对话历史 [{"role": "user/assistant", "content": ...}]
|
|||
|
|
self.updated_at = time.time()
|
|||
|
|
|
|||
|
|
def is_expired(self) -> bool:
|
|||
|
|
return time.time() - self.updated_at > SESSION_TTL
|
|||
|
|
|
|||
|
|
def touch(self):
|
|||
|
|
self.updated_at = time.time()
|
|||
|
|
|
|||
|
|
def reset(self):
|
|||
|
|
"""重置状态机(保留 history)"""
|
|||
|
|
self.state = "idle"
|
|||
|
|
self.pending_image_url = ""
|
|||
|
|
self.pending_analysis = ""
|
|||
|
|
self.pending_alarm_type = ""
|
|||
|
|
self.pending_order_id = ""
|
|||
|
|
self.pending_alarm_id = ""
|
|||
|
|
|
|||
|
|
def add_history(self, role: str, content):
|
|||
|
|
"""添加对话记录,content 可以是 str 或 list(多模态)"""
|
|||
|
|
self.history.append({"role": role, "content": content})
|
|||
|
|
# 保留最近 10 轮(20 条消息)
|
|||
|
|
if len(self.history) > 20:
|
|||
|
|
self.history = self.history[-20:]
|
|||
|
|
self.touch()
|
|||
|
|
|
|||
|
|
def get_history_for_vlm(self) -> List[Dict]:
|
|||
|
|
"""获取用于 VLM 调用的 history(过滤掉过大的图片内容)"""
|
|||
|
|
return self.history.copy()
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SessionManager:
|
|||
|
|
"""全局会话管理器(单例)"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._sessions: Dict[str, UserSession] = {}
|
|||
|
|
self._last_cleanup = time.time()
|
|||
|
|
|
|||
|
|
def get(self, user_id: str) -> UserSession:
|
|||
|
|
"""获取或创建用户会话"""
|
|||
|
|
# 定期清理过期会话(每 5 分钟)
|
|||
|
|
if time.time() - self._last_cleanup > 300:
|
|||
|
|
self._cleanup_expired()
|
|||
|
|
|
|||
|
|
session = self._sessions.get(user_id)
|
|||
|
|
if session and not session.is_expired():
|
|||
|
|
session.touch()
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
# 过期或不存在,创建新会话
|
|||
|
|
session = UserSession(user_id)
|
|||
|
|
self._sessions[user_id] = session
|
|||
|
|
return session
|
|||
|
|
|
|||
|
|
def clear(self, user_id: str):
|
|||
|
|
"""清除用户会话"""
|
|||
|
|
if user_id in self._sessions:
|
|||
|
|
self._sessions[user_id].reset()
|
|||
|
|
self._sessions[user_id].history.clear()
|
|||
|
|
|
|||
|
|
def _cleanup_expired(self):
|
|||
|
|
"""清理过期会话"""
|
|||
|
|
expired = [uid for uid, s in self._sessions.items() if s.is_expired()]
|
|||
|
|
for uid in expired:
|
|||
|
|
del self._sessions[uid]
|
|||
|
|
if expired:
|
|||
|
|
logger.debug(f"清理 {len(expired)} 个过期会话")
|
|||
|
|
self._last_cleanup = time.time()
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def active_count(self) -> int:
|
|||
|
|
return sum(1 for s in self._sessions.values() if not s.is_expired())
|
|||
|
|
|
|||
|
|
|
|||
|
|
_session_manager: Optional[SessionManager] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_session_manager() -> SessionManager:
|
|||
|
|
global _session_manager
|
|||
|
|
if _session_manager is None:
|
|||
|
|
_session_manager = SessionManager()
|
|||
|
|
return _session_manager
|