diff --git a/core/postprocessor.py b/core/postprocessor.py index 8b76ccd..e7e5d89 100644 --- a/core/postprocessor.py +++ b/core/postprocessor.py @@ -701,6 +701,98 @@ class PostProcessor: self._buffer_class_ids[:valid_count].copy() ) + def batch_process_detections( + self, + batch_outputs: List[np.ndarray], + batch_size: int, + conf_threshold: Optional[float] = None, + nms_threshold: Optional[float] = None + ) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]: + """ + 批量处理多 ROI 检测结果 + + Args: + batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组) + batch_size: batch 大小 + conf_threshold: 置信度阈值 + nms_threshold: NMS阈值 + + Returns: + List of (检测框, 置信度, 类别ID) per ROI + """ + if conf_threshold is None: + conf_threshold = self.conf_threshold + if nms_threshold is None: + nms_threshold = self.nms_threshold + + if len(batch_outputs) == 1: + first_output = batch_outputs[0] + if isinstance(first_output, np.ndarray) and first_output.ndim == 3: + if first_output.shape[0] == batch_size: + outputs_array = first_output + else: + outputs_array = first_output + else: + outputs_array = np.array([first_output]) + else: + outputs_array = np.array(batch_outputs) + + results = [] + nms_processor = NMSProcessor(nms_threshold, use_gpu=True) + + for batch_idx in range(batch_size): + if outputs_array.ndim == 3: + if outputs_array.shape[0] == batch_size: + output = outputs_array[batch_idx] + else: + output = outputs_array[0] + elif outputs_array.ndim == 2: + output = outputs_array + else: + if batch_idx < len(outputs_array): + output = outputs_array[batch_idx] + else: + output = outputs_array[0] + + if output.ndim == 3: + output = output[0] + + if output.shape[0] != 84: + results.append((np.array([]), np.array([]), np.array([]))) + continue + + boxes_xywh = output[0:4, :].T + obj_conf = output[4, :] + person_scores = output[5, :] + scores = obj_conf * person_scores + + coarse_mask = scores > 0.3 + if not np.any(coarse_mask): + results.append((np.array([]), np.array([]), np.array([]))) + continue + + boxes = boxes_xywh[coarse_mask] + scores_coarse = scores[coarse_mask] + class_ids = np.zeros(len(boxes), dtype=np.int32) + + valid_count = len(boxes) + + self._buffer_xyxy[:valid_count, 0] = boxes[:, 0] - boxes[:, 2] / 2 + self._buffer_xyxy[:valid_count, 1] = boxes[:, 1] - boxes[:, 3] / 2 + self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2 + self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2 + + keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter( + self._buffer_xyxy[:valid_count], + scores_coarse.astype(np.float32), + class_ids, + conf_threshold + ) + + results.append((keep_boxes, keep_scores, keep_classes)) + + return results + def filter_by_roi( self, boxes: np.ndarray, diff --git a/main.py b/main.py index 3c7b270..7e5f24f 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ import threading import signal import time from datetime import datetime -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, List, Tuple from config.settings import get_settings, Settings from core.config_sync import get_config_sync_manager, ConfigSyncManager @@ -57,6 +57,9 @@ class EdgeInferenceService: "uptime_seconds": 0, } + self._batch_roi_queue: List[tuple] = [] + self._batch_lock = threading.Lock() + self._logger.info("Edge_Inference_Service 初始化开始") def _init_database(self): @@ -177,46 +180,53 @@ class EdgeInferenceService: return callback def _process_frame(self, camera_id: str, frame: VideoFrame): - """处理视频帧""" + """处理视频帧 - 批量处理多 ROI""" try: - start_time = time.perf_counter() - roi_configs = self._config_manager.get_roi_configs(camera_id) + roi_items = [] for roi in roi_configs: - self._process_roi_frame(camera_id, frame, roi) + if not roi.enabled: + continue + try: + cropped, scale_info = self._preprocessor.preprocess_single( + frame.image, roi + ) + roi_items.append((camera_id, roi, frame, cropped, scale_info)) + except Exception as e: + self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}") - processing_time_ms = (time.perf_counter() - start_time) * 1000 + if not roi_items: + return + + with self._batch_lock: + self._batch_roi_queue.extend(roi_items) + + batch_size = len(self._batch_roi_queue) + if batch_size >= 1: + self._batch_process_rois() self._performance_stats["total_frames_processed"] += 1 - self._frame_counter = getattr(self, '_frame_counter', 0) + 1 - if self._frame_counter % 10 == 0: - self._logger.log_inference_latency( - processing_time_ms, - batch_size=1 - ) - except Exception as e: self._logger.error(f"处理帧失败 {camera_id}: {e}") - def _process_roi_frame( - self, - camera_id: str, - frame: VideoFrame, - roi - ): - """处理ROI帧,batch=1 推理""" - try: - if not roi.enabled: + def _batch_process_rois(self): + """批量处理 ROI - 预处理、推理、后处理""" + with self._batch_lock: + roi_items = self._batch_roi_queue + if not roi_items: return - - cropped = self._preprocessor.preprocess_single(frame.image, roi) - - processed_image, scale_info = cropped + self._batch_roi_queue = [] + + batch_size = len(roi_items) + + try: + images = [item[3] for item in roi_items] + scale_infos = [item[4] for item in roi_items] batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch( - [processed_image] + images ) engine = self._engine_manager.get_engine("default") @@ -225,20 +235,28 @@ class EdgeInferenceService: outputs, inference_time_ms = engine.infer(batch_data) - boxes, scores, class_ids = self._postprocessor.process_detections( + batch_results = self._postprocessor.batch_process_detections( outputs, + batch_size, conf_threshold=self._settings.inference.conf_threshold ) - if len(boxes) > 0: - self._handle_detections( - camera_id, roi, frame, - boxes, scores, class_ids, - scale_info - ) - + for idx, (camera_id, roi, frame, _, scale_info) in enumerate(roi_items): + boxes, scores, class_ids = batch_results[idx] + + if len(boxes) > 0: + self._handle_detections( + camera_id, roi, frame, + boxes, scores, class_ids, + scale_info + ) + except Exception as e: - self._logger.error(f"处理ROI帧失败: {e}") + self._logger.error(f"批量处理 ROI 失败: {e}") + + def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi): + """收集 ROI 帧数据 - 已集成到 _process_frame 中""" + pass def _build_tracks( self,