Files
security-ai-edge/check_engine.py

44 lines
1.2 KiB
Python
Raw Normal View History

"""检查 TensorRT Engine 的实际 shape"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
try:
import tensorrt as trt
engine_path = "./models/yolo11n.engine"
with open(engine_path, "rb") as f:
runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
engine = runtime.deserialize_cuda_engine(f.read())
print("=" * 60)
print("Engine Binding Information")
print("=" * 60)
for i in range(engine.num_bindings):
name = engine.get_binding_name(i)
shape = engine.get_binding_shape(i)
dtype = trt.nptype(engine.get_binding_dtype(i))
is_input = engine.binding_is_input(i)
size = trt.volume(shape)
print(f"\nBinding {i}:")
print(f" Name: {name}")
print(f" Shape: {shape}")
print(f" Dtype: {dtype}")
print(f" Size: {size}")
print(f" Is Input: {is_input}")
if is_input:
print(f" Total Elements: {size}")
print(f" Expected Batch Size: {shape[0] if len(shape) > 0 else 'N/A'}")
print("\n" + "=" * 60)
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()