70 lines
2.2 KiB
Python
70 lines
2.2 KiB
Python
|
|
"""调试 TensorRT 输出 shape"""
|
||
|
|
import sys
|
||
|
|
import os
|
||
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
try:
|
||
|
|
import tensorrt as trt
|
||
|
|
import pycuda.driver as cuda
|
||
|
|
import pycuda.autoinit
|
||
|
|
|
||
|
|
engine_path = "./models/yolo11n.engine"
|
||
|
|
|
||
|
|
with open(engine_path, "rb") as f:
|
||
|
|
runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
|
||
|
|
engine = runtime.deserialize_cuda_engine(f.read())
|
||
|
|
|
||
|
|
context = engine.create_execution_context()
|
||
|
|
|
||
|
|
input_shape = (1, 3, 480, 480)
|
||
|
|
input_data = np.random.randn(*input_shape).astype(np.float32)
|
||
|
|
|
||
|
|
input_binding_idx = 0
|
||
|
|
output_binding_idx = 1
|
||
|
|
|
||
|
|
output_shape = engine.get_binding_shape(output_binding_idx)
|
||
|
|
print(f"Engine 定义 output shape: {output_shape}")
|
||
|
|
|
||
|
|
context.set_input_shape(engine.get_binding_name(0), input_shape)
|
||
|
|
|
||
|
|
output_size = int(np.prod([max(1, s) for s in output_shape]))
|
||
|
|
|
||
|
|
h_input = cuda.pagelocked_empty(input_data.size, np.float32)
|
||
|
|
h_output = cuda.pagelocked_empty(output_size, np.float32)
|
||
|
|
|
||
|
|
np.copyto(h_input, input_data.ravel())
|
||
|
|
|
||
|
|
d_input = cuda.mem_alloc(h_input.nbytes)
|
||
|
|
d_output = cuda.mem_alloc(h_output.nbytes)
|
||
|
|
|
||
|
|
bindings = [int(d_input), int(d_output)]
|
||
|
|
|
||
|
|
cuda.memcpy_htod(d_input, h_input)
|
||
|
|
context.execute_v2(bindings=bindings)
|
||
|
|
cuda.memcpy_dtoh(h_output, d_output)
|
||
|
|
|
||
|
|
output_array = h_output.reshape(output_shape)
|
||
|
|
|
||
|
|
print(f"\n实际输出:")
|
||
|
|
print(f" dtype: {output_array.dtype}")
|
||
|
|
print(f" shape: {output_array.shape}")
|
||
|
|
print(f" ndim: {output_array.ndim}")
|
||
|
|
|
||
|
|
if output_array.ndim == 1:
|
||
|
|
print(f" total elements: {output_array.shape[0]}")
|
||
|
|
print(f" expected (84*4725): {84 * 4725}")
|
||
|
|
elif output_array.ndim == 2:
|
||
|
|
print(f" shape[0]: {output_array.shape[0]} (detections)")
|
||
|
|
print(f" shape[1]: {output_array.shape[1]} (features)")
|
||
|
|
elif output_array.ndim == 3:
|
||
|
|
print(f" shape[0]: {output_array.shape[0]} (batch)")
|
||
|
|
print(f" shape[1]: {output_array.shape[1]} (classes+coords)")
|
||
|
|
print(f" shape[2]: {output_array.shape[2]} (num_boxes)")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|