feat: 批处理多ROI推理 - 添加 batch_process_detections 批量后处理 - 重构 _process_frame 收集多ROI统一推理
This commit is contained in:
90
main.py
90
main.py
@@ -10,7 +10,7 @@ import threading
|
||||
import signal
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
from config.settings import get_settings, Settings
|
||||
from core.config_sync import get_config_sync_manager, ConfigSyncManager
|
||||
@@ -57,6 +57,9 @@ class EdgeInferenceService:
|
||||
"uptime_seconds": 0,
|
||||
}
|
||||
|
||||
self._batch_roi_queue: List[tuple] = []
|
||||
self._batch_lock = threading.Lock()
|
||||
|
||||
self._logger.info("Edge_Inference_Service 初始化开始")
|
||||
|
||||
def _init_database(self):
|
||||
@@ -177,46 +180,53 @@ class EdgeInferenceService:
|
||||
return callback
|
||||
|
||||
def _process_frame(self, camera_id: str, frame: VideoFrame):
|
||||
"""处理视频帧"""
|
||||
"""处理视频帧 - 批量处理多 ROI"""
|
||||
try:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
roi_configs = self._config_manager.get_roi_configs(camera_id)
|
||||
|
||||
roi_items = []
|
||||
for roi in roi_configs:
|
||||
self._process_roi_frame(camera_id, frame, roi)
|
||||
if not roi.enabled:
|
||||
continue
|
||||
try:
|
||||
cropped, scale_info = self._preprocessor.preprocess_single(
|
||||
frame.image, roi
|
||||
)
|
||||
roi_items.append((camera_id, roi, frame, cropped, scale_info))
|
||||
except Exception as e:
|
||||
self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}")
|
||||
|
||||
processing_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
if not roi_items:
|
||||
return
|
||||
|
||||
with self._batch_lock:
|
||||
self._batch_roi_queue.extend(roi_items)
|
||||
|
||||
batch_size = len(self._batch_roi_queue)
|
||||
if batch_size >= 1:
|
||||
self._batch_process_rois()
|
||||
|
||||
self._performance_stats["total_frames_processed"] += 1
|
||||
|
||||
self._frame_counter = getattr(self, '_frame_counter', 0) + 1
|
||||
if self._frame_counter % 10 == 0:
|
||||
self._logger.log_inference_latency(
|
||||
processing_time_ms,
|
||||
batch_size=1
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"处理帧失败 {camera_id}: {e}")
|
||||
|
||||
def _process_roi_frame(
|
||||
self,
|
||||
camera_id: str,
|
||||
frame: VideoFrame,
|
||||
roi
|
||||
):
|
||||
"""处理ROI帧,batch=1 推理"""
|
||||
try:
|
||||
if not roi.enabled:
|
||||
def _batch_process_rois(self):
|
||||
"""批量处理 ROI - 预处理、推理、后处理"""
|
||||
with self._batch_lock:
|
||||
roi_items = self._batch_roi_queue
|
||||
if not roi_items:
|
||||
return
|
||||
|
||||
cropped = self._preprocessor.preprocess_single(frame.image, roi)
|
||||
|
||||
processed_image, scale_info = cropped
|
||||
self._batch_roi_queue = []
|
||||
|
||||
batch_size = len(roi_items)
|
||||
|
||||
try:
|
||||
images = [item[3] for item in roi_items]
|
||||
scale_infos = [item[4] for item in roi_items]
|
||||
|
||||
batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch(
|
||||
[processed_image]
|
||||
images
|
||||
)
|
||||
|
||||
engine = self._engine_manager.get_engine("default")
|
||||
@@ -225,20 +235,28 @@ class EdgeInferenceService:
|
||||
|
||||
outputs, inference_time_ms = engine.infer(batch_data)
|
||||
|
||||
boxes, scores, class_ids = self._postprocessor.process_detections(
|
||||
batch_results = self._postprocessor.batch_process_detections(
|
||||
outputs,
|
||||
batch_size,
|
||||
conf_threshold=self._settings.inference.conf_threshold
|
||||
)
|
||||
|
||||
if len(boxes) > 0:
|
||||
self._handle_detections(
|
||||
camera_id, roi, frame,
|
||||
boxes, scores, class_ids,
|
||||
scale_info
|
||||
)
|
||||
|
||||
for idx, (camera_id, roi, frame, _, scale_info) in enumerate(roi_items):
|
||||
boxes, scores, class_ids = batch_results[idx]
|
||||
|
||||
if len(boxes) > 0:
|
||||
self._handle_detections(
|
||||
camera_id, roi, frame,
|
||||
boxes, scores, class_ids,
|
||||
scale_info
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"处理ROI帧失败: {e}")
|
||||
self._logger.error(f"批量处理 ROI 失败: {e}")
|
||||
|
||||
def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi):
|
||||
"""收集 ROI 帧数据 - 已集成到 _process_frame 中"""
|
||||
pass
|
||||
|
||||
def _build_tracks(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user