feat: 批处理多ROI推理 - 添加 batch_process_detections 批量后处理 - 重构 _process_frame 收集多ROI统一推理

This commit is contained in:
2026-02-03 11:17:54 +08:00
parent d7f56683c7
commit 1caba41625
2 changed files with 146 additions and 36 deletions

View File

@@ -701,6 +701,98 @@ class PostProcessor:
self._buffer_class_ids[:valid_count].copy()
)
def batch_process_detections(
self,
batch_outputs: List[np.ndarray],
batch_size: int,
conf_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
批量处理多 ROI 检测结果
Args:
batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组)
batch_size: batch 大小
conf_threshold: 置信度阈值
nms_threshold: NMS阈值
Returns:
List of (检测框, 置信度, 类别ID) per ROI
"""
if conf_threshold is None:
conf_threshold = self.conf_threshold
if nms_threshold is None:
nms_threshold = self.nms_threshold
if len(batch_outputs) == 1:
first_output = batch_outputs[0]
if isinstance(first_output, np.ndarray) and first_output.ndim == 3:
if first_output.shape[0] == batch_size:
outputs_array = first_output
else:
outputs_array = first_output
else:
outputs_array = np.array([first_output])
else:
outputs_array = np.array(batch_outputs)
results = []
nms_processor = NMSProcessor(nms_threshold, use_gpu=True)
for batch_idx in range(batch_size):
if outputs_array.ndim == 3:
if outputs_array.shape[0] == batch_size:
output = outputs_array[batch_idx]
else:
output = outputs_array[0]
elif outputs_array.ndim == 2:
output = outputs_array
else:
if batch_idx < len(outputs_array):
output = outputs_array[batch_idx]
else:
output = outputs_array[0]
if output.ndim == 3:
output = output[0]
if output.shape[0] != 84:
results.append((np.array([]), np.array([]), np.array([])))
continue
boxes_xywh = output[0:4, :].T
obj_conf = output[4, :]
person_scores = output[5, :]
scores = obj_conf * person_scores
coarse_mask = scores > 0.3
if not np.any(coarse_mask):
results.append((np.array([]), np.array([]), np.array([])))
continue
boxes = boxes_xywh[coarse_mask]
scores_coarse = scores[coarse_mask]
class_ids = np.zeros(len(boxes), dtype=np.int32)
valid_count = len(boxes)
self._buffer_xyxy[:valid_count, 0] = boxes[:, 0] - boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 1] = boxes[:, 1] - boxes[:, 3] / 2
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
self._buffer_xyxy[:valid_count],
scores_coarse.astype(np.float32),
class_ids,
conf_threshold
)
results.append((keep_boxes, keep_scores, keep_classes))
return results
def filter_by_roi(
self,
boxes: np.ndarray,