TensorRT测试

This commit is contained in:
2026-01-20 11:14:10 +08:00
commit 189c6fa786
51 changed files with 9251 additions and 0 deletions

View File

@@ -0,0 +1,214 @@
#!/usr/bin/env python3
"""
动态批次 TensorRT 引擎构建器TensorRT 10.14.1 终极兼容版)
支持 YOLO ONNX
支持 batch: 1-32
"""
import os
import time
import torch
def build_dynamic_tensorrt_engine(
onnx_path,
engine_path,
use_fp16=True,
min_bs=1,
opt_bs=8,
max_bs=32
):
print("🔧 第二步: 构建 TensorRT 引擎...")
try:
import tensorrt as trt
if os.path.exists(engine_path):
os.remove(engine_path)
print(f"🗑️ 删除旧 engine 文件: {engine_path}")
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
print(f"📁 解析 ONNX 模型: {onnx_path}")
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
print("❌ ONNX 解析失败:")
for i in range(parser.num_errors):
print(f" {parser.get_error(i)}")
return None
config = builder.create_builder_config()
config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE, 8 << 30
)
if use_fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
input_name = network.get_input(0).name
profile.set_shape(
input_name,
(min_bs, 3, 640, 640),
(opt_bs, 3, 640, 640),
(max_bs, 3, 640, 640),
)
config.add_optimization_profile(profile)
print(f"🎯 动态形状配置: min={min_bs}, opt={opt_bs}, max={max_bs}")
print("⏳ 开始构建 TensorRT 引擎(可能需要几分钟)...")
start_time = time.time()
serialized_engine = builder.build_serialized_network(
network, config
)
build_time = time.time() - start_time
if serialized_engine is None:
print("❌ TensorRT 引擎构建失败")
return None
with open(engine_path, "wb") as f:
f.write(serialized_engine)
file_size = os.path.getsize(engine_path) / (1024 * 1024)
print(
f"✅ 引擎构建完成: {engine_path} ({file_size:.1f} MB)"
)
print(f"⏱️ 构建耗时: {build_time:.1f}")
return engine_path
except Exception as e:
print(f"❌ 引擎构建失败: {e}")
import traceback
traceback.print_exc()
return None
def test_dynamic_engine_shapes(engine_path):
print(f"\n🧪 测试动态 engine 支持的批次: {engine_path}")
try:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open(engine_path, "rb") as f:
engine = runtime.deserialize_cuda_engine(f.read())
if engine is None:
print("❌ 引擎加载失败")
return []
print(
"⚡ 引擎是否使用 EXPLICIT_BATCH:",
not engine.has_implicit_batch_dimension,
)
context = engine.create_execution_context()
# TRT 10.x 必须 async 选 profile
stream = cuda.Stream()
context.set_optimization_profile_async(0, stream.handle)
# -------- TensorRT 10.x 正确获取输入 tensor --------
input_name = None
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
mode = engine.get_tensor_mode(name)
if mode == trt.TensorIOMode.INPUT:
input_name = name
break
if input_name is None:
print("❌ 找不到输入张量")
return []
print(f"📊 输入张量: {input_name}")
supported_batches = []
for batch_size in [1, 2, 4, 8, 16, 32]:
try:
context.set_input_shape(
input_name,
(batch_size, 3, 640, 640)
)
if context.all_binding_shapes_specified:
supported_batches.append(batch_size)
print(f" ✅ 批次 {batch_size} 支持")
else:
print(f" ❌ 批次 {batch_size} 形状未就绪")
except Exception as e:
print(f" ❌ 批次 {batch_size} 不支持: {e}")
print(f"\n🎯 支持的批次大小: {supported_batches}")
return supported_batches
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return []
def main():
print("动态批次 TensorRT 引擎构建器TensorRT 10.14.1 终极兼容版)")
print("=" * 60)
model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
onnx_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.onnx"
engine_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.engine"
if not os.path.exists(model_path):
print(f"❌ 模型文件不存在: {model_path}")
return
if not torch.cuda.is_available():
print("❌ CUDA 不可用")
return
print(f"✅ CUDA 可用,设备: {torch.cuda.get_device_name(0)}")
if not os.path.exists(onnx_path):
print("❌ ONNX 不存在,请先导出动态 ONNX")
return
else:
print(f"✅ ONNX 文件已存在: {onnx_path}")
engine_path = build_dynamic_tensorrt_engine(
onnx_path,
engine_path,
use_fp16=True,
min_bs=1,
opt_bs=8,
max_bs=32,
)
if not engine_path:
return
supported_batches = test_dynamic_engine_shapes(engine_path)
if supported_batches:
print(
f"\n🎉 TensorRT 引擎准备就绪! 支持批次: {supported_batches}"
)
else:
print("⚠️ 引擎构建完成但不支持任何动态批次")
if __name__ == "__main__":
main()