diff --git a/main.py b/main.py index ff8edaa..149f1e3 100644 --- a/main.py +++ b/main.py @@ -562,59 +562,63 @@ class EdgeInferenceService: self._logger.error(f"处理帧失败 {camera_id}: {e}") def _batch_process_rois(self): - """批量处理 ROI - 真正的 batch 推理""" + """批量处理 ROI - 真正的 batch 推理(按 max_batch_size 分块)""" with self._batch_lock: roi_items = self._batch_roi_queue if not roi_items: return self._batch_roi_queue = [] - try: - images = [item[4] for item in roi_items] - scale_infos = [item[5] for item in roi_items] + engine = self._engine_manager.get_engine("default") + if engine is None: + return - # 真正的 batch: 将所有 ROI 裁剪图拼成 [N,3,H,W] 一次推理 - batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch( - images - ) + # 按 max_batch_size 分块处理 + for chunk_start in range(0, len(roi_items), self._max_batch_size): + chunk = roi_items[chunk_start:chunk_start + self._max_batch_size] - engine = self._engine_manager.get_engine("default") - if engine is None: - return + try: + images = [item[4] for item in chunk] + scale_infos = [item[5] for item in chunk] - # 一次性推理整个 batch - outputs, inference_time_ms = engine.infer(batch_data) - - # 诊断:输出原始推理结果形状(非告警诊断日志,使用 DEBUG 级别) - import numpy as np - if isinstance(outputs, np.ndarray): - self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, output shape={outputs.shape}, 耗时={inference_time_ms:.1f}ms") - elif isinstance(outputs, (list, tuple)): - shapes = [o.shape if hasattr(o, 'shape') else type(o) for o in outputs] - self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms") - - batch_size = len(roi_items) - batch_results = self._postprocessor.batch_process_detections( - outputs, - batch_size, - conf_threshold=self._settings.inference.conf_threshold - ) - - total_detections = sum(len(r[0]) for r in batch_results) - self._logger.debug(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}") - - for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(roi_items): - boxes, scores, class_ids = batch_results[idx] - - # 无论是否检测到目标都要调用算法(离岗检测需要"无人"信号) - self._handle_detections( - camera_id, roi, bind, frame, - boxes, scores, class_ids, - scale_info + # 真正的 batch: 将所有 ROI 裁剪图拼成 [N,3,H,W] 一次推理 + batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch( + images ) - except Exception as e: - self._logger.error(f"批量处理 ROI 失败: {e}") + # 一次性推理整个 batch + outputs, inference_time_ms = engine.infer(batch_data) + + # 诊断:输出原始推理结果形状(非告警诊断日志,使用 DEBUG 级别) + import numpy as np + if isinstance(outputs, np.ndarray): + self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, output shape={outputs.shape}, 耗时={inference_time_ms:.1f}ms") + elif isinstance(outputs, (list, tuple)): + shapes = [o.shape if hasattr(o, 'shape') else type(o) for o in outputs] + self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms") + + batch_size = len(chunk) + batch_results = self._postprocessor.batch_process_detections( + outputs, + batch_size, + conf_threshold=self._settings.inference.conf_threshold + ) + + total_detections = sum(len(r[0]) for r in batch_results) + self._logger.debug(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}") + + for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(chunk): + boxes, scores, class_ids = batch_results[idx] + + # 无论是否检测到目标都要调用算法(离岗检测需要"无人"信号) + self._handle_detections( + camera_id, roi, bind, frame, + boxes, scores, class_ids, + scale_info + ) + + except Exception as e: + self._logger.error(f"批量处理 ROI 失败: {e}") def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi): """收集 ROI 帧数据 - 已集成到 _process_frame 中"""