""" 主入口模块 整合所有模块,启动推理服务 """ import logging import os import sys import threading import signal import time from datetime import datetime from typing import Dict, Any, Optional, List, Tuple from config.settings import get_settings, Settings from core.config_sync import get_config_sync_manager, ConfigSyncManager from core.video_stream import MultiStreamManager, VideoFrame from core.preprocessor import ImagePreprocessor from core.tensorrt_engine import TensorRTEngine, EngineManager from core.postprocessor import PostProcessor from core.result_reporter import ResultReporter from algorithms import AlgorithmManager from utils.logger import get_logger, StructuredLogger from utils.version_control import get_version_control logger = logging.getLogger(__name__) class EdgeInferenceService: """边缘推理服务主类 整合所有模块,提供完整的推理服务 """ def __init__(self): self._running = False self._settings = get_settings() self._logger = get_logger("main") self._version_control = get_version_control() self._db_manager = None self._config_manager: Optional[ConfigSyncManager] = None self._stream_manager: Optional[MultiStreamManager] = None self._preprocessor: Optional[ImagePreprocessor] = None self._engine_manager: Optional[EngineManager] = None self._postprocessor: Optional[PostProcessor] = None self._reporter: Optional[ResultReporter] = None self._algorithm_manager: Optional[AlgorithmManager] = None self._processing_threads: Dict[str, threading.Thread] = {} self._stop_event = threading.Event() self._performance_stats = { "start_time": None, "total_frames_processed": 0, "total_alerts_generated": 0, "uptime_seconds": 0, } self._batch_roi_queue: List[tuple] = [] self._batch_lock = threading.Lock() self._batch_event = threading.Event() self._inference_thread: Optional[threading.Thread] = None self._max_batch_size = 8 self._batch_timeout_sec = 0.05 # 50ms 攒批窗口 self._logger.info("Edge_Inference_Service 初始化开始") def _init_database(self): """初始化数据库""" try: from config.database import SQLiteManager self._db_manager = SQLiteManager() if self._db_manager._conn: self._logger.info("数据库初始化成功") else: self._logger.warning("数据库不可见,服务将在无数据库模式下运行") except Exception as e: self._logger.warning(f"数据库初始化失败,服务将在无数据库模式下运行: {e}") self._db_manager = None def _init_config_manager(self): """初始化配置管理器""" try: self._config_manager = get_config_sync_manager() self._config_manager.start_config_subscription() self._logger.info("配置管理器初始化成功") except Exception as e: self._logger.error(f"配置管理器初始化失败: {e}") raise def _init_stream_manager(self): """初始化流管理器""" self._stream_manager = MultiStreamManager() self._logger.info("流管理器初始化成功") def _init_preprocessor(self): """初始化预处理器""" self._preprocessor = ImagePreprocessor() self._logger.info("预处理器初始化成功") def _init_engine(self): """初始化推理引擎""" try: self._engine_manager = EngineManager() engine_path = self._settings.inference.model_path if os.path.exists(engine_path): self._engine_manager.load_engine("default", engine_path) self._logger.info(f"推理引擎加载成功: {engine_path}") else: self._logger.warning(f"引擎文件不存在: {engine_path}") except Exception as e: self._logger.error(f"推理引擎初始化失败: {e}") raise def _init_postprocessor(self): """初始化后处理器""" self._postprocessor = PostProcessor() self._logger.info("后处理器初始化成功") def _init_reporter(self): """初始化结果上报器""" self._reporter = ResultReporter() self._logger.info("结果上报器初始化成功") def _init_algorithm_manager(self): """初始化算法管理器""" try: self._algorithm_manager = AlgorithmManager() self._algorithm_manager.start_config_subscription() self._logger.info("算法管理器初始化成功") except Exception as e: self._logger.error(f"算法管理器初始化失败: {e}") def initialize(self): """初始化所有组件""" self._logger.info("=" * 50) self._logger.info("Edge_Inference_Service 启动") self._logger.info("=" * 50) self._init_database() self._init_config_manager() self._init_stream_manager() self._init_preprocessor() self._init_engine() self._init_postprocessor() self._init_reporter() self._init_algorithm_manager() self._performance_stats["start_time"] = datetime.now() self._version_control.record_update( version="1.0.0", update_type="启动", description="Edge_Inference_Service 启动运行", updated_by="系统", affected_items=["全局"], ) self._logger.info("所有组件初始化完成") def _load_cameras(self): """加载摄像头配置""" cameras = self._config_manager.get_cameras() for camera in cameras: try: self._stream_manager.add_stream( camera_id=camera.camera_id, rtsp_url=camera.rtsp_url, target_fps=self._settings.video_stream.default_fps, on_frame_callback=self._create_frame_callback(camera.camera_id) ) self._logger.info(f"已添加摄像头: {camera.camera_id}") except Exception as e: self._logger.error(f"添加摄像头失败 {camera.camera_id}: {e}") def _create_frame_callback(self, camera_id: str): """创建帧处理回调""" def callback(frame): self._process_frame(camera_id, frame) return callback def _process_frame(self, camera_id: str, frame: VideoFrame): """处理视频帧 - 批量处理多 ROI""" try: roi_configs = self._config_manager.get_roi_configs_with_bindings(camera_id) roi_items = [] for roi in roi_configs: if not roi.enabled: continue if not roi.bindings: continue for bind in roi.bindings: if not bind.enabled: continue try: cropped, scale_info = self._preprocessor.preprocess_single( frame.image, roi ) roi_items.append((camera_id, roi, bind, frame, cropped, scale_info)) except Exception as e: self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}") if not roi_items: return with self._batch_lock: self._batch_roi_queue.extend(roi_items) # 通知推理线程有新数据 self._batch_event.set() self._performance_stats["total_frames_processed"] += 1 except Exception as e: self._logger.error(f"处理帧失败 {camera_id}: {e}") def _batch_process_rois(self): """批量处理 ROI - 真正的 batch 推理""" with self._batch_lock: roi_items = self._batch_roi_queue if not roi_items: return self._batch_roi_queue = [] try: images = [item[4] for item in roi_items] scale_infos = [item[5] for item in roi_items] # 真正的 batch: 将所有 ROI 裁剪图拼成 [N,3,H,W] 一次推理 batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch( images ) engine = self._engine_manager.get_engine("default") if engine is None: return # 一次性推理整个 batch outputs, inference_time_ms = engine.infer(batch_data) batch_size = len(roi_items) batch_results = self._postprocessor.batch_process_detections( outputs, batch_size, conf_threshold=self._settings.inference.conf_threshold ) for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(roi_items): boxes, scores, class_ids = batch_results[idx] if len(boxes) > 0: self._handle_detections( camera_id, roi, bind, frame, boxes, scores, class_ids, scale_info ) except Exception as e: self._logger.error(f"批量处理 ROI 失败: {e}") def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi): """收集 ROI 帧数据 - 已集成到 _process_frame 中""" pass def _build_tracks( self, roi, boxes: any, scores: any, class_ids: any, scale_info: tuple ) -> list: """将检测结果转换为算法所需的 tracks 格式 坐标从 letterbox 空间还原到 ROI 裁剪空间 """ tracks = [] class_names = getattr(self._settings, 'class_names', ['person']) # 将 letterbox 坐标还原到 ROI 裁剪空间 reverted_boxes = self._preprocessor.revert_boxes( [box.tolist() if hasattr(box, 'tolist') else list(box) for box in boxes], scale_info ) for i, box in enumerate(reverted_boxes): class_id = int(class_ids[i]) if class_ids[i] else 0 track = { "track_id": f"{roi.roi_id}_{i}", "class": class_names[class_id] if class_id < len(class_names) else f"class_{class_id}", "confidence": float(scores[i]), "bbox": box, "matched_rois": [{"roi_id": roi.roi_id}], } tracks.append(track) return tracks def _handle_detections( self, camera_id: str, roi, bind, frame: VideoFrame, boxes: any, scores: any, class_ids: any, scale_info: tuple ): """处理检测结果 - 算法接管判断权""" try: if self._algorithm_manager is None: self._logger.warning("算法管理器不可用,跳过算法处理") return roi_id = roi.roi_id algo_code = bind.algo_code algo_params = bind.params or {} self._algorithm_manager.register_algorithm( roi_id=roi_id, bind_id=bind.bind_id, algorithm_type=algo_code, params=algo_params ) tracks = self._build_tracks(roi, boxes, scores, class_ids, scale_info) if not tracks: return alerts = self._algorithm_manager.process( roi_id=roi_id, bind_id=bind.bind_id, camera_id=camera_id, algorithm_type=algo_code, tracks=tracks, current_time=frame.timestamp ) for alert in alerts: self._performance_stats["total_alerts_generated"] += 1 from core.result_reporter import AlertInfo alert_info = AlertInfo( alert_id=f"{roi_id}_{bind.bind_id}_{int(frame.timestamp.timestamp())}", camera_id=camera_id, roi_id=roi_id, bind_id=bind.bind_id, alert_type=alert.get("alert_type", "detection"), target_class=alert.get("class", bind.target_class or "unknown"), confidence=alert.get("confidence", 1.0), bbox=alert.get("bbox", []), message=alert.get("message", ""), timestamp=frame.timestamp, duration_minutes=alert.get("duration_minutes"), ) self._reporter.report_alert(alert_info, screenshot=frame.image) self._logger.info( f"告警已生成: type={alert.get('alert_type', 'detection')}, " f"camera={camera_id}, roi={roi_id}, " f"confidence={alert.get('confidence', 1.0)}" ) except Exception as e: self._logger.error(f"处理检测结果失败: {e}") def _inference_worker(self): """推理线程:攒批窗口内收集 ROI 请求,批量推理""" while not self._stop_event.is_set(): # 等待有新数据到达或超时 self._batch_event.wait(timeout=self._batch_timeout_sec) self._batch_event.clear() with self._batch_lock: queue_size = len(self._batch_roi_queue) # 攒批窗口:等到攒够 max_batch 或超时后处理 if queue_size > 0 and queue_size < self._max_batch_size: # 再等一小段时间凑更多 self._batch_event.wait(timeout=self._batch_timeout_sec) self._batch_event.clear() with self._batch_lock: if not self._batch_roi_queue: continue self._batch_process_rois() def start(self): """启动服务""" if self._running: return self._running = True self._stop_event.clear() self._load_cameras() # 启动独立推理线程(生产者-消费者模式) self._inference_thread = threading.Thread( target=self._inference_worker, name="InferenceWorker", daemon=True ) self._inference_thread.start() self._logger.info("推理线程已启动") self._stream_manager.start_all() self._logger.info("Edge_Inference_Service 已启动") self._start_heartbeat_thread() self._register_signal_handlers() self._wait_for_shutdown() def _start_heartbeat_thread(self): """启动心跳线程""" def heartbeat(): while not self._stop_event.is_set(): try: uptime = (datetime.now() - self._performance_stats["start_time"]).total_seconds() self._performance_stats["uptime_seconds"] = uptime status = { "running": True, "uptime_seconds": uptime, "frames_processed": self._performance_stats["total_frames_processed"], "alerts_generated": self._performance_stats["total_alerts_generated"], "stream_stats": self._stream_manager.get_statistics() if self._stream_manager else {}, } if self._reporter: self._reporter.report_heartbeat("edge_inference_device", status) except Exception as e: self._logger.error(f"心跳上报失败: {e}") time.sleep(30) thread = threading.Thread(target=heartbeat, name="Heartbeat", daemon=True) thread.start() def _register_signal_handlers(self): """注册信号处理器""" def handle_signal(signum, frame): self._logger.info(f"收到信号 {signum}, 正在停止服务...") self.stop() signal.signal(signal.SIGINT, handle_signal) signal.signal(signal.SIGTERM, handle_signal) def _wait_for_shutdown(self): """等待关闭信号""" while not self._stop_event.is_set(): time.sleep(1) def stop(self): """停止服务""" if not self._running: return self._running = False self._stop_event.set() self._batch_event.set() # 唤醒推理线程以退出 if self._inference_thread and self._inference_thread.is_alive(): self._inference_thread.join(timeout=5) if self._stream_manager: self._stream_manager.stop_all() self._stream_manager.close() if self._engine_manager: self._engine_manager.release_all() if self._config_manager: self._config_manager.stop_config_subscription() self._config_manager.close() if self._algorithm_manager: self._algorithm_manager.stop_config_subscription() if self._reporter: self._reporter.close() self._performance_stats["uptime_seconds"] = ( (datetime.now() - self._performance_stats["start_time"]).total_seconds() ) self._logger.info("Edge_Inference_Service 已停止") self._logger.info(f"运行统计: {self._performance_stats}") def get_status(self) -> Dict[str, Any]: """获取服务状态""" return { "running": self._running, "start_time": ( self._performance_stats["start_time"].isoformat() if self._performance_stats["start_time"] else None ), "uptime_seconds": self._performance_stats["uptime_seconds"], "total_frames_processed": self._performance_stats["total_frames_processed"], "total_alerts_generated": self._performance_stats["total_alerts_generated"], "stream_manager": ( self._stream_manager.get_statistics() if self._stream_manager else {} ), "config_version": ( self._config_manager.config_version if self._config_manager else None ), } def main(): """主函数入口""" service = EdgeInferenceService() try: service.initialize() service.start() except KeyboardInterrupt: service.stop() except Exception as e: logger.error(f"服务异常: {e}") raise if __name__ == "__main__": main()