fix: YOLO TensorRT 输出解析修复

- TensorRT 输出 shape: (1, 84, 4725)  (84, 4725)
- 正确解析 YOLO 输出格式: boxes[0:4], obj_conf[4], cls_scores[5:]
- 移除错误的 detection 遍历逻辑
- 工业级向量化操作代替 Python 循环
This commit is contained in:
2026-02-02 15:02:58 +08:00
parent 745cadc8e7
commit 3dd4e56f99
12 changed files with 9951 additions and 45 deletions

View File

@@ -600,54 +600,44 @@ class PostProcessor:
if output.ndim == 3:
output = output[0]
if output.ndim == 2:
output = output.reshape(-1)
num_detections = output.shape[0] // 85
boxes = []
scores = []
class_ids = []
for i in range(num_detections):
start_idx = i * 85
detection = output[start_idx:start_idx + 85]
x_center = detection[0]
y_center = detection[1]
width = detection[2]
height = detection[3]
obj_conf = detection[4]
class_scores = detection[5:]
if len(class_scores) == 0:
continue
class_id = np.argmax(class_scores)
class_conf = class_scores[class_id]
total_conf = obj_conf * class_conf
if total_conf < 0.0:
continue
x1 = x_center - width / 2
y1 = y_center - height / 2
x2 = x_center + width / 2
y2 = y_center + height / 2
boxes.append([x1, y1, x2, y2])
scores.append(float(total_conf))
class_ids.append(int(class_id))
if not boxes:
if output.ndim != 2:
return np.array([]), np.array([]), np.array([])
if output.shape[0] != 84:
return np.array([]), np.array([]), np.array([])
num_boxes = output.shape[1]
boxes_xywh = output[0:4, :].T
obj_conf = output[4, :]
cls_scores = output[5:, :]
cls_ids = np.argmax(cls_scores, axis=0)
cls_conf = cls_scores[cls_ids, np.arange(num_boxes)]
scores = obj_conf * cls_conf
valid_mask = scores > self._conf_threshold
if not np.any(valid_mask):
return np.array([]), np.array([]), np.array([])
boxes = boxes_xywh[valid_mask]
scores = scores[valid_mask]
class_ids = cls_ids[valid_mask]
boxes_xyxy = np.zeros_like(boxes)
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2
return (
np.array(boxes, dtype=np.float32),
np.array(scores, dtype=np.float32),
np.array(class_ids, dtype=np.int32)
boxes_xyxy.astype(np.float32),
scores.astype(np.float32),
class_ids.astype(np.int32)
)
def filter_by_roi(