From fa0304aa47bfc34b81dac19a705fe2d8be464e24 Mon Sep 17 00:00:00 2001 From: 16337 <1633794139@qq.com> Date: Tue, 3 Feb 2026 14:26:52 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93schema=E5=88=86=E7=A6=BB=E7=A9=BA=E9=97=B4=E4=B8=8E?= =?UTF-8?q?=E4=B8=9A=E5=8A=A1=E9=85=8D=E7=BD=AE=20=20-=20=E6=96=B0?= =?UTF-8?q?=E5=A2=9Ealgorithm=5Fregistry=E5=92=8Croi=5Falgo=5Fbind?= =?UTF-8?q?=E8=A1=A8=20=20-=20roi=5Fconfigs=E7=AE=80=E5=8C=96=E4=B8=BA?= =?UTF-8?q?=E7=BA=AF=E7=A9=BA=E9=97=B4=E9=85=8D=E7=BD=AE=20=20-=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9EAlgorithmInfo/ROIAlgoBind=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- algorithms.py | 313 +++++++++++++++++++++++------- config/config_models.py | 187 +++++++++++++++++- config/database.py | 420 +++++++++++++++++++++++++++++++++------- core/config_sync.py | 152 +++++++++++++-- main.py | 55 +++--- 5 files changed, 942 insertions(+), 185 deletions(-) diff --git a/algorithms.py b/algorithms.py index 150110c..d465cde 100644 --- a/algorithms.py +++ b/algorithms.py @@ -290,6 +290,84 @@ class IntrusionAlgorithm: self.detection_start.clear() +class CrowdDetectionAlgorithm: + """人群聚集检测算法""" + + def __init__( + self, + max_count: int = 10, + cooldown_seconds: int = 300, + target_class: Optional[str] = "person", + ): + self.max_count = max_count + self.cooldown_seconds = cooldown_seconds + self.target_class = target_class + + self.last_alert_time: Dict[str, datetime] = {} + self.alert_triggered: Dict[str, bool] = {} + + def _check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool: + matched_rois = detection.get("matched_rois", []) + for roi in matched_rois: + if roi.get("roi_id") == roi_id: + return True + return False + + def _check_target_class(self, detection: Dict, target_class: Optional[str]) -> bool: + if not target_class: + return True + return detection.get("class") == target_class + + def _get_bboxes(self, tracks: List[Dict], roi_id: str) -> List[List[float]]: + bboxes = [] + for det in tracks: + if self._check_detection_in_roi(det, roi_id) and self._check_target_class(det, self.target_class): + bboxes.append(det.get("bbox", [])) + return bboxes + + def process( + self, + roi_id: str, + camera_id: str, + tracks: List[Dict], + current_time: Optional[datetime] = None, + ) -> List[Dict]: + current_time = current_time or datetime.now() + key = f"{camera_id}_{roi_id}" + + person_count = 0 + for det in tracks: + if self._check_detection_in_roi(det, roi_id) and self._check_target_class(det, self.target_class): + person_count += 1 + + if person_count <= self.max_count: + self.alert_triggered[key] = False + return [] + + if self.alert_triggered.get(key, False): + elapsed_since_alert = (current_time - self.last_alert_time.get(key, datetime.min)).total_seconds() + if elapsed_since_alert < self.cooldown_seconds: + return [] + self.alert_triggered[key] = False + + bboxes = self._get_bboxes(tracks, roi_id) + self.last_alert_time[key] = current_time + self.alert_triggered[key] = True + + return [{ + "roi_id": roi_id, + "camera_id": camera_id, + "bbox": bboxes[0] if bboxes else [], + "alert_type": "crowd_detection", + "message": f"检测到人群聚集,当前人数: {person_count}", + "count": person_count, + }] + + def reset(self): + self.last_alert_time.clear() + self.alert_triggered.clear() + + class AlgorithmManager: def __init__(self, working_hours: Optional[List[Dict]] = None): self.algorithms: Dict[str, Dict[str, Any]] = {} @@ -308,6 +386,11 @@ class AlgorithmManager: "confirm_seconds": 5, "target_class": None, }, + "crowd_detection": { + "max_count": 10, + "cooldown_seconds": 300, + "target_class": "person", + }, } self._pubsub = None @@ -358,6 +441,13 @@ class AlgorithmManager: self.reload_algorithm(roi_id) else: self.reload_all_algorithms() + elif data.get("type") == "bind": + bind_ids = data.get("ids", []) + if bind_ids: + for bind_id in bind_ids: + self.reload_bind_algorithm(bind_id) + else: + self.reload_all_algorithms() except Exception as e: logger.error(f"处理配置更新消息失败: {e}") except Exception as e: @@ -372,59 +462,95 @@ class AlgorithmManager: self._pubsub_thread.join(timeout=5) logger.info("配置订阅已停止") - def load_from_redis(self, roi_id: str) -> bool: - """从Redis加载单个ROI的算法配置""" + def load_bind_from_redis(self, bind_id: str) -> bool: + """从Redis加载单个绑定配置的算法""" try: from core.config_sync import get_config_sync_manager config_manager = get_config_sync_manager() - roi_config = config_manager.get_roi_from_redis(roi_id) + bind_config = config_manager.get_algo_bind_from_redis(bind_id) - if not roi_config: + if not bind_config: return False with self._update_lock: - algorithm_type = roi_config.get("algorithm_type", "leave_post") + roi_id = bind_config.get("roi_id") + algo_code = bind_config.get("algo_code", "leave_post") + params = bind_config.get("params", {}) - if algorithm_type == "leave_post": - params = { - "working_hours": roi_config.get("working_hours"), - "confirm_on_duty_sec": roi_config.get("confirm_on_duty_sec", 10), - "confirm_leave_sec": roi_config.get("confirm_leave_sec", 10), - "cooldown_sec": roi_config.get("cooldown_sec", 300), - "target_class": roi_config.get("target_class", "person"), + if roi_id not in self.algorithms: + self.algorithms[roi_id] = {} + + key = f"{roi_id}_{bind_id}" + + if algo_code == "leave_post": + algo_params = { + "confirm_on_duty_sec": params.get("confirm_on_duty_sec", 10), + "confirm_leave_sec": params.get("confirm_leave_sec", 10), + "cooldown_sec": params.get("cooldown_sec", 300), + "working_hours": params.get("working_hours", []), + "target_class": params.get("target_class", bind_config.get("target_class", "person")), } - if roi_id in self.algorithms and "leave_post" in self.algorithms[roi_id]: - algo = self.algorithms[roi_id]["leave_post"] - algo.confirm_on_duty_sec = params["confirm_on_duty_sec"] - algo.confirm_leave_sec = params["confirm_leave_sec"] - algo.cooldown_sec = params["cooldown_sec"] - algo.target_class = params["target_class"] - if params["working_hours"]: - algo.working_hours = params["working_hours"] - logger.info(f"已热更新算法参数: {roi_id}") + if key in self.algorithms.get(roi_id, {}) and "leave_post" in self.algorithms[roi_id].get(key, {}): + algo = self.algorithms[roi_id][key]["leave_post"] + algo.confirm_on_duty_sec = algo_params["confirm_on_duty_sec"] + algo.confirm_leave_sec = algo_params["confirm_leave_sec"] + algo.cooldown_sec = algo_params["cooldown_sec"] + algo.target_class = algo_params["target_class"] + if algo_params["working_hours"]: + algo.working_hours = algo_params["working_hours"] + logger.info(f"已热更新算法参数: {key}") else: - self.register_algorithm(roi_id, "leave_post", params) - logger.info(f"已从Redis加载算法: {roi_id}") - else: - params = {} - if roi_id in self.algorithms and algorithm_type in self.algorithms[roi_id]: - pass - else: - self.register_algorithm(roi_id, algorithm_type, params) - logger.info(f"已从Redis加载算法: {roi_id}") + self.algorithms[roi_id][key] = {} + self.algorithms[roi_id][key]["leave_post"] = LeavePostAlgorithm( + confirm_on_duty_sec=algo_params["confirm_on_duty_sec"], + confirm_leave_sec=algo_params["confirm_leave_sec"], + cooldown_sec=algo_params["cooldown_sec"], + working_hours=algo_params["working_hours"], + target_class=algo_params["target_class"], + ) + logger.info(f"已从Redis加载算法: {key}") + elif algo_code == "intrusion": + algo_params = { + "cooldown_seconds": params.get("cooldown_seconds", 120), + "confirm_seconds": params.get("confirm_seconds", 5), + "target_class": params.get("target_class", bind_config.get("target_class")), + } + self.algorithms[roi_id][key] = {} + self.algorithms[roi_id][key]["intrusion"] = IntrusionAlgorithm( + cooldown_seconds=algo_params["cooldown_seconds"], + confirm_seconds=algo_params["confirm_seconds"], + target_class=algo_params["target_class"], + ) + logger.info(f"已从Redis加载算法: {key}") return True except Exception as e: logger.error(f"从Redis加载算法配置失败: {e}") return False + def reload_bind_algorithm(self, bind_id: str) -> bool: + """重新加载单个绑定的算法配置""" + return self.load_bind_from_redis(bind_id) + def reload_algorithm(self, roi_id: str) -> bool: - """重新加载单个ROI的算法配置""" - if roi_id not in self.algorithms: - return self.load_from_redis(roi_id) - - self.reset_algorithm(roi_id) - return self.load_from_redis(roi_id) + """重新加载单个ROI的所有算法绑定配置""" + try: + from core.config_sync import get_config_sync_manager + config_manager = get_config_sync_manager() + bindings = config_manager.get_bindings_from_redis(roi_id) + + if not bindings: + return False + + for bind in bindings: + bind_id = bind.get("bind_id") + self.reset_algorithm(roi_id, bind_id) + self.load_bind_from_redis(bind_id) + + return True + except Exception as e: + logger.error(f"重新加载ROI算法配置失败: {e}") + return False def reload_all_algorithms(self) -> int: """重新加载所有算法配置""" @@ -432,11 +558,13 @@ class AlgorithmManager: try: from core.config_sync import get_config_sync_manager config_manager = get_config_sync_manager() - roi_configs = config_manager.get_all_roi_configs() + bindings = config_manager.get_bindings_from_redis("") - for roi_config in roi_configs: - roi_id = roi_config.get("roi_id") - if self.reload_algorithm(roi_id): + for bind in bindings: + bind_id = bind.get("bind_id") + roi_id = bind.get("roi_id") + self.reset_algorithm(roi_id, bind_id) + if self.load_bind_from_redis(bind_id): count += 1 logger.info(f"已重新加载 {count} 个算法配置") @@ -448,22 +576,29 @@ class AlgorithmManager: def register_algorithm( self, roi_id: str, + bind_id: str, algorithm_type: str, params: Optional[Dict[str, Any]] = None, ): - if roi_id in self.algorithms and algorithm_type in self.algorithms[roi_id]: + """注册算法(支持绑定ID)""" + key = f"{roi_id}_{bind_id}" + + if key in self.algorithms and algorithm_type in self.algorithms[key]: return - + if roi_id not in self.algorithms: self.algorithms[roi_id] = {} - + + if key not in self.algorithms[roi_id]: + self.algorithms[roi_id][key] = {} + algo_params = self.default_params.get(algorithm_type, {}).copy() if params: algo_params.update(params) if algorithm_type == "leave_post": roi_working_hours = algo_params.get("working_hours") or self.working_hours - self.algorithms[roi_id]["leave_post"] = LeavePostAlgorithm( + self.algorithms[roi_id][key]["leave_post"] = LeavePostAlgorithm( confirm_on_duty_sec=algo_params.get("confirm_on_duty_sec", 10), confirm_leave_sec=algo_params.get("confirm_leave_sec", 10), cooldown_sec=algo_params.get("cooldown_sec", 300), @@ -471,21 +606,31 @@ class AlgorithmManager: target_class=algo_params.get("target_class", "person"), ) elif algorithm_type == "intrusion": - self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm( + self.algorithms[roi_id][key]["intrusion"] = IntrusionAlgorithm( cooldown_seconds=algo_params.get("cooldown_seconds", 120), confirm_seconds=algo_params.get("confirm_seconds", 5), target_class=algo_params.get("target_class"), ) - + elif algorithm_type == "crowd_detection": + from algorithms import CrowdDetectionAlgorithm + self.algorithms[roi_id][key]["crowd_detection"] = CrowdDetectionAlgorithm( + max_count=algo_params.get("max_count", 10), + cooldown_seconds=algo_params.get("cooldown_seconds", 300), + target_class=algo_params.get("target_class", "person"), + ) + def process( self, roi_id: str, + bind_id: str, camera_id: str, algorithm_type: str, tracks: List[Dict], current_time: Optional[datetime] = None, ) -> List[Dict]: - algo = self.algorithms.get(roi_id, {}).get(algorithm_type) + """处理检测结果(支持绑定ID)""" + key = f"{roi_id}_{bind_id}" + algo = self.algorithms.get(roi_id, {}).get(key, {}).get(algorithm_type) if algo is None: return [] return algo.process(roi_id, camera_id, tracks, current_time) @@ -493,47 +638,69 @@ class AlgorithmManager: def update_roi_params( self, roi_id: str, + bind_id: str, algorithm_type: str, params: Dict[str, Any], ): - if roi_id in self.algorithms and algorithm_type in self.algorithms[roi_id]: - algo = self.algorithms[roi_id][algorithm_type] + """更新算法参数(支持绑定ID)""" + key = f"{roi_id}_{bind_id}" + if roi_id in self.algorithms and key in self.algorithms[roi_id] and algorithm_type in self.algorithms[roi_id][key]: + algo = self.algorithms[roi_id][key][algorithm_type] for key, value in params.items(): if hasattr(algo, key): setattr(algo, key, value) - - def reset_algorithm(self, roi_id: str, algorithm_type: Optional[str] = None): + + def reset_algorithm(self, roi_id: str, bind_id: Optional[str] = None): + """重置算法状态(支持绑定ID)""" if roi_id not in self.algorithms: return - - if algorithm_type: - if algorithm_type in self.algorithms[roi_id]: - self.algorithms[roi_id][algorithm_type].reset() + + if bind_id: + key = f"{roi_id}_{bind_id}" + if key in self.algorithms[roi_id]: + if algorithm_type in self.algorithms[roi_id][key]: + self.algorithms[roi_id][key][algorithm_type].reset() else: - for algo in self.algorithms[roi_id].values(): - algo.reset() - + for key in self.algorithms[roi_id]: + for algo in self.algorithms[roi_id][key].values(): + algo.reset() + def reset_all(self): + """重置所有算法""" for roi_algorithms in self.algorithms.values(): - for algo in roi_algorithms.values(): - algo.reset() - + for bind_algorithms in roi_algorithms.values(): + for algo in bind_algorithms.values(): + algo.reset() + def remove_roi(self, roi_id: str): + """移除ROI的所有算法""" if roi_id in self.algorithms: - self.reset_algorithm(roi_id) + for key in list(self.algorithms[roi_id].keys()): + self.reset_algorithm(roi_id, key.split("_")[-1] if "_" in key else None) del self.algorithms[roi_id] - + + def remove_bind(self, roi_id: str, bind_id: str): + """移除绑定的算法""" + key = f"{roi_id}_{bind_id}" + if roi_id in self.algorithms and key in self.algorithms[roi_id]: + for algo in self.algorithms[roi_id][key].values(): + algo.reset() + del self.algorithms[roi_id][key] + def get_status(self, roi_id: str) -> Dict[str, Any]: + """获取算法状态""" status = {} if roi_id in self.algorithms: - for algo_type, algo in self.algorithms[roi_id].items(): - if algo_type == "leave_post": - status[algo_type] = { - "state": getattr(algo, "state", "WAITING"), - "alarm_sent": getattr(algo, "alarm_sent", False), - } - else: - status[algo_type] = { - "detection_count": len(getattr(algo, "detection_start", {})), - } + for key, bind_algorithms in self.algorithms[roi_id].items(): + bind_id = key.split("_", 1)[-1] if "_" in key else "" + for algo_type, algo in bind_algorithms.items(): + if algo_type == "leave_post": + status[f"{algo_type}_{bind_id}"] = { + "state": getattr(algo, "state", "WAITING"), + "alarm_sent": getattr(algo, "alarm_sent", False), + } + else: + status[f"{algo_type}_{bind_id}"] = { + "detection_count": len(getattr(algo, "detection_start", {})), + } return status diff --git a/config/config_models.py b/config/config_models.py index 1340258..3cec251 100644 --- a/config/config_models.py +++ b/config/config_models.py @@ -215,7 +215,8 @@ class AlertInfo: alert_id: str camera_id: str roi_id: str - alert_type: str + bind_id: Optional[str] = None + alert_type: str = "detection" target_class: Optional[str] = None confidence: Optional[float] = None bbox: Optional[List[float]] = None @@ -224,6 +225,7 @@ class AlertInfo: level: AlertLevel = AlertLevel.MEDIUM timestamp: Optional[str] = None extra_data: Optional[Dict[str, Any]] = None + detections: Optional[str] = None def to_dict(self) -> Dict[str, Any]: """转换为字典""" @@ -231,6 +233,7 @@ class AlertInfo: "alert_id": self.alert_id, "camera_id": self.camera_id, "roi_id": self.roi_id, + "bind_id": self.bind_id, "alert_type": self.alert_type, "target_class": self.target_class, "confidence": self.confidence, @@ -240,6 +243,7 @@ class AlertInfo: "level": self.level.value if isinstance(self.level, AlertLevel) else self.level, "timestamp": self.timestamp, "extra_data": self.extra_data, + "detections": self.detections, } def to_json(self) -> str: @@ -291,3 +295,184 @@ class ConfigVersion: "description": self.description, "affected_items": self.affected_items, } + + +@dataclass +class AlgorithmInfo: + """算法配置信息数据模型""" + algo_code: str + algo_name: str + target_class: str = "person" + param_schema: Optional[Dict[str, Any]] = None + description: Optional[str] = None + is_active: bool = True + created_at: Optional[str] = None + updated_at: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "algo_code": self.algo_code, + "algo_name": self.algo_name, + "target_class": self.target_class, + "param_schema": self.param_schema, + "description": self.description, + "is_active": self.is_active, + "created_at": self.created_at, + "updated_at": self.updated_at, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'AlgorithmInfo': + """从字典创建实例""" + return cls( + algo_code=data.get("algo_code", ""), + algo_name=data.get("algo_name", ""), + target_class=data.get("target_class", "person"), + param_schema=data.get("param_schema"), + description=data.get("description"), + is_active=data.get("is_active", True), + created_at=data.get("created_at"), + updated_at=data.get("updated_at"), + ) + + +@dataclass +class ROIAlgoBind: + """ROI与算法绑定关系数据模型""" + bind_id: str + roi_id: str + algo_code: str + params: Dict[str, Any] = field(default_factory=dict) + priority: int = 0 + enabled: bool = True + created_at: Optional[str] = None + updated_at: Optional[str] = None + algo_name: Optional[str] = None + target_class: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "bind_id": self.bind_id, + "roi_id": self.roi_id, + "algo_code": self.algo_code, + "params": self.params, + "priority": self.priority, + "enabled": self.enabled, + "created_at": self.created_at, + "updated_at": self.updated_at, + "algo_name": self.algo_name, + "target_class": self.target_class, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ROIAlgoBind': + """从字典创建实例""" + return cls( + bind_id=data.get("bind_id", ""), + roi_id=data.get("roi_id", ""), + algo_code=data.get("algo_code", ""), + params=data.get("params", {}), + priority=data.get("priority", 0), + enabled=data.get("enabled", True), + created_at=data.get("created_at"), + updated_at=data.get("updated_at"), + algo_name=data.get("algo_name"), + target_class=data.get("target_class"), + ) + + +@dataclass +class ROIInfoNew: + """ROI区域信息数据模型(新版本,包含绑定信息)""" + roi_id: str + camera_id: str + roi_type: ROIType + coordinates: List[List[float]] + enabled: bool = True + priority: int = 0 + extra_params: Optional[Dict[str, Any]] = None + bindings: List[ROIAlgoBind] = field(default_factory=list) + updated_at: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "roi_id": self.roi_id, + "camera_id": self.camera_id, + "roi_type": self.roi_type.value if isinstance(self.roi_type, ROIType) else self.roi_type, + "coordinates": self.coordinates, + "enabled": self.enabled, + "priority": self.priority, + "extra_params": self.extra_params, + "bindings": [b.to_dict() for b in self.bindings], + "updated_at": self.updated_at, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ROIInfoNew': + """从字典创建实例""" + roi_type_str = data.get("roi_type", "polygon") + roi_type = ROIType(roi_type_str) if roi_type_str in [e.value for e in ROIType] else ROIType.POLYGON + + bindings = data.get("bindings", []) + if isinstance(bindings, list): + bindings = [ROIAlgoBind.from_dict(b) for b in bindings] + + return cls( + roi_id=data.get("roi_id", ""), + camera_id=data.get("camera_id", ""), + roi_type=roi_type, + coordinates=data.get("coordinates", []), + enabled=data.get("enabled", True), + priority=data.get("priority", 0), + extra_params=data.get("extra_params"), + bindings=bindings, + updated_at=data.get("updated_at"), + ) + + def is_point_inside(self, point: List[float]) -> bool: + """判断点是否在ROI区域内""" + if self.roi_type == ROIType.RECTANGLE: + return self._is_point_in_rectangle(point) + elif self.roi_type == ROIType.POLYGON: + return self._is_point_in_polygon(point) + return False + + def _is_point_in_rectangle(self, point: List[float]) -> bool: + """判断点是否在矩形区域内""" + if len(self.coordinates) < 2: + return False + + x, y = point[0], point[1] + x1, y1 = self.coordinates[0] + x2, y2 = self.coordinates[1] + + left = min(x1, x2) + right = max(x1, x2) + top = min(y1, y2) + bottom = max(y1, y2) + + return left <= x <= right and top <= y <= bottom + + def _is_point_in_polygon(self, point: List[float]) -> bool: + """判断点是否在多边形区域内(射线法)""" + if len(self.coordinates) < 3: + return False + + x, y = point[0], point[1] + n = len(self.coordinates) + inside = False + + j = n - 1 + for i in range(n): + xi, yi = self.coordinates[i] + xj, yj = self.coordinates[j] + + if ((yi > y) != (yj > y)) and (x < (xj - xi) * (y - yi) / (yj - yi) + xi): + inside = not inside + + j = i + + return inside diff --git a/config/database.py b/config/database.py index 3fca865..daa2fc1 100644 --- a/config/database.py +++ b/config/database.py @@ -9,6 +9,7 @@ SQLite 数据库模块 """ import os +import json import sqlite3 import threading import queue @@ -39,7 +40,8 @@ class AlertRecord: alert_id: str camera_id: str roi_id: str - alert_type: str + bind_id: Optional[str] = None # 关联 roi_algo_bind 表 + alert_type: str = "detection" target_class: Optional[str] = None confidence: Optional[float] = None bbox: Optional[List[float]] = None @@ -49,6 +51,7 @@ class AlertRecord: created_at: datetime = field(default_factory=datetime.now) processed_at: Optional[datetime] = None duration_minutes: Optional[float] = None + detections: Optional[str] = None # JSON格式的检测结果 class SQLiteManager: @@ -108,12 +111,73 @@ class SQLiteManager: cursor = self._conn.cursor() + cursor.execute(""" + CREATE TABLE IF NOT EXISTS algorithm_registry ( + algo_code TEXT PRIMARY KEY, + algo_name TEXT NOT NULL, + target_class TEXT DEFAULT 'person', + param_schema TEXT, + description TEXT, + is_active BOOLEAN DEFAULT 1, + created_at TEXT, + updated_at TEXT + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_algo_active + ON algorithm_registry(is_active) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS roi_configs ( + roi_id TEXT PRIMARY KEY, + camera_id TEXT NOT NULL, + roi_type TEXT NOT NULL, + coordinates TEXT NOT NULL, + enabled BOOLEAN DEFAULT 1, + priority INTEGER DEFAULT 0, + extra_params TEXT, + updated_at TEXT + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_roi_camera + ON roi_configs(camera_id) + """) + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS roi_algo_bind ( + bind_id TEXT PRIMARY KEY, + roi_id TEXT NOT NULL, + algo_code TEXT NOT NULL, + params TEXT NOT NULL, + priority INTEGER DEFAULT 0, + enabled BOOLEAN DEFAULT 1, + created_at TEXT, + updated_at TEXT, + FOREIGN KEY (roi_id) REFERENCES roi_configs(roi_id) ON DELETE CASCADE, + FOREIGN KEY (algo_code) REFERENCES algorithm_registry(algo_code) ON DELETE RESTRICT + ) + """) + + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_bind_roi + ON roi_algo_bind(roi_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_bind_algo + ON roi_algo_bind(algo_code) + """) + cursor.execute(""" CREATE TABLE IF NOT EXISTS alert_records ( id INTEGER PRIMARY KEY AUTOINCREMENT, alert_id TEXT UNIQUE NOT NULL, camera_id TEXT NOT NULL, roi_id TEXT NOT NULL, + bind_id TEXT, alert_type TEXT NOT NULL, target_class TEXT, confidence REAL, @@ -123,7 +187,9 @@ class SQLiteManager: status TEXT DEFAULT 'pending', created_at TEXT NOT NULL, processed_at TEXT, - duration_minutes REAL + duration_minutes REAL, + detections TEXT, + FOREIGN KEY (bind_id) REFERENCES roi_algo_bind(bind_id) ON DELETE SET NULL ) """) @@ -139,6 +205,10 @@ class SQLiteManager: CREATE INDEX IF NOT EXISTS idx_alert_status ON alert_records(status) """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_alert_bind + ON alert_records(bind_id) + """) cursor.execute(""" CREATE TABLE IF NOT EXISTS camera_configs ( @@ -148,31 +218,12 @@ class SQLiteManager: status BOOLEAN DEFAULT 1, enabled BOOLEAN DEFAULT 1, location TEXT, + roi_group_id TEXT, extra_params TEXT, updated_at TEXT ) """) - cursor.execute(""" - CREATE TABLE IF NOT EXISTS roi_configs ( - roi_id TEXT PRIMARY KEY, - camera_id TEXT NOT NULL, - roi_type TEXT NOT NULL, - coordinates TEXT NOT NULL, - algorithm_type TEXT NOT NULL, - alert_threshold INTEGER DEFAULT 3, - alert_cooldown INTEGER DEFAULT 300, - enabled BOOLEAN DEFAULT 1, - extra_params TEXT, - working_hours TEXT, - confirm_on_duty_sec INTEGER DEFAULT 10, - confirm_leave_sec INTEGER DEFAULT 10, - cooldown_sec INTEGER DEFAULT 300, - target_class TEXT DEFAULT 'person', - updated_at TEXT - ) - """) - cursor.execute(""" CREATE TABLE IF NOT EXISTS config_update_log ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -186,6 +237,66 @@ class SQLiteManager: """) self._conn.commit() + + self._init_default_algorithms() + + def _init_default_algorithms(self): + """初始化默认算法配置""" + try: + cursor = self._conn.cursor() + + algorithms = [ + { + 'algo_code': 'leave_post', + 'algo_name': '离岗检测', + 'target_class': 'person', + 'param_schema': json.dumps({ + "confirm_on_duty_sec": {"type": "int", "default": 10, "min": 1}, + "confirm_leave_sec": {"type": "int", "default": 10, "min": 1}, + "cooldown_sec": {"type": "int", "default": 300, "min": 0}, + "working_hours": {"type": "list", "default": []}, + }), + 'description': '检测人员是否在岗,支持工作时间段配置' + }, + { + 'algo_code': 'intrusion', + 'algo_name': '周界入侵检测', + 'target_class': 'person', + 'param_schema': json.dumps({ + "cooldown_seconds": {"type": "int", "default": 120, "min": 0}, + "confirm_seconds": {"type": "int", "default": 5, "min": 1}, + }), + 'description': '检测人员进入指定区域,支持确认时间和冷却时间配置' + }, + { + 'algo_code': 'crowd_detection', + 'algo_name': '人群聚集检测', + 'target_class': 'person', + 'param_schema': json.dumps({ + "max_count": {"type": "int", "default": 10, "min": 1}, + "cooldown_seconds": {"type": "int", "default": 300, "min": 0}, + }), + 'description': '检测区域内人员数量是否超过阈值' + }, + ] + + for algo in algorithms: + cursor.execute(""" + INSERT OR IGNORE INTO algorithm_registry ( + algo_code, algo_name, target_class, param_schema, description, + is_active, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, 1, ?, ?) + """, ( + algo['algo_code'], algo['algo_name'], algo['target_class'], + algo['param_schema'], algo['description'], + datetime.now().isoformat(), datetime.now().isoformat() + )) + + self._conn.commit() + logger.info(f"已初始化 {len(algorithms)} 个默认算法配置") + + except Exception as e: + logger.error(f"初始化默认算法失败: {e}") def _start_background_threads(self): """启动后台线程""" @@ -242,15 +353,16 @@ class SQLiteManager: for record in batch: cursor.execute(""" INSERT OR REPLACE INTO alert_records ( - alert_id, camera_id, roi_id, alert_type, + alert_id, camera_id, roi_id, bind_id, alert_type, target_class, confidence, bbox, message, image_path, status, created_at, processed_at, - duration_minutes - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + duration_minutes, detections + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( record['alert_id'], record['camera_id'], record['roi_id'], + record.get('bind_id'), record['alert_type'], record.get('target_class'), record.get('confidence'), @@ -261,6 +373,7 @@ class SQLiteManager: record['created_at'], record.get('processed_at'), record.get('duration_minutes'), + record.get('detections'), )) self._conn.commit() @@ -285,6 +398,7 @@ class SQLiteManager: 'alert_id': alert.alert_id, 'camera_id': alert.camera_id, 'roi_id': alert.roi_id, + 'bind_id': alert.bind_id, 'alert_type': alert.alert_type, 'target_class': alert.target_class, 'confidence': alert.confidence, @@ -295,6 +409,7 @@ class SQLiteManager: 'created_at': alert.created_at.isoformat(), 'processed_at': alert.processed_at.isoformat() if alert.processed_at else None, 'duration_minutes': alert.duration_minutes, + 'detections': alert.detections, } self._write_queue.put(record) @@ -427,14 +542,15 @@ class SQLiteManager: cursor.execute(""" INSERT OR REPLACE INTO camera_configs ( camera_id, rtsp_url, camera_name, status, enabled, - location, extra_params, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + location, roi_group_id, extra_params, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( camera_id, rtsp_url, kwargs.get('camera_name'), kwargs.get('status', True), kwargs.get('enabled', True), kwargs.get('location'), + kwargs.get('roi_group_id'), str(kwargs.get('extra_params')) if kwargs.get('extra_params') else None, now )) @@ -452,7 +568,7 @@ class SQLiteManager: row = cursor.fetchone() if row: columns = ['camera_id', 'rtsp_url', 'camera_name', 'status', - 'enabled', 'location', 'extra_params', 'updated_at'] + 'enabled', 'location', 'roi_group_id', 'extra_params', 'updated_at'] return dict(zip(columns, row)) return None except Exception as e: @@ -465,7 +581,7 @@ class SQLiteManager: cursor = self._conn.cursor() cursor.execute("SELECT * FROM camera_configs ORDER BY camera_id") columns = ['camera_id', 'rtsp_url', 'camera_name', 'status', - 'enabled', 'location', 'extra_params', 'updated_at'] + 'enabled', 'location', 'roi_group_id', 'extra_params', 'updated_at'] return [dict(zip(columns, row)) for row in cursor.fetchall()] except Exception as e: logger.error(f"获取所有摄像头配置失败: {e}") @@ -483,29 +599,21 @@ class SQLiteManager: return False def save_roi_config(self, roi_id: str, camera_id: str, roi_type: str, - coordinates: List, algorithm_type: str, **kwargs) -> bool: - """保存ROI配置""" + coordinates: List, **kwargs) -> bool: + """保存ROI配置(空间信息,不包含业务参数)""" try: cursor = self._conn.cursor() now = datetime.now().isoformat() cursor.execute(""" INSERT OR REPLACE INTO roi_configs ( - roi_id, camera_id, roi_type, coordinates, algorithm_type, - alert_threshold, alert_cooldown, enabled, extra_params, - working_hours, confirm_on_duty_sec, confirm_leave_sec, - cooldown_sec, target_class, updated_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + roi_id, camera_id, roi_type, coordinates, + enabled, priority, extra_params, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, ( - roi_id, camera_id, roi_type, str(coordinates), algorithm_type, - kwargs.get('alert_threshold', 3), - kwargs.get('alert_cooldown', 300), + roi_id, camera_id, roi_type, str(coordinates), kwargs.get('enabled', True), + kwargs.get('priority', 0), str(kwargs.get('extra_params')) if kwargs.get('extra_params') else None, - str(kwargs.get('working_hours')) if kwargs.get('working_hours') else None, - kwargs.get('confirm_on_duty_sec', 10), - kwargs.get('confirm_leave_sec', 10), - kwargs.get('cooldown_sec', 300), - kwargs.get('target_class', 'person'), now )) self._conn.commit() @@ -522,20 +630,12 @@ class SQLiteManager: row = cursor.fetchone() if row: columns = ['roi_id', 'camera_id', 'roi_type', 'coordinates', - 'algorithm_type', 'alert_threshold', 'alert_cooldown', - 'enabled', 'extra_params', 'working_hours', - 'confirm_on_duty_sec', 'confirm_leave_sec', 'cooldown_sec', - 'target_class', 'updated_at'] + 'enabled', 'priority', 'extra_params', 'updated_at'] result = dict(zip(columns, row)) try: result['coordinates'] = eval(result['coordinates']) except: pass - try: - if result.get('working_hours'): - result['working_hours'] = eval(result['working_hours']) - except: - pass return result return None except Exception as e: @@ -548,10 +648,7 @@ class SQLiteManager: cursor = self._conn.cursor() cursor.execute("SELECT * FROM roi_configs WHERE camera_id = ?", (camera_id,)) columns = ['roi_id', 'camera_id', 'roi_type', 'coordinates', - 'algorithm_type', 'alert_threshold', 'alert_cooldown', - 'enabled', 'extra_params', 'working_hours', - 'confirm_on_duty_sec', 'confirm_leave_sec', 'cooldown_sec', - 'target_class', 'updated_at'] + 'enabled', 'priority', 'extra_params', 'updated_at'] results = [] for row in cursor.fetchall(): r = dict(zip(columns, row)) @@ -559,11 +656,6 @@ class SQLiteManager: r['coordinates'] = eval(r['coordinates']) except: pass - try: - if r.get('working_hours'): - r['working_hours'] = eval(r['working_hours']) - except: - pass results.append(r) return results except Exception as e: @@ -576,10 +668,7 @@ class SQLiteManager: cursor = self._conn.cursor() cursor.execute("SELECT * FROM roi_configs ORDER BY camera_id, roi_id") columns = ['roi_id', 'camera_id', 'roi_type', 'coordinates', - 'algorithm_type', 'alert_threshold', 'alert_cooldown', - 'enabled', 'extra_params', 'working_hours', - 'confirm_on_duty_sec', 'confirm_leave_sec', 'cooldown_sec', - 'target_class', 'updated_at'] + 'enabled', 'priority', 'extra_params', 'updated_at'] results = [] for row in cursor.fetchall(): r = dict(zip(columns, row)) @@ -587,11 +676,6 @@ class SQLiteManager: r['coordinates'] = eval(r['coordinates']) except: pass - try: - if r.get('working_hours'): - r['working_hours'] = eval(r['working_hours']) - except: - pass results.append(r) return results except Exception as e: @@ -609,6 +693,200 @@ class SQLiteManager: logger.error(f"删除ROI配置失败: {e}") return False + def save_algorithm(self, algo_code: str, algo_name: str, target_class: str = "person", + param_schema: Optional[str] = None, description: Optional[str] = None, + is_active: bool = True) -> bool: + """保存算法配置""" + try: + cursor = self._conn.cursor() + now = datetime.now().isoformat() + cursor.execute(""" + INSERT OR REPLACE INTO algorithm_registry ( + algo_code, algo_name, target_class, param_schema, description, + is_active, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, ( + algo_code, algo_name, target_class, param_schema, description, + is_active, now, now + )) + self._conn.commit() + return True + except Exception as e: + logger.error(f"保存算法配置失败: {e}") + return False + + def get_algorithm(self, algo_code: str) -> Optional[Dict[str, Any]]: + """获取算法配置""" + try: + cursor = self._conn.cursor() + cursor.execute("SELECT * FROM algorithm_registry WHERE algo_code = ?", (algo_code,)) + row = cursor.fetchone() + if row: + columns = ['algo_code', 'algo_name', 'target_class', 'param_schema', + 'description', 'is_active', 'created_at', 'updated_at'] + result = dict(zip(columns, row)) + try: + if result.get('param_schema'): + result['param_schema'] = json.loads(result['param_schema']) + except: + pass + return result + return None + except Exception as e: + logger.error(f"获取算法配置失败: {e}") + return None + + def get_all_algorithms(self, active_only: bool = True) -> List[Dict[str, Any]]: + """获取所有算法配置""" + try: + cursor = self._conn.cursor() + query = "SELECT * FROM algorithm_registry" + if active_only: + query += " WHERE is_active = 1" + cursor.execute(query) + columns = ['algo_code', 'algo_name', 'target_class', 'param_schema', + 'description', 'is_active', 'created_at', 'updated_at'] + results = [] + for row in cursor.fetchall(): + r = dict(zip(columns, row)) + try: + if r.get('param_schema'): + r['param_schema'] = json.loads(r['param_schema']) + except: + pass + results.append(r) + return results + except Exception as e: + logger.error(f"获取所有算法配置失败: {e}") + return [] + + def save_roi_algo_bind(self, bind_id: str, roi_id: str, algo_code: str, + params: Dict[str, Any], priority: int = 0, + enabled: bool = True) -> bool: + """保存ROI与算法的绑定关系""" + try: + cursor = self._conn.cursor() + now = datetime.now().isoformat() + cursor.execute(""" + INSERT OR REPLACE INTO roi_algo_bind ( + bind_id, roi_id, algo_code, params, priority, + enabled, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, ( + bind_id, roi_id, algo_code, json.dumps(params), + priority, enabled, now, now + )) + self._conn.commit() + return True + except Exception as e: + logger.error(f"保存ROI算法绑定失败: {e}") + return False + + def get_roi_algo_bind(self, bind_id: str) -> Optional[Dict[str, Any]]: + """获取ROI算法绑定配置""" + try: + cursor = self._conn.cursor() + cursor.execute("SELECT * FROM roi_algo_bind WHERE bind_id = ?", (bind_id,)) + row = cursor.fetchone() + if row: + columns = ['bind_id', 'roi_id', 'algo_code', 'params', + 'priority', 'enabled', 'created_at', 'updated_at'] + result = dict(zip(columns, row)) + try: + if result.get('params'): + result['params'] = json.loads(result['params']) + except: + pass + return result + return None + except Exception as e: + logger.error(f"获取ROI算法绑定失败: {e}") + return None + + def get_bindings_by_roi(self, roi_id: str) -> List[Dict[str, Any]]: + """获取指定ROI的所有算法绑定""" + try: + cursor = self._conn.cursor() + cursor.execute(""" + SELECT b.*, a.algo_name, a.target_class + FROM roi_algo_bind b + LEFT JOIN algorithm_registry a ON b.algo_code = a.algo_code + WHERE b.roi_id = ? AND b.enabled = 1 + ORDER BY b.priority DESC + """, (roi_id,)) + results = [] + for row in cursor.fetchall(): + result = dict(zip( + ['bind_id', 'roi_id', 'algo_code', 'params', 'priority', + 'enabled', 'created_at', 'updated_at', 'algo_name', 'target_class'], + row + )) + try: + if result.get('params'): + result['params'] = json.loads(result['params']) + except: + pass + results.append(result) + return results + except Exception as e: + logger.error(f"获取ROI算法绑定失败: {e}") + return [] + + def get_bindings_by_camera(self, camera_id: str) -> List[Dict[str, Any]]: + """获取指定摄像头的所有ROI算法绑定""" + try: + cursor = self._conn.cursor() + cursor.execute(""" + SELECT b.*, a.algo_name, a.target_class, r.roi_type, r.coordinates + FROM roi_algo_bind b + LEFT JOIN algorithm_registry a ON b.algo_code = a.algo_code + LEFT JOIN roi_configs r ON b.roi_id = r.roi_id + WHERE r.camera_id = ? AND b.enabled = 1 AND r.enabled = 1 + ORDER BY r.priority DESC, b.priority DESC + """, (camera_id,)) + results = [] + for row in cursor.fetchall(): + result = dict(zip( + ['bind_id', 'roi_id', 'algo_code', 'params', 'priority', + 'enabled', 'created_at', 'updated_at', 'algo_name', 'target_class', + 'roi_type', 'coordinates'], + row + )) + try: + if result.get('params'): + result['params'] = json.loads(result['params']) + if result.get('coordinates'): + result['coordinates'] = eval(result['coordinates']) + except: + pass + results.append(result) + return results + except Exception as e: + logger.error(f"获取摄像头算法绑定失败: {e}") + return [] + + def delete_roi_algo_bind(self, bind_id: str) -> bool: + """删除ROI算法绑定""" + try: + cursor = self._conn.cursor() + cursor.execute("DELETE FROM roi_algo_bind WHERE bind_id = ?", (bind_id,)) + self._conn.commit() + return cursor.rowcount > 0 + except Exception as e: + logger.error(f"删除ROI算法绑定失败: {e}") + return False + + def delete_bindings_by_roi(self, roi_id: str) -> int: + """删除指定ROI的所有算法绑定""" + try: + cursor = self._conn.cursor() + cursor.execute("DELETE FROM roi_algo_bind WHERE roi_id = ?", (roi_id,)) + self._conn.commit() + return cursor.rowcount + except Exception as e: + logger.error(f"删除ROI算法绑定失败: {e}") + return 0 + def log_config_update( self, config_type: str, diff --git a/core/config_sync.py b/core/config_sync.py index 2fcf6e7..a7b8ee5 100644 --- a/core/config_sync.py +++ b/core/config_sync.py @@ -21,7 +21,7 @@ from redis.client import PubSub from config.settings import get_settings, RedisConfig from config.database import get_sqlite_manager, SQLiteManager -from config.config_models import CameraInfo as CameraInfoModel, ROIInfo, ConfigVersion +from config.config_models import CameraInfo as CameraInfoModel, ROIInfo, ConfigVersion, ROIInfoNew, ROIAlgoBind from utils.version_control import get_version_control logger = logging.getLogger(__name__) @@ -282,7 +282,7 @@ class ConfigSyncManager: def get_roi_configs(self, camera_id: Optional[str] = None, force_refresh: bool = False) -> List[ROIInfo]: - """获取ROI配置列表""" + """获取ROI配置列表(兼容旧版本)""" cache_key = f"rois_{camera_id}" if camera_id else "rois_all" if not force_refresh: @@ -312,6 +312,59 @@ class ConfigSyncManager: cached = self._cache.get(cache_key) return cached or [] + def get_roi_configs_with_bindings(self, camera_id: Optional[str] = None, + force_refresh: bool = False) -> List[ROIInfoNew]: + """获取ROI配置列表(包含算法绑定信息)""" + cache_key = f"rois_bindings_{camera_id}" if camera_id else "rois_bindings_all" + + if not force_refresh: + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + self._init_database() + + if self._db_manager is None: + logger.warning("数据库管理器不可用,返回空ROI配置列表") + return [] + + try: + if camera_id: + roi_configs = self._db_manager.get_rois_by_camera(camera_id) + bindings_list = self._db_manager.get_bindings_by_camera(camera_id) + else: + roi_configs = self._db_manager.get_all_roi_configs() + bindings_list = [] + for roi in roi_configs: + bindings = self._db_manager.get_bindings_by_roi(roi['roi_id']) + bindings_list.extend(bindings) + + roi_dict = {r['roi_id']: r for r in roi_configs} + bindings_dict = {} + for b in bindings_list: + roi_id = b['roi_id'] + if roi_id not in bindings_dict: + bindings_dict[roi_id] = [] + bindings_dict[roi_id].append(b) + + result = [] + for roi_id, roi_data in roi_dict.items(): + roi_info = ROIInfoNew.from_dict(roi_data) + if roi_id in bindings_dict: + roi_info.bindings = [ROIAlgoBind.from_dict(b) for b in bindings_dict[roi_id]] + result.append(roi_info) + + result.sort(key=lambda x: x.priority, reverse=True) + + self._cache.set(cache_key, result) + logger.info(f"已加载ROI配置(含绑定): {len(result)} 个") + return result + + except Exception as e: + logger.error(f"获取ROI配置(含绑定)失败: {e}") + cached = self._cache.get(cache_key) + return cached or [] + def get_camera_rois(self, camera_id: str) -> List[ROIInfo]: """获取指定摄像头的ROI配置""" return self.get_roi_configs(camera_id=camera_id) @@ -397,13 +450,8 @@ class ConfigSyncManager: "camera_id": roi_config.get("camera_id", ""), "roi_type": roi_config.get("roi_type", ""), "coordinates": str(roi_config.get("coordinates", [])), - "algorithm_type": roi_config.get("algorithm_type", ""), - "working_hours": str(roi_config.get("working_hours", [])), - "confirm_on_duty_sec": str(roi_config.get("confirm_on_duty_sec", 10)), - "confirm_leave_sec": str(roi_config.get("confirm_leave_sec", 10)), - "cooldown_sec": str(roi_config.get("cooldown_sec", 300)), - "target_class": roi_config.get("target_class", "person"), "enabled": str(roi_config.get("enabled", True)), + "priority": str(roi_config.get("priority", 0)), }) self._redis_client.expire(key, 3600) @@ -413,6 +461,33 @@ class ConfigSyncManager: logger.error(f"缓存ROI配置到Redis失败: {e}") return False + def _cache_algo_bind_to_redis(self, bind_config: Dict[str, Any]) -> bool: + """将ROI算法绑定配置缓存到Redis""" + if not self._redis_client: + return False + + try: + bind_id = bind_config.get("bind_id") + key = f"config:bind:{bind_id}" + + self._redis_client.hset(key, mapping={ + "bind_id": bind_id, + "roi_id": bind_config.get("roi_id", ""), + "algo_code": bind_config.get("algo_code", ""), + "params": str(bind_config.get("params", {})), + "priority": str(bind_config.get("priority", 0)), + "enabled": str(bind_config.get("enabled", True)), + "algo_name": bind_config.get("algo_name", ""), + "target_class": bind_config.get("target_class", "person"), + }) + + self._redis_client.expire(key, 3600) + logger.debug(f"ROI算法绑定配置已缓存到Redis: {key}") + return True + except Exception as e: + logger.error(f"缓存ROI算法绑定配置到Redis失败: {e}") + return False + def _cache_camera_to_redis(self, camera_config: Dict[str, Any]) -> bool: """将摄像头配置缓存到Redis""" if not self._redis_client: @@ -454,6 +529,12 @@ class ConfigSyncManager: if self._cache_roi_to_redis(roi): count += 1 + self.clear_redis_cache("bind") + bindings = self._db_manager.get_bindings_by_camera("") + for bind in bindings: + if self._cache_algo_bind_to_redis(bind): + count += 1 + logger.info(f"已同步 {count} 条配置到Redis缓存") return count except Exception as e: @@ -469,13 +550,8 @@ class ConfigSyncManager: key = f"config:roi:{roi_id}" data = self._redis_client.hgetall(key) if data: - if data.get('coordinates'): - data['coordinates'] = eval(data['coordinates']) - if data.get('working_hours'): - data['working_hours'] = eval(data['working_hours']) - data['confirm_on_duty_sec'] = int(data.get('confirm_on_duty_sec', 10)) - data['confirm_leave_sec'] = int(data.get('confirm_leave_sec', 10)) - data['cooldown_sec'] = int(data.get('cooldown_sec', 300)) + data['coordinates'] = eval(data['coordinates']) if data.get('coordinates') else [] + data['priority'] = int(data.get('priority', 0)) data['enabled'] = data.get('enabled', 'True') == 'True' return data return None @@ -483,6 +559,47 @@ class ConfigSyncManager: logger.error(f"从Redis获取ROI配置失败: {e}") return None + def get_algo_bind_from_redis(self, bind_id: str) -> Optional[Dict[str, Any]]: + """从Redis获取ROI算法绑定配置""" + if not self._redis_client: + return None + + try: + key = f"config:bind:{bind_id}" + data = self._redis_client.hgetall(key) + if data: + data['params'] = eval(data['params']) if data.get('params') else {} + data['priority'] = int(data.get('priority', 0)) + data['enabled'] = data.get('enabled', 'True') == 'True' + data['target_class'] = data.get('target_class', 'person') + return data + return None + except Exception as e: + logger.error(f"从Redis获取ROI算法绑定配置失败: {e}") + return None + + def get_bindings_from_redis(self, roi_id: str) -> List[Dict[str, Any]]: + """从Redis获取ROI的所有算法绑定""" + if not self._redis_client: + return [] + + try: + pattern = "config:bind:*" + keys = self._redis_client.keys(pattern) + results = [] + for key in keys: + data = self._redis_client.hgetall(key) + if data and data.get('roi_id') == roi_id: + data['params'] = eval(data['params']) if data.get('params') else {} + data['priority'] = int(data.get('priority', 0)) + data['enabled'] = data.get('enabled', 'True') == 'True' + data['target_class'] = data.get('target_class', 'person') + results.append(data) + return sorted(results, key=lambda x: x['priority'], reverse=True) + except Exception as e: + logger.error(f"从Redis获取ROI算法绑定列表失败: {e}") + return [] + def get_camera_from_redis(self, camera_id: str) -> Optional[Dict[str, Any]]: """从Redis获取摄像头配置""" if not self._redis_client: @@ -529,6 +646,10 @@ class ConfigSyncManager: keys = self._redis_client.keys("config:camera:*") if keys: self._redis_client.delete(*keys) + elif config_type == "bind": + keys = self._redis_client.keys("config:bind:*") + if keys: + self._redis_client.delete(*keys) else: keys = self._redis_client.keys("config:*") if keys: @@ -543,6 +664,7 @@ class ConfigSyncManager: self.clear_redis_cache() count = self.sync_all_to_redis() self.notify_config_change("roi", []) + self.notify_config_change("bind", []) logger.info(f"算法配置已重新加载,更新了 {count} 条缓存") def close(self): diff --git a/main.py b/main.py index 7e5f24f..07abed2 100644 --- a/main.py +++ b/main.py @@ -182,19 +182,26 @@ class EdgeInferenceService: def _process_frame(self, camera_id: str, frame: VideoFrame): """处理视频帧 - 批量处理多 ROI""" try: - roi_configs = self._config_manager.get_roi_configs(camera_id) + roi_configs = self._config_manager.get_roi_configs_with_bindings(camera_id) roi_items = [] for roi in roi_configs: if not roi.enabled: continue - try: - cropped, scale_info = self._preprocessor.preprocess_single( - frame.image, roi - ) - roi_items.append((camera_id, roi, frame, cropped, scale_info)) - except Exception as e: - self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}") + if not roi.bindings: + continue + + for bind in roi.bindings: + if not bind.enabled: + continue + + try: + cropped, scale_info = self._preprocessor.preprocess_single( + frame.image, roi + ) + roi_items.append((camera_id, roi, bind, frame, cropped, scale_info)) + except Exception as e: + self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}") if not roi_items: return @@ -222,8 +229,8 @@ class EdgeInferenceService: batch_size = len(roi_items) try: - images = [item[3] for item in roi_items] - scale_infos = [item[4] for item in roi_items] + images = [item[4] for item in roi_items] + scale_infos = [item[5] for item in roi_items] batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch( images @@ -241,12 +248,12 @@ class EdgeInferenceService: conf_threshold=self._settings.inference.conf_threshold ) - for idx, (camera_id, roi, frame, _, scale_info) in enumerate(roi_items): + for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(roi_items): boxes, scores, class_ids = batch_results[idx] if len(boxes) > 0: self._handle_detections( - camera_id, roi, frame, + camera_id, roi, bind, frame, boxes, scores, class_ids, scale_info ) @@ -287,6 +294,7 @@ class EdgeInferenceService: self, camera_id: str, roi, + bind, frame: VideoFrame, boxes: any, scores: any, @@ -300,19 +308,14 @@ class EdgeInferenceService: return roi_id = roi.roi_id - algorithm_type = roi.algorithm_type - algo_type_str = algorithm_type.value if hasattr(algorithm_type, 'value') else algorithm_type + algo_code = bind.algo_code + algo_params = bind.params or {} self._algorithm_manager.register_algorithm( roi_id=roi_id, - algorithm_type=algo_type_str, - params={ - "working_hours": roi.working_hours, - "confirm_on_duty_sec": roi.confirm_on_duty_sec, - "confirm_leave_sec": roi.confirm_leave_sec, - "cooldown_sec": roi.cooldown_sec, - "target_class": roi.target_class, - } + bind_id=bind.bind_id, + algorithm_type=algo_code, + params=algo_params ) tracks = self._build_tracks(roi, boxes, scores, class_ids, scale_info) @@ -322,8 +325,9 @@ class EdgeInferenceService: alerts = self._algorithm_manager.process( roi_id=roi_id, + bind_id=bind.bind_id, camera_id=camera_id, - algorithm_type=algo_type_str, + algorithm_type=algo_code, tracks=tracks, current_time=frame.timestamp ) @@ -333,11 +337,12 @@ class EdgeInferenceService: from core.result_reporter import AlertInfo alert_info = AlertInfo( - alert_id=f"{roi_id}_{int(frame.timestamp.timestamp())}", + alert_id=f"{roi_id}_{bind.bind_id}_{int(frame.timestamp.timestamp())}", camera_id=camera_id, roi_id=roi_id, + bind_id=bind.bind_id, alert_type=alert.get("alert_type", "detection"), - target_class=alert.get("class", "unknown"), + target_class=alert.get("class", bind.target_class or "unknown"), confidence=alert.get("confidence", 1.0), bbox=alert.get("bbox", []), message=alert.get("message", ""),