feat: 批处理多ROI推理 - 添加 batch_process_detections 批量后处理 - 重构 _process_frame 收集多ROI统一推理
This commit is contained in:
@@ -701,6 +701,98 @@ class PostProcessor:
|
||||
self._buffer_class_ids[:valid_count].copy()
|
||||
)
|
||||
|
||||
def batch_process_detections(
|
||||
self,
|
||||
batch_outputs: List[np.ndarray],
|
||||
batch_size: int,
|
||||
conf_threshold: Optional[float] = None,
|
||||
nms_threshold: Optional[float] = None
|
||||
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
批量处理多 ROI 检测结果
|
||||
|
||||
Args:
|
||||
batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组)
|
||||
batch_size: batch 大小
|
||||
conf_threshold: 置信度阈值
|
||||
nms_threshold: NMS阈值
|
||||
|
||||
Returns:
|
||||
List of (检测框, 置信度, 类别ID) per ROI
|
||||
"""
|
||||
if conf_threshold is None:
|
||||
conf_threshold = self.conf_threshold
|
||||
if nms_threshold is None:
|
||||
nms_threshold = self.nms_threshold
|
||||
|
||||
if len(batch_outputs) == 1:
|
||||
first_output = batch_outputs[0]
|
||||
if isinstance(first_output, np.ndarray) and first_output.ndim == 3:
|
||||
if first_output.shape[0] == batch_size:
|
||||
outputs_array = first_output
|
||||
else:
|
||||
outputs_array = first_output
|
||||
else:
|
||||
outputs_array = np.array([first_output])
|
||||
else:
|
||||
outputs_array = np.array(batch_outputs)
|
||||
|
||||
results = []
|
||||
nms_processor = NMSProcessor(nms_threshold, use_gpu=True)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
if outputs_array.ndim == 3:
|
||||
if outputs_array.shape[0] == batch_size:
|
||||
output = outputs_array[batch_idx]
|
||||
else:
|
||||
output = outputs_array[0]
|
||||
elif outputs_array.ndim == 2:
|
||||
output = outputs_array
|
||||
else:
|
||||
if batch_idx < len(outputs_array):
|
||||
output = outputs_array[batch_idx]
|
||||
else:
|
||||
output = outputs_array[0]
|
||||
|
||||
if output.ndim == 3:
|
||||
output = output[0]
|
||||
|
||||
if output.shape[0] != 84:
|
||||
results.append((np.array([]), np.array([]), np.array([])))
|
||||
continue
|
||||
|
||||
boxes_xywh = output[0:4, :].T
|
||||
obj_conf = output[4, :]
|
||||
person_scores = output[5, :]
|
||||
scores = obj_conf * person_scores
|
||||
|
||||
coarse_mask = scores > 0.3
|
||||
if not np.any(coarse_mask):
|
||||
results.append((np.array([]), np.array([]), np.array([])))
|
||||
continue
|
||||
|
||||
boxes = boxes_xywh[coarse_mask]
|
||||
scores_coarse = scores[coarse_mask]
|
||||
class_ids = np.zeros(len(boxes), dtype=np.int32)
|
||||
|
||||
valid_count = len(boxes)
|
||||
|
||||
self._buffer_xyxy[:valid_count, 0] = boxes[:, 0] - boxes[:, 2] / 2
|
||||
self._buffer_xyxy[:valid_count, 1] = boxes[:, 1] - boxes[:, 3] / 2
|
||||
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
|
||||
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
|
||||
|
||||
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
|
||||
self._buffer_xyxy[:valid_count],
|
||||
scores_coarse.astype(np.float32),
|
||||
class_ids,
|
||||
conf_threshold
|
||||
)
|
||||
|
||||
results.append((keep_boxes, keep_scores, keep_classes))
|
||||
|
||||
return results
|
||||
|
||||
def filter_by_roi(
|
||||
self,
|
||||
boxes: np.ndarray,
|
||||
|
||||
Reference in New Issue
Block a user