fix: YOLO TensorRT 输出解析修复
- TensorRT 输出 shape: (1, 84, 4725) (84, 4725) - 正确解析 YOLO 输出格式: boxes[0:4], obj_conf[4], cls_scores[5:] - 移除错误的 detection 遍历逻辑 - 工业级向量化操作代替 Python 循环
This commit is contained in:
69
debug_output_shape.py
Normal file
69
debug_output_shape.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""调试 TensorRT 输出 shape"""
|
||||
import sys
|
||||
import os
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
|
||||
engine_path = "./models/yolo11n.engine"
|
||||
|
||||
with open(engine_path, "rb") as f:
|
||||
runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
|
||||
engine = runtime.deserialize_cuda_engine(f.read())
|
||||
|
||||
context = engine.create_execution_context()
|
||||
|
||||
input_shape = (1, 3, 480, 480)
|
||||
input_data = np.random.randn(*input_shape).astype(np.float32)
|
||||
|
||||
input_binding_idx = 0
|
||||
output_binding_idx = 1
|
||||
|
||||
output_shape = engine.get_binding_shape(output_binding_idx)
|
||||
print(f"Engine 定义 output shape: {output_shape}")
|
||||
|
||||
context.set_input_shape(engine.get_binding_name(0), input_shape)
|
||||
|
||||
output_size = int(np.prod([max(1, s) for s in output_shape]))
|
||||
|
||||
h_input = cuda.pagelocked_empty(input_data.size, np.float32)
|
||||
h_output = cuda.pagelocked_empty(output_size, np.float32)
|
||||
|
||||
np.copyto(h_input, input_data.ravel())
|
||||
|
||||
d_input = cuda.mem_alloc(h_input.nbytes)
|
||||
d_output = cuda.mem_alloc(h_output.nbytes)
|
||||
|
||||
bindings = [int(d_input), int(d_output)]
|
||||
|
||||
cuda.memcpy_htod(d_input, h_input)
|
||||
context.execute_v2(bindings=bindings)
|
||||
cuda.memcpy_dtoh(h_output, d_output)
|
||||
|
||||
output_array = h_output.reshape(output_shape)
|
||||
|
||||
print(f"\n实际输出:")
|
||||
print(f" dtype: {output_array.dtype}")
|
||||
print(f" shape: {output_array.shape}")
|
||||
print(f" ndim: {output_array.ndim}")
|
||||
|
||||
if output_array.ndim == 1:
|
||||
print(f" total elements: {output_array.shape[0]}")
|
||||
print(f" expected (84*4725): {84 * 4725}")
|
||||
elif output_array.ndim == 2:
|
||||
print(f" shape[0]: {output_array.shape[0]} (detections)")
|
||||
print(f" shape[1]: {output_array.shape[1]} (features)")
|
||||
elif output_array.ndim == 3:
|
||||
print(f" shape[0]: {output_array.shape[0]} (batch)")
|
||||
print(f" shape[1]: {output_array.shape[1]} (classes+coords)")
|
||||
print(f" shape[2]: {output_array.shape[2]} (num_boxes)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
Reference in New Issue
Block a user