fix: YOLO TensorRT 输出解析修复
- TensorRT 输出 shape: (1, 84, 4725) (84, 4725) - 正确解析 YOLO 输出格式: boxes[0:4], obj_conf[4], cls_scores[5:] - 移除错误的 detection 遍历逻辑 - 工业级向量化操作代替 Python 循环
This commit is contained in:
@@ -600,54 +600,44 @@ class PostProcessor:
|
||||
if output.ndim == 3:
|
||||
output = output[0]
|
||||
|
||||
if output.ndim == 2:
|
||||
output = output.reshape(-1)
|
||||
|
||||
num_detections = output.shape[0] // 85
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
class_ids = []
|
||||
|
||||
for i in range(num_detections):
|
||||
start_idx = i * 85
|
||||
detection = output[start_idx:start_idx + 85]
|
||||
|
||||
x_center = detection[0]
|
||||
y_center = detection[1]
|
||||
width = detection[2]
|
||||
height = detection[3]
|
||||
|
||||
obj_conf = detection[4]
|
||||
|
||||
class_scores = detection[5:]
|
||||
if len(class_scores) == 0:
|
||||
continue
|
||||
|
||||
class_id = np.argmax(class_scores)
|
||||
class_conf = class_scores[class_id]
|
||||
|
||||
total_conf = obj_conf * class_conf
|
||||
|
||||
if total_conf < 0.0:
|
||||
continue
|
||||
|
||||
x1 = x_center - width / 2
|
||||
y1 = y_center - height / 2
|
||||
x2 = x_center + width / 2
|
||||
y2 = y_center + height / 2
|
||||
|
||||
boxes.append([x1, y1, x2, y2])
|
||||
scores.append(float(total_conf))
|
||||
class_ids.append(int(class_id))
|
||||
|
||||
if not boxes:
|
||||
if output.ndim != 2:
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
if output.shape[0] != 84:
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
num_boxes = output.shape[1]
|
||||
|
||||
boxes_xywh = output[0:4, :].T
|
||||
|
||||
obj_conf = output[4, :]
|
||||
|
||||
cls_scores = output[5:, :]
|
||||
|
||||
cls_ids = np.argmax(cls_scores, axis=0)
|
||||
cls_conf = cls_scores[cls_ids, np.arange(num_boxes)]
|
||||
|
||||
scores = obj_conf * cls_conf
|
||||
|
||||
valid_mask = scores > self._conf_threshold
|
||||
|
||||
if not np.any(valid_mask):
|
||||
return np.array([]), np.array([]), np.array([])
|
||||
|
||||
boxes = boxes_xywh[valid_mask]
|
||||
scores = scores[valid_mask]
|
||||
class_ids = cls_ids[valid_mask]
|
||||
|
||||
boxes_xyxy = np.zeros_like(boxes)
|
||||
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2
|
||||
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2
|
||||
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2
|
||||
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2
|
||||
|
||||
return (
|
||||
np.array(boxes, dtype=np.float32),
|
||||
np.array(scores, dtype=np.float32),
|
||||
np.array(class_ids, dtype=np.int32)
|
||||
boxes_xyxy.astype(np.float32),
|
||||
scores.astype(np.float32),
|
||||
class_ids.astype(np.int32)
|
||||
)
|
||||
|
||||
def filter_by_roi(
|
||||
|
||||
Reference in New Issue
Block a user