""" 会话记忆管理器 管理每个用户的对话上下文,支持多轮对话状态机。 内存缓存,10 分钟 TTL 自动清除。 """ import time from typing import Dict, List, Optional from app.utils.logger import logger SESSION_TTL = 600 # 10 分钟 class UserSession: """单个用户的会话上下文""" def __init__(self, user_id: str): self.user_id = user_id self.state = "idle" # idle / waiting_location / waiting_confirm / waiting_close_photo self.pending_image_url = "" # 暂存的图片 COS key self.pending_analysis = "" # VLM 分析结果描述 self.pending_alarm_type = "" # VLM 识别的告警类型 self.pending_order_id = "" # 待结单的工单 ID self.pending_alarm_id = "" # 关联的告警 ID self.history: List[Dict] = [] # 对话历史 [{"role": "user/assistant", "content": ...}] self.updated_at = time.time() def is_expired(self) -> bool: return time.time() - self.updated_at > SESSION_TTL def touch(self): self.updated_at = time.time() def reset(self): """重置状态机(保留 history)""" self.state = "idle" self.pending_image_url = "" self.pending_analysis = "" self.pending_alarm_type = "" self.pending_order_id = "" self.pending_alarm_id = "" def add_history(self, role: str, content): """添加对话记录,content 可以是 str 或 list(多模态)""" self.history.append({"role": role, "content": content}) # 保留最近 10 轮(20 条消息) if len(self.history) > 20: self.history = self.history[-20:] self.touch() def get_history_for_vlm(self) -> List[Dict]: """获取用于 VLM 调用的 history(过滤掉过大的图片内容)""" return self.history.copy() class SessionManager: """全局会话管理器(单例)""" def __init__(self): self._sessions: Dict[str, UserSession] = {} self._last_cleanup = time.time() def get(self, user_id: str) -> UserSession: """获取或创建用户会话""" # 定期清理过期会话(每 5 分钟) if time.time() - self._last_cleanup > 300: self._cleanup_expired() session = self._sessions.get(user_id) if session and not session.is_expired(): session.touch() return session # 过期或不存在,创建新会话 session = UserSession(user_id) self._sessions[user_id] = session return session def clear(self, user_id: str): """清除用户会话""" if user_id in self._sessions: self._sessions[user_id].reset() self._sessions[user_id].history.clear() def _cleanup_expired(self): """清理过期会话""" expired = [uid for uid, s in self._sessions.items() if s.is_expired()] for uid in expired: del self._sessions[uid] if expired: logger.debug(f"清理 {len(expired)} 个过期会话") self._last_cleanup = time.time() @property def active_count(self) -> int: return sum(1 for s in self._sessions.values() if not s.is_expired()) _session_manager: Optional[SessionManager] = None def get_session_manager() -> SessionManager: global _session_manager if _session_manager is None: _session_manager = SessionManager() return _session_manager