44 lines
1.2 KiB
Python
44 lines
1.2 KiB
Python
|
|
"""检查 TensorRT Engine 的实际 shape"""
|
||
|
|
import sys
|
||
|
|
import os
|
||
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
|
|
||
|
|
try:
|
||
|
|
import tensorrt as trt
|
||
|
|
|
||
|
|
engine_path = "./models/yolo11n.engine"
|
||
|
|
|
||
|
|
with open(engine_path, "rb") as f:
|
||
|
|
runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
|
||
|
|
engine = runtime.deserialize_cuda_engine(f.read())
|
||
|
|
|
||
|
|
print("=" * 60)
|
||
|
|
print("Engine Binding Information")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
for i in range(engine.num_bindings):
|
||
|
|
name = engine.get_binding_name(i)
|
||
|
|
shape = engine.get_binding_shape(i)
|
||
|
|
dtype = trt.nptype(engine.get_binding_dtype(i))
|
||
|
|
is_input = engine.binding_is_input(i)
|
||
|
|
|
||
|
|
size = trt.volume(shape)
|
||
|
|
|
||
|
|
print(f"\nBinding {i}:")
|
||
|
|
print(f" Name: {name}")
|
||
|
|
print(f" Shape: {shape}")
|
||
|
|
print(f" Dtype: {dtype}")
|
||
|
|
print(f" Size: {size}")
|
||
|
|
print(f" Is Input: {is_input}")
|
||
|
|
|
||
|
|
if is_input:
|
||
|
|
print(f" Total Elements: {size}")
|
||
|
|
print(f" Expected Batch Size: {shape[0] if len(shape) > 0 else 'N/A'}")
|
||
|
|
|
||
|
|
print("\n" + "=" * 60)
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Error: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|