feat: 批处理多ROI推理 - 添加 batch_process_detections 批量后处理 - 重构 _process_frame 收集多ROI统一推理

This commit is contained in:
2026-02-03 11:17:54 +08:00
parent d7f56683c7
commit 1caba41625
2 changed files with 146 additions and 36 deletions

View File

@@ -701,6 +701,98 @@ class PostProcessor:
self._buffer_class_ids[:valid_count].copy()
)
def batch_process_detections(
self,
batch_outputs: List[np.ndarray],
batch_size: int,
conf_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
批量处理多 ROI 检测结果
Args:
batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组)
batch_size: batch 大小
conf_threshold: 置信度阈值
nms_threshold: NMS阈值
Returns:
List of (检测框, 置信度, 类别ID) per ROI
"""
if conf_threshold is None:
conf_threshold = self.conf_threshold
if nms_threshold is None:
nms_threshold = self.nms_threshold
if len(batch_outputs) == 1:
first_output = batch_outputs[0]
if isinstance(first_output, np.ndarray) and first_output.ndim == 3:
if first_output.shape[0] == batch_size:
outputs_array = first_output
else:
outputs_array = first_output
else:
outputs_array = np.array([first_output])
else:
outputs_array = np.array(batch_outputs)
results = []
nms_processor = NMSProcessor(nms_threshold, use_gpu=True)
for batch_idx in range(batch_size):
if outputs_array.ndim == 3:
if outputs_array.shape[0] == batch_size:
output = outputs_array[batch_idx]
else:
output = outputs_array[0]
elif outputs_array.ndim == 2:
output = outputs_array
else:
if batch_idx < len(outputs_array):
output = outputs_array[batch_idx]
else:
output = outputs_array[0]
if output.ndim == 3:
output = output[0]
if output.shape[0] != 84:
results.append((np.array([]), np.array([]), np.array([])))
continue
boxes_xywh = output[0:4, :].T
obj_conf = output[4, :]
person_scores = output[5, :]
scores = obj_conf * person_scores
coarse_mask = scores > 0.3
if not np.any(coarse_mask):
results.append((np.array([]), np.array([]), np.array([])))
continue
boxes = boxes_xywh[coarse_mask]
scores_coarse = scores[coarse_mask]
class_ids = np.zeros(len(boxes), dtype=np.int32)
valid_count = len(boxes)
self._buffer_xyxy[:valid_count, 0] = boxes[:, 0] - boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 1] = boxes[:, 1] - boxes[:, 3] / 2
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
self._buffer_xyxy[:valid_count],
scores_coarse.astype(np.float32),
class_ids,
conf_threshold
)
results.append((keep_boxes, keep_scores, keep_classes))
return results
def filter_by_roi(
self,
boxes: np.ndarray,

90
main.py
View File

@@ -10,7 +10,7 @@ import threading
import signal
import time
from datetime import datetime
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, List, Tuple
from config.settings import get_settings, Settings
from core.config_sync import get_config_sync_manager, ConfigSyncManager
@@ -57,6 +57,9 @@ class EdgeInferenceService:
"uptime_seconds": 0,
}
self._batch_roi_queue: List[tuple] = []
self._batch_lock = threading.Lock()
self._logger.info("Edge_Inference_Service 初始化开始")
def _init_database(self):
@@ -177,46 +180,53 @@ class EdgeInferenceService:
return callback
def _process_frame(self, camera_id: str, frame: VideoFrame):
"""处理视频帧"""
"""处理视频帧 - 批量处理多 ROI"""
try:
start_time = time.perf_counter()
roi_configs = self._config_manager.get_roi_configs(camera_id)
roi_items = []
for roi in roi_configs:
self._process_roi_frame(camera_id, frame, roi)
if not roi.enabled:
continue
try:
cropped, scale_info = self._preprocessor.preprocess_single(
frame.image, roi
)
roi_items.append((camera_id, roi, frame, cropped, scale_info))
except Exception as e:
self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}")
processing_time_ms = (time.perf_counter() - start_time) * 1000
if not roi_items:
return
with self._batch_lock:
self._batch_roi_queue.extend(roi_items)
batch_size = len(self._batch_roi_queue)
if batch_size >= 1:
self._batch_process_rois()
self._performance_stats["total_frames_processed"] += 1
self._frame_counter = getattr(self, '_frame_counter', 0) + 1
if self._frame_counter % 10 == 0:
self._logger.log_inference_latency(
processing_time_ms,
batch_size=1
)
except Exception as e:
self._logger.error(f"处理帧失败 {camera_id}: {e}")
def _process_roi_frame(
self,
camera_id: str,
frame: VideoFrame,
roi
):
"""处理ROI帧batch=1 推理"""
try:
if not roi.enabled:
def _batch_process_rois(self):
"""批量处理 ROI - 预处理、推理、后处理"""
with self._batch_lock:
roi_items = self._batch_roi_queue
if not roi_items:
return
cropped = self._preprocessor.preprocess_single(frame.image, roi)
processed_image, scale_info = cropped
self._batch_roi_queue = []
batch_size = len(roi_items)
try:
images = [item[3] for item in roi_items]
scale_infos = [item[4] for item in roi_items]
batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch(
[processed_image]
images
)
engine = self._engine_manager.get_engine("default")
@@ -225,20 +235,28 @@ class EdgeInferenceService:
outputs, inference_time_ms = engine.infer(batch_data)
boxes, scores, class_ids = self._postprocessor.process_detections(
batch_results = self._postprocessor.batch_process_detections(
outputs,
batch_size,
conf_threshold=self._settings.inference.conf_threshold
)
if len(boxes) > 0:
self._handle_detections(
camera_id, roi, frame,
boxes, scores, class_ids,
scale_info
)
for idx, (camera_id, roi, frame, _, scale_info) in enumerate(roi_items):
boxes, scores, class_ids = batch_results[idx]
if len(boxes) > 0:
self._handle_detections(
camera_id, roi, frame,
boxes, scores, class_ids,
scale_info
)
except Exception as e:
self._logger.error(f"处理ROI失败: {e}")
self._logger.error(f"批量处理 ROI 失败: {e}")
def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi):
"""收集 ROI 帧数据 - 已集成到 _process_frame 中"""
pass
def _build_tracks(
self,