fix(edge): 修复 batch 推理超过 MAX_BATCH_SIZE 导致缓冲区溢出
队列中 ROI 数量超过 8 时(多摄像头多 ROI 绑定场景), 一次性送入 TensorRT 引擎导致 np.copyto 溢出。 改为按 max_batch_size 分块推理。 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
88
main.py
88
main.py
@@ -562,59 +562,63 @@ class EdgeInferenceService:
|
|||||||
self._logger.error(f"处理帧失败 {camera_id}: {e}")
|
self._logger.error(f"处理帧失败 {camera_id}: {e}")
|
||||||
|
|
||||||
def _batch_process_rois(self):
|
def _batch_process_rois(self):
|
||||||
"""批量处理 ROI - 真正的 batch 推理"""
|
"""批量处理 ROI - 真正的 batch 推理(按 max_batch_size 分块)"""
|
||||||
with self._batch_lock:
|
with self._batch_lock:
|
||||||
roi_items = self._batch_roi_queue
|
roi_items = self._batch_roi_queue
|
||||||
if not roi_items:
|
if not roi_items:
|
||||||
return
|
return
|
||||||
self._batch_roi_queue = []
|
self._batch_roi_queue = []
|
||||||
|
|
||||||
try:
|
engine = self._engine_manager.get_engine("default")
|
||||||
images = [item[4] for item in roi_items]
|
if engine is None:
|
||||||
scale_infos = [item[5] for item in roi_items]
|
return
|
||||||
|
|
||||||
# 真正的 batch: 将所有 ROI 裁剪图拼成 [N,3,H,W] 一次推理
|
# 按 max_batch_size 分块处理
|
||||||
batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch(
|
for chunk_start in range(0, len(roi_items), self._max_batch_size):
|
||||||
images
|
chunk = roi_items[chunk_start:chunk_start + self._max_batch_size]
|
||||||
)
|
|
||||||
|
|
||||||
engine = self._engine_manager.get_engine("default")
|
try:
|
||||||
if engine is None:
|
images = [item[4] for item in chunk]
|
||||||
return
|
scale_infos = [item[5] for item in chunk]
|
||||||
|
|
||||||
# 一次性推理整个 batch
|
# 真正的 batch: 将所有 ROI 裁剪图拼成 [N,3,H,W] 一次推理
|
||||||
outputs, inference_time_ms = engine.infer(batch_data)
|
batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch(
|
||||||
|
images
|
||||||
# 诊断:输出原始推理结果形状(非告警诊断日志,使用 DEBUG 级别)
|
|
||||||
import numpy as np
|
|
||||||
if isinstance(outputs, np.ndarray):
|
|
||||||
self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, output shape={outputs.shape}, 耗时={inference_time_ms:.1f}ms")
|
|
||||||
elif isinstance(outputs, (list, tuple)):
|
|
||||||
shapes = [o.shape if hasattr(o, 'shape') else type(o) for o in outputs]
|
|
||||||
self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms")
|
|
||||||
|
|
||||||
batch_size = len(roi_items)
|
|
||||||
batch_results = self._postprocessor.batch_process_detections(
|
|
||||||
outputs,
|
|
||||||
batch_size,
|
|
||||||
conf_threshold=self._settings.inference.conf_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
total_detections = sum(len(r[0]) for r in batch_results)
|
|
||||||
self._logger.debug(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}")
|
|
||||||
|
|
||||||
for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(roi_items):
|
|
||||||
boxes, scores, class_ids = batch_results[idx]
|
|
||||||
|
|
||||||
# 无论是否检测到目标都要调用算法(离岗检测需要"无人"信号)
|
|
||||||
self._handle_detections(
|
|
||||||
camera_id, roi, bind, frame,
|
|
||||||
boxes, scores, class_ids,
|
|
||||||
scale_info
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
# 一次性推理整个 batch
|
||||||
self._logger.error(f"批量处理 ROI 失败: {e}")
|
outputs, inference_time_ms = engine.infer(batch_data)
|
||||||
|
|
||||||
|
# 诊断:输出原始推理结果形状(非告警诊断日志,使用 DEBUG 级别)
|
||||||
|
import numpy as np
|
||||||
|
if isinstance(outputs, np.ndarray):
|
||||||
|
self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, output shape={outputs.shape}, 耗时={inference_time_ms:.1f}ms")
|
||||||
|
elif isinstance(outputs, (list, tuple)):
|
||||||
|
shapes = [o.shape if hasattr(o, 'shape') else type(o) for o in outputs]
|
||||||
|
self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms")
|
||||||
|
|
||||||
|
batch_size = len(chunk)
|
||||||
|
batch_results = self._postprocessor.batch_process_detections(
|
||||||
|
outputs,
|
||||||
|
batch_size,
|
||||||
|
conf_threshold=self._settings.inference.conf_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
total_detections = sum(len(r[0]) for r in batch_results)
|
||||||
|
self._logger.debug(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}")
|
||||||
|
|
||||||
|
for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(chunk):
|
||||||
|
boxes, scores, class_ids = batch_results[idx]
|
||||||
|
|
||||||
|
# 无论是否检测到目标都要调用算法(离岗检测需要"无人"信号)
|
||||||
|
self._handle_detections(
|
||||||
|
camera_id, roi, bind, frame,
|
||||||
|
boxes, scores, class_ids,
|
||||||
|
scale_info
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"批量处理 ROI 失败: {e}")
|
||||||
|
|
||||||
def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi):
|
def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi):
|
||||||
"""收集 ROI 帧数据 - 已集成到 _process_frame 中"""
|
"""收集 ROI 帧数据 - 已集成到 _process_frame 中"""
|
||||||
|
|||||||
Reference in New Issue
Block a user