From fdd3e9ac883c281597bd13b59db62b1e4aeb46ea Mon Sep 17 00:00:00 2001 From: 16337 <1633794139@qq.com> Date: Fri, 30 Jan 2026 14:24:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=AE=97=E6=B3=95=E5=B1=82=E6=8E=A5?= =?UTF-8?q?=E7=AE=A1=E5=88=A4=E6=96=AD=E6=9D=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加 AlgorithmManager 集成 - 新增 _build_tracks() 将检测转为算法输入格式 - 重构 _handle_detections() 让算法决定是否告警 - 注册算法时从 ROI 配置读取参数 - 停止使用 PostProcessor.check_alarm_condition() 做业务判断 新的执行链路: ROI裁剪 480p预处理 YOLO推理 algorithm.process() 告警 --- main.py | 136 +++++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 96 insertions(+), 40 deletions(-) diff --git a/main.py b/main.py index fdc821b..02ef3b0 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,7 @@ from core.preprocessor import ImagePreprocessor from core.tensorrt_engine import TensorRTEngine, EngineManager from core.postprocessor import PostProcessor from core.result_reporter import ResultReporter +from algorithms import AlgorithmManager from utils.logger import get_logger, StructuredLogger from utils.version_control import get_version_control @@ -45,6 +46,7 @@ class EdgeInferenceService: self._engine_manager: Optional[EngineManager] = None self._postprocessor: Optional[PostProcessor] = None self._reporter: Optional[ResultReporter] = None + self._algorithm_manager: Optional[AlgorithmManager] = None self._processing_threads: Dict[str, threading.Thread] = {} self._stop_event = threading.Event() @@ -117,6 +119,15 @@ class EdgeInferenceService: self._reporter = ResultReporter() self._logger.info("结果上报器初始化成功") + def _init_algorithm_manager(self): + """初始化算法管理器""" + try: + self._algorithm_manager = AlgorithmManager() + self._algorithm_manager.start_config_subscription() + self._logger.info("算法管理器初始化成功") + except Exception as e: + self._logger.error(f"算法管理器初始化失败: {e}") + def initialize(self): """初始化所有组件""" self._logger.info("=" * 50) @@ -130,6 +141,7 @@ class EdgeInferenceService: self._init_engine() self._init_postprocessor() self._init_reporter() + self._init_algorithm_manager() self._performance_stats["start_time"] = datetime.now() @@ -195,6 +207,9 @@ class EdgeInferenceService: ): """处理ROI帧""" try: + if not roi.enabled: + return + cropped = self._preprocessor.preprocess_single(frame.image, roi) processed_image, scale_info = cropped @@ -216,7 +231,7 @@ class EdgeInferenceService: if len(boxes) > 0: self._handle_detections( - camera_id, roi.roi_id, frame, + camera_id, roi, frame, boxes, scores, class_ids, scale_info ) @@ -224,61 +239,99 @@ class EdgeInferenceService: except Exception as e: self._logger.error(f"处理ROI帧失败: {e}") + def _build_tracks( + self, + roi, + boxes: any, + scores: any, + class_ids: any, + scale_info: tuple + ) -> list: + """将检测结果转换为算法所需的 tracks 格式""" + tracks = [] + class_names = getattr(self._settings, 'class_names', ['person']) + + for i, box in enumerate(boxes): + class_id = int(class_ids[i]) if class_ids[i] else 0 + track = { + "track_id": f"{roi.roi_id}_{i}", + "class": class_names[class_id] if class_id < len(class_names) else f"class_{class_id}", + "confidence": float(scores[i]), + "bbox": box, + "matched_rois": [{"roi_id": roi.roi_id}], + } + tracks.append(track) + + return tracks + def _handle_detections( self, camera_id: str, - roi_id: str, + roi, frame: VideoFrame, boxes: any, scores: any, class_ids: any, scale_info: tuple ): - """处理检测结果""" + """处理检测结果 - 算法接管判断权""" try: - from config.config_models import AlgorithmType + if self._algorithm_manager is None: + self._logger.warning("算法管理器不可用,跳过算法处理") + return - mapped_boxes = self._postprocessor.map_coordinates( - boxes, scale_info, - (frame.width, frame.height) + roi_id = roi.roi_id + algorithm_type = roi.algorithm_type + algo_type_str = algorithm_type.value if hasattr(algorithm_type, 'value') else algorithm_type + + self._algorithm_manager.register_algorithm( + roi_id=roi_id, + algorithm_type=algo_type_str, + params={ + "working_hours": roi.working_hours, + "confirm_on_duty_sec": roi.confirm_on_duty_sec, + "confirm_leave_sec": roi.confirm_leave_sec, + "cooldown_sec": roi.cooldown_sec, + "target_class": roi.target_class, + } ) - for i, box in enumerate(mapped_boxes): - detection_result = { - "class_id": int(class_ids[i]) if len(class_ids) > 0 else 0, - "confidence": float(scores[i]), - "bbox": box, - } + tracks = self._build_tracks(roi, boxes, scores, class_ids, scale_info) + + if not tracks: + return + + alerts = self._algorithm_manager.process( + roi_id=roi_id, + camera_id=camera_id, + algorithm_type=algo_type_str, + tracks=tracks, + current_time=frame.timestamp + ) + + for alert in alerts: + self._performance_stats["total_alerts_generated"] += 1 - alert_result = self._postprocessor.check_alarm_condition( - roi_id, True, frame.timestamp + self._reporter.report_detection_alert( + camera_id=camera_id, + roi_id=roi_id, + alert_type=alert.get("alert_type", "detection"), + detection={ + "class_name": alert.get("message", "未知目标"), + "confidence": alert.get("confidence", 1.0), + "bbox": alert.get("bbox", []), + "message": alert.get("message", ""), + }, + screenshot=frame.image + ) + + self._logger.log_alert( + alert.get("alert_type", "detection"), + camera_id, + roi_id, + alert.get("confidence", 1.0) ) - if alert_result["should_alert"]: - self._performance_stats["total_alerts_generated"] += 1 - - screenshot = frame.image - - self._reporter.report_detection_alert( - camera_id=camera_id, - roi_id=roi_id, - alert_type="detection", - detection={ - "class_name": f"class_{detection_result['class_id']}", - "confidence": detection_result["confidence"], - "bbox": detection_result["bbox"], - "message": f"检测到目标" - }, - screenshot=screenshot - ) - - self._logger.log_alert( - "detection", - camera_id, - roi_id, - detection_result["confidence"] - ) - except Exception as e: self._logger.error(f"处理检测结果失败: {e}") @@ -362,6 +415,9 @@ class EdgeInferenceService: self._config_manager.stop_config_subscription() self._config_manager.close() + if self._algorithm_manager: + self._algorithm_manager.stop_config_subscription() + if self._reporter: self._reporter.close()