feat: 批处理多ROI推理 - 添加 batch_process_detections 批量后处理 - 重构 _process_frame 收集多ROI统一推理
This commit is contained in:
@@ -701,6 +701,98 @@ class PostProcessor:
|
|||||||
self._buffer_class_ids[:valid_count].copy()
|
self._buffer_class_ids[:valid_count].copy()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def batch_process_detections(
|
||||||
|
self,
|
||||||
|
batch_outputs: List[np.ndarray],
|
||||||
|
batch_size: int,
|
||||||
|
conf_threshold: Optional[float] = None,
|
||||||
|
nms_threshold: Optional[float] = None
|
||||||
|
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
||||||
|
"""
|
||||||
|
批量处理多 ROI 检测结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组)
|
||||||
|
batch_size: batch 大小
|
||||||
|
conf_threshold: 置信度阈值
|
||||||
|
nms_threshold: NMS阈值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (检测框, 置信度, 类别ID) per ROI
|
||||||
|
"""
|
||||||
|
if conf_threshold is None:
|
||||||
|
conf_threshold = self.conf_threshold
|
||||||
|
if nms_threshold is None:
|
||||||
|
nms_threshold = self.nms_threshold
|
||||||
|
|
||||||
|
if len(batch_outputs) == 1:
|
||||||
|
first_output = batch_outputs[0]
|
||||||
|
if isinstance(first_output, np.ndarray) and first_output.ndim == 3:
|
||||||
|
if first_output.shape[0] == batch_size:
|
||||||
|
outputs_array = first_output
|
||||||
|
else:
|
||||||
|
outputs_array = first_output
|
||||||
|
else:
|
||||||
|
outputs_array = np.array([first_output])
|
||||||
|
else:
|
||||||
|
outputs_array = np.array(batch_outputs)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
nms_processor = NMSProcessor(nms_threshold, use_gpu=True)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
if outputs_array.ndim == 3:
|
||||||
|
if outputs_array.shape[0] == batch_size:
|
||||||
|
output = outputs_array[batch_idx]
|
||||||
|
else:
|
||||||
|
output = outputs_array[0]
|
||||||
|
elif outputs_array.ndim == 2:
|
||||||
|
output = outputs_array
|
||||||
|
else:
|
||||||
|
if batch_idx < len(outputs_array):
|
||||||
|
output = outputs_array[batch_idx]
|
||||||
|
else:
|
||||||
|
output = outputs_array[0]
|
||||||
|
|
||||||
|
if output.ndim == 3:
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
if output.shape[0] != 84:
|
||||||
|
results.append((np.array([]), np.array([]), np.array([])))
|
||||||
|
continue
|
||||||
|
|
||||||
|
boxes_xywh = output[0:4, :].T
|
||||||
|
obj_conf = output[4, :]
|
||||||
|
person_scores = output[5, :]
|
||||||
|
scores = obj_conf * person_scores
|
||||||
|
|
||||||
|
coarse_mask = scores > 0.3
|
||||||
|
if not np.any(coarse_mask):
|
||||||
|
results.append((np.array([]), np.array([]), np.array([])))
|
||||||
|
continue
|
||||||
|
|
||||||
|
boxes = boxes_xywh[coarse_mask]
|
||||||
|
scores_coarse = scores[coarse_mask]
|
||||||
|
class_ids = np.zeros(len(boxes), dtype=np.int32)
|
||||||
|
|
||||||
|
valid_count = len(boxes)
|
||||||
|
|
||||||
|
self._buffer_xyxy[:valid_count, 0] = boxes[:, 0] - boxes[:, 2] / 2
|
||||||
|
self._buffer_xyxy[:valid_count, 1] = boxes[:, 1] - boxes[:, 3] / 2
|
||||||
|
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
|
||||||
|
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
|
||||||
|
|
||||||
|
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
|
||||||
|
self._buffer_xyxy[:valid_count],
|
||||||
|
scores_coarse.astype(np.float32),
|
||||||
|
class_ids,
|
||||||
|
conf_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append((keep_boxes, keep_scores, keep_classes))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def filter_by_roi(
|
def filter_by_roi(
|
||||||
self,
|
self,
|
||||||
boxes: np.ndarray,
|
boxes: np.ndarray,
|
||||||
|
|||||||
90
main.py
90
main.py
@@ -10,7 +10,7 @@ import threading
|
|||||||
import signal
|
import signal
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional, List, Tuple
|
||||||
|
|
||||||
from config.settings import get_settings, Settings
|
from config.settings import get_settings, Settings
|
||||||
from core.config_sync import get_config_sync_manager, ConfigSyncManager
|
from core.config_sync import get_config_sync_manager, ConfigSyncManager
|
||||||
@@ -57,6 +57,9 @@ class EdgeInferenceService:
|
|||||||
"uptime_seconds": 0,
|
"uptime_seconds": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self._batch_roi_queue: List[tuple] = []
|
||||||
|
self._batch_lock = threading.Lock()
|
||||||
|
|
||||||
self._logger.info("Edge_Inference_Service 初始化开始")
|
self._logger.info("Edge_Inference_Service 初始化开始")
|
||||||
|
|
||||||
def _init_database(self):
|
def _init_database(self):
|
||||||
@@ -177,46 +180,53 @@ class EdgeInferenceService:
|
|||||||
return callback
|
return callback
|
||||||
|
|
||||||
def _process_frame(self, camera_id: str, frame: VideoFrame):
|
def _process_frame(self, camera_id: str, frame: VideoFrame):
|
||||||
"""处理视频帧"""
|
"""处理视频帧 - 批量处理多 ROI"""
|
||||||
try:
|
try:
|
||||||
start_time = time.perf_counter()
|
|
||||||
|
|
||||||
roi_configs = self._config_manager.get_roi_configs(camera_id)
|
roi_configs = self._config_manager.get_roi_configs(camera_id)
|
||||||
|
|
||||||
|
roi_items = []
|
||||||
for roi in roi_configs:
|
for roi in roi_configs:
|
||||||
self._process_roi_frame(camera_id, frame, roi)
|
if not roi.enabled:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
cropped, scale_info = self._preprocessor.preprocess_single(
|
||||||
|
frame.image, roi
|
||||||
|
)
|
||||||
|
roi_items.append((camera_id, roi, frame, cropped, scale_info))
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}")
|
||||||
|
|
||||||
processing_time_ms = (time.perf_counter() - start_time) * 1000
|
if not roi_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._batch_lock:
|
||||||
|
self._batch_roi_queue.extend(roi_items)
|
||||||
|
|
||||||
|
batch_size = len(self._batch_roi_queue)
|
||||||
|
if batch_size >= 1:
|
||||||
|
self._batch_process_rois()
|
||||||
|
|
||||||
self._performance_stats["total_frames_processed"] += 1
|
self._performance_stats["total_frames_processed"] += 1
|
||||||
|
|
||||||
self._frame_counter = getattr(self, '_frame_counter', 0) + 1
|
|
||||||
if self._frame_counter % 10 == 0:
|
|
||||||
self._logger.log_inference_latency(
|
|
||||||
processing_time_ms,
|
|
||||||
batch_size=1
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.error(f"处理帧失败 {camera_id}: {e}")
|
self._logger.error(f"处理帧失败 {camera_id}: {e}")
|
||||||
|
|
||||||
def _process_roi_frame(
|
def _batch_process_rois(self):
|
||||||
self,
|
"""批量处理 ROI - 预处理、推理、后处理"""
|
||||||
camera_id: str,
|
with self._batch_lock:
|
||||||
frame: VideoFrame,
|
roi_items = self._batch_roi_queue
|
||||||
roi
|
if not roi_items:
|
||||||
):
|
|
||||||
"""处理ROI帧,batch=1 推理"""
|
|
||||||
try:
|
|
||||||
if not roi.enabled:
|
|
||||||
return
|
return
|
||||||
|
self._batch_roi_queue = []
|
||||||
cropped = self._preprocessor.preprocess_single(frame.image, roi)
|
|
||||||
|
batch_size = len(roi_items)
|
||||||
processed_image, scale_info = cropped
|
|
||||||
|
try:
|
||||||
|
images = [item[3] for item in roi_items]
|
||||||
|
scale_infos = [item[4] for item in roi_items]
|
||||||
|
|
||||||
batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch(
|
batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch(
|
||||||
[processed_image]
|
images
|
||||||
)
|
)
|
||||||
|
|
||||||
engine = self._engine_manager.get_engine("default")
|
engine = self._engine_manager.get_engine("default")
|
||||||
@@ -225,20 +235,28 @@ class EdgeInferenceService:
|
|||||||
|
|
||||||
outputs, inference_time_ms = engine.infer(batch_data)
|
outputs, inference_time_ms = engine.infer(batch_data)
|
||||||
|
|
||||||
boxes, scores, class_ids = self._postprocessor.process_detections(
|
batch_results = self._postprocessor.batch_process_detections(
|
||||||
outputs,
|
outputs,
|
||||||
|
batch_size,
|
||||||
conf_threshold=self._settings.inference.conf_threshold
|
conf_threshold=self._settings.inference.conf_threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(boxes) > 0:
|
for idx, (camera_id, roi, frame, _, scale_info) in enumerate(roi_items):
|
||||||
self._handle_detections(
|
boxes, scores, class_ids = batch_results[idx]
|
||||||
camera_id, roi, frame,
|
|
||||||
boxes, scores, class_ids,
|
if len(boxes) > 0:
|
||||||
scale_info
|
self._handle_detections(
|
||||||
)
|
camera_id, roi, frame,
|
||||||
|
boxes, scores, class_ids,
|
||||||
|
scale_info
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.error(f"处理ROI帧失败: {e}")
|
self._logger.error(f"批量处理 ROI 失败: {e}")
|
||||||
|
|
||||||
|
def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi):
|
||||||
|
"""收集 ROI 帧数据 - 已集成到 _process_frame 中"""
|
||||||
|
pass
|
||||||
|
|
||||||
def _build_tracks(
|
def _build_tracks(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user