215 lines
6.0 KiB
Python
215 lines
6.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
动态批次 TensorRT 引擎构建器(TensorRT 10.14.1 终极兼容版)
|
||
支持 YOLO ONNX
|
||
支持 batch: 1-32
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import torch
|
||
|
||
|
||
def build_dynamic_tensorrt_engine(
|
||
onnx_path,
|
||
engine_path,
|
||
use_fp16=True,
|
||
min_bs=1,
|
||
opt_bs=8,
|
||
max_bs=32
|
||
):
|
||
print("🔧 第二步: 构建 TensorRT 引擎...")
|
||
|
||
try:
|
||
import tensorrt as trt
|
||
|
||
if os.path.exists(engine_path):
|
||
os.remove(engine_path)
|
||
print(f"🗑️ 删除旧 engine 文件: {engine_path}")
|
||
|
||
logger = trt.Logger(trt.Logger.INFO)
|
||
builder = trt.Builder(logger)
|
||
network = builder.create_network(
|
||
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||
)
|
||
parser = trt.OnnxParser(network, logger)
|
||
|
||
print(f"📁 解析 ONNX 模型: {onnx_path}")
|
||
with open(onnx_path, "rb") as f:
|
||
if not parser.parse(f.read()):
|
||
print("❌ ONNX 解析失败:")
|
||
for i in range(parser.num_errors):
|
||
print(f" {parser.get_error(i)}")
|
||
return None
|
||
|
||
config = builder.create_builder_config()
|
||
config.set_memory_pool_limit(
|
||
trt.MemoryPoolType.WORKSPACE, 8 << 30
|
||
)
|
||
if use_fp16:
|
||
config.set_flag(trt.BuilderFlag.FP16)
|
||
|
||
profile = builder.create_optimization_profile()
|
||
input_name = network.get_input(0).name
|
||
profile.set_shape(
|
||
input_name,
|
||
(min_bs, 3, 640, 640),
|
||
(opt_bs, 3, 640, 640),
|
||
(max_bs, 3, 640, 640),
|
||
)
|
||
config.add_optimization_profile(profile)
|
||
|
||
print(f"🎯 动态形状配置: min={min_bs}, opt={opt_bs}, max={max_bs}")
|
||
|
||
print("⏳ 开始构建 TensorRT 引擎(可能需要几分钟)...")
|
||
start_time = time.time()
|
||
|
||
serialized_engine = builder.build_serialized_network(
|
||
network, config
|
||
)
|
||
|
||
build_time = time.time() - start_time
|
||
|
||
if serialized_engine is None:
|
||
print("❌ TensorRT 引擎构建失败")
|
||
return None
|
||
|
||
with open(engine_path, "wb") as f:
|
||
f.write(serialized_engine)
|
||
|
||
file_size = os.path.getsize(engine_path) / (1024 * 1024)
|
||
print(
|
||
f"✅ 引擎构建完成: {engine_path} ({file_size:.1f} MB)"
|
||
)
|
||
print(f"⏱️ 构建耗时: {build_time:.1f} 秒")
|
||
|
||
return engine_path
|
||
|
||
except Exception as e:
|
||
print(f"❌ 引擎构建失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
|
||
def test_dynamic_engine_shapes(engine_path):
|
||
print(f"\n🧪 测试动态 engine 支持的批次: {engine_path}")
|
||
|
||
try:
|
||
import tensorrt as trt
|
||
import pycuda.driver as cuda
|
||
import pycuda.autoinit # noqa
|
||
|
||
logger = trt.Logger(trt.Logger.WARNING)
|
||
runtime = trt.Runtime(logger)
|
||
|
||
with open(engine_path, "rb") as f:
|
||
engine = runtime.deserialize_cuda_engine(f.read())
|
||
|
||
if engine is None:
|
||
print("❌ 引擎加载失败")
|
||
return []
|
||
|
||
print(
|
||
"⚡ 引擎是否使用 EXPLICIT_BATCH:",
|
||
not engine.has_implicit_batch_dimension,
|
||
)
|
||
|
||
context = engine.create_execution_context()
|
||
|
||
# TRT 10.x 必须 async 选 profile
|
||
stream = cuda.Stream()
|
||
context.set_optimization_profile_async(0, stream.handle)
|
||
|
||
# -------- TensorRT 10.x 正确获取输入 tensor --------
|
||
input_name = None
|
||
for i in range(engine.num_io_tensors):
|
||
name = engine.get_tensor_name(i)
|
||
mode = engine.get_tensor_mode(name)
|
||
if mode == trt.TensorIOMode.INPUT:
|
||
input_name = name
|
||
break
|
||
|
||
if input_name is None:
|
||
print("❌ 找不到输入张量")
|
||
return []
|
||
|
||
print(f"📊 输入张量: {input_name}")
|
||
|
||
supported_batches = []
|
||
|
||
for batch_size in [1, 2, 4, 8, 16, 32]:
|
||
try:
|
||
context.set_input_shape(
|
||
input_name,
|
||
(batch_size, 3, 640, 640)
|
||
)
|
||
|
||
if context.all_binding_shapes_specified:
|
||
supported_batches.append(batch_size)
|
||
print(f" ✅ 批次 {batch_size} 支持")
|
||
else:
|
||
print(f" ❌ 批次 {batch_size} 形状未就绪")
|
||
|
||
except Exception as e:
|
||
print(f" ❌ 批次 {batch_size} 不支持: {e}")
|
||
|
||
print(f"\n🎯 支持的批次大小: {supported_batches}")
|
||
return supported_batches
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return []
|
||
|
||
|
||
def main():
|
||
print("动态批次 TensorRT 引擎构建器(TensorRT 10.14.1 终极兼容版)")
|
||
print("=" * 60)
|
||
|
||
model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
|
||
onnx_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.onnx"
|
||
engine_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.engine"
|
||
|
||
if not os.path.exists(model_path):
|
||
print(f"❌ 模型文件不存在: {model_path}")
|
||
return
|
||
|
||
if not torch.cuda.is_available():
|
||
print("❌ CUDA 不可用")
|
||
return
|
||
|
||
print(f"✅ CUDA 可用,设备: {torch.cuda.get_device_name(0)}")
|
||
|
||
if not os.path.exists(onnx_path):
|
||
print("❌ ONNX 不存在,请先导出动态 ONNX")
|
||
return
|
||
else:
|
||
print(f"✅ ONNX 文件已存在: {onnx_path}")
|
||
|
||
engine_path = build_dynamic_tensorrt_engine(
|
||
onnx_path,
|
||
engine_path,
|
||
use_fp16=True,
|
||
min_bs=1,
|
||
opt_bs=8,
|
||
max_bs=32,
|
||
)
|
||
|
||
if not engine_path:
|
||
return
|
||
|
||
supported_batches = test_dynamic_engine_shapes(engine_path)
|
||
|
||
if supported_batches:
|
||
print(
|
||
f"\n🎉 TensorRT 引擎准备就绪! 支持批次: {supported_batches}"
|
||
)
|
||
else:
|
||
print("⚠️ 引擎构建完成但不支持任何动态批次")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|