feat: 批处理多ROI推理 - 添加 batch_process_detections 批量后处理 - 重构 _process_frame 收集多ROI统一推理

This commit is contained in:
2026-02-03 11:17:54 +08:00
parent d7f56683c7
commit 1caba41625
2 changed files with 146 additions and 36 deletions

View File

@@ -701,6 +701,98 @@ class PostProcessor:
self._buffer_class_ids[:valid_count].copy() self._buffer_class_ids[:valid_count].copy()
) )
def batch_process_detections(
self,
batch_outputs: List[np.ndarray],
batch_size: int,
conf_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
批量处理多 ROI 检测结果
Args:
batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组)
batch_size: batch 大小
conf_threshold: 置信度阈值
nms_threshold: NMS阈值
Returns:
List of (检测框, 置信度, 类别ID) per ROI
"""
if conf_threshold is None:
conf_threshold = self.conf_threshold
if nms_threshold is None:
nms_threshold = self.nms_threshold
if len(batch_outputs) == 1:
first_output = batch_outputs[0]
if isinstance(first_output, np.ndarray) and first_output.ndim == 3:
if first_output.shape[0] == batch_size:
outputs_array = first_output
else:
outputs_array = first_output
else:
outputs_array = np.array([first_output])
else:
outputs_array = np.array(batch_outputs)
results = []
nms_processor = NMSProcessor(nms_threshold, use_gpu=True)
for batch_idx in range(batch_size):
if outputs_array.ndim == 3:
if outputs_array.shape[0] == batch_size:
output = outputs_array[batch_idx]
else:
output = outputs_array[0]
elif outputs_array.ndim == 2:
output = outputs_array
else:
if batch_idx < len(outputs_array):
output = outputs_array[batch_idx]
else:
output = outputs_array[0]
if output.ndim == 3:
output = output[0]
if output.shape[0] != 84:
results.append((np.array([]), np.array([]), np.array([])))
continue
boxes_xywh = output[0:4, :].T
obj_conf = output[4, :]
person_scores = output[5, :]
scores = obj_conf * person_scores
coarse_mask = scores > 0.3
if not np.any(coarse_mask):
results.append((np.array([]), np.array([]), np.array([])))
continue
boxes = boxes_xywh[coarse_mask]
scores_coarse = scores[coarse_mask]
class_ids = np.zeros(len(boxes), dtype=np.int32)
valid_count = len(boxes)
self._buffer_xyxy[:valid_count, 0] = boxes[:, 0] - boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 1] = boxes[:, 1] - boxes[:, 3] / 2
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
self._buffer_xyxy[:valid_count],
scores_coarse.astype(np.float32),
class_ids,
conf_threshold
)
results.append((keep_boxes, keep_scores, keep_classes))
return results
def filter_by_roi( def filter_by_roi(
self, self,
boxes: np.ndarray, boxes: np.ndarray,

90
main.py
View File

@@ -10,7 +10,7 @@ import threading
import signal import signal
import time import time
from datetime import datetime from datetime import datetime
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, List, Tuple
from config.settings import get_settings, Settings from config.settings import get_settings, Settings
from core.config_sync import get_config_sync_manager, ConfigSyncManager from core.config_sync import get_config_sync_manager, ConfigSyncManager
@@ -57,6 +57,9 @@ class EdgeInferenceService:
"uptime_seconds": 0, "uptime_seconds": 0,
} }
self._batch_roi_queue: List[tuple] = []
self._batch_lock = threading.Lock()
self._logger.info("Edge_Inference_Service 初始化开始") self._logger.info("Edge_Inference_Service 初始化开始")
def _init_database(self): def _init_database(self):
@@ -177,46 +180,53 @@ class EdgeInferenceService:
return callback return callback
def _process_frame(self, camera_id: str, frame: VideoFrame): def _process_frame(self, camera_id: str, frame: VideoFrame):
"""处理视频帧""" """处理视频帧 - 批量处理多 ROI"""
try: try:
start_time = time.perf_counter()
roi_configs = self._config_manager.get_roi_configs(camera_id) roi_configs = self._config_manager.get_roi_configs(camera_id)
roi_items = []
for roi in roi_configs: for roi in roi_configs:
self._process_roi_frame(camera_id, frame, roi) if not roi.enabled:
continue
try:
cropped, scale_info = self._preprocessor.preprocess_single(
frame.image, roi
)
roi_items.append((camera_id, roi, frame, cropped, scale_info))
except Exception as e:
self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}")
processing_time_ms = (time.perf_counter() - start_time) * 1000 if not roi_items:
return
with self._batch_lock:
self._batch_roi_queue.extend(roi_items)
batch_size = len(self._batch_roi_queue)
if batch_size >= 1:
self._batch_process_rois()
self._performance_stats["total_frames_processed"] += 1 self._performance_stats["total_frames_processed"] += 1
self._frame_counter = getattr(self, '_frame_counter', 0) + 1
if self._frame_counter % 10 == 0:
self._logger.log_inference_latency(
processing_time_ms,
batch_size=1
)
except Exception as e: except Exception as e:
self._logger.error(f"处理帧失败 {camera_id}: {e}") self._logger.error(f"处理帧失败 {camera_id}: {e}")
def _process_roi_frame( def _batch_process_rois(self):
self, """批量处理 ROI - 预处理、推理、后处理"""
camera_id: str, with self._batch_lock:
frame: VideoFrame, roi_items = self._batch_roi_queue
roi if not roi_items:
):
"""处理ROI帧batch=1 推理"""
try:
if not roi.enabled:
return return
self._batch_roi_queue = []
cropped = self._preprocessor.preprocess_single(frame.image, roi)
batch_size = len(roi_items)
processed_image, scale_info = cropped
try:
images = [item[3] for item in roi_items]
scale_infos = [item[4] for item in roi_items]
batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch( batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch(
[processed_image] images
) )
engine = self._engine_manager.get_engine("default") engine = self._engine_manager.get_engine("default")
@@ -225,20 +235,28 @@ class EdgeInferenceService:
outputs, inference_time_ms = engine.infer(batch_data) outputs, inference_time_ms = engine.infer(batch_data)
boxes, scores, class_ids = self._postprocessor.process_detections( batch_results = self._postprocessor.batch_process_detections(
outputs, outputs,
batch_size,
conf_threshold=self._settings.inference.conf_threshold conf_threshold=self._settings.inference.conf_threshold
) )
if len(boxes) > 0: for idx, (camera_id, roi, frame, _, scale_info) in enumerate(roi_items):
self._handle_detections( boxes, scores, class_ids = batch_results[idx]
camera_id, roi, frame,
boxes, scores, class_ids, if len(boxes) > 0:
scale_info self._handle_detections(
) camera_id, roi, frame,
boxes, scores, class_ids,
scale_info
)
except Exception as e: except Exception as e:
self._logger.error(f"处理ROI失败: {e}") self._logger.error(f"批量处理 ROI 失败: {e}")
def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi):
"""收集 ROI 帧数据 - 已集成到 _process_frame 中"""
pass
def _build_tracks( def _build_tracks(
self, self,