Files
Test_AI/dynamic_batch_tensorrt_builder.py
2026-01-20 11:14:10 +08:00

215 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
动态批次 TensorRT 引擎构建器TensorRT 10.14.1 终极兼容版)
支持 YOLO ONNX
支持 batch: 1-32
"""
import os
import time
import torch
def build_dynamic_tensorrt_engine(
onnx_path,
engine_path,
use_fp16=True,
min_bs=1,
opt_bs=8,
max_bs=32
):
print("🔧 第二步: 构建 TensorRT 引擎...")
try:
import tensorrt as trt
if os.path.exists(engine_path):
os.remove(engine_path)
print(f"🗑️ 删除旧 engine 文件: {engine_path}")
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
print(f"📁 解析 ONNX 模型: {onnx_path}")
with open(onnx_path, "rb") as f:
if not parser.parse(f.read()):
print("❌ ONNX 解析失败:")
for i in range(parser.num_errors):
print(f" {parser.get_error(i)}")
return None
config = builder.create_builder_config()
config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE, 8 << 30
)
if use_fp16:
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
input_name = network.get_input(0).name
profile.set_shape(
input_name,
(min_bs, 3, 640, 640),
(opt_bs, 3, 640, 640),
(max_bs, 3, 640, 640),
)
config.add_optimization_profile(profile)
print(f"🎯 动态形状配置: min={min_bs}, opt={opt_bs}, max={max_bs}")
print("⏳ 开始构建 TensorRT 引擎(可能需要几分钟)...")
start_time = time.time()
serialized_engine = builder.build_serialized_network(
network, config
)
build_time = time.time() - start_time
if serialized_engine is None:
print("❌ TensorRT 引擎构建失败")
return None
with open(engine_path, "wb") as f:
f.write(serialized_engine)
file_size = os.path.getsize(engine_path) / (1024 * 1024)
print(
f"✅ 引擎构建完成: {engine_path} ({file_size:.1f} MB)"
)
print(f"⏱️ 构建耗时: {build_time:.1f}")
return engine_path
except Exception as e:
print(f"❌ 引擎构建失败: {e}")
import traceback
traceback.print_exc()
return None
def test_dynamic_engine_shapes(engine_path):
print(f"\n🧪 测试动态 engine 支持的批次: {engine_path}")
try:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit # noqa
logger = trt.Logger(trt.Logger.WARNING)
runtime = trt.Runtime(logger)
with open(engine_path, "rb") as f:
engine = runtime.deserialize_cuda_engine(f.read())
if engine is None:
print("❌ 引擎加载失败")
return []
print(
"⚡ 引擎是否使用 EXPLICIT_BATCH:",
not engine.has_implicit_batch_dimension,
)
context = engine.create_execution_context()
# TRT 10.x 必须 async 选 profile
stream = cuda.Stream()
context.set_optimization_profile_async(0, stream.handle)
# -------- TensorRT 10.x 正确获取输入 tensor --------
input_name = None
for i in range(engine.num_io_tensors):
name = engine.get_tensor_name(i)
mode = engine.get_tensor_mode(name)
if mode == trt.TensorIOMode.INPUT:
input_name = name
break
if input_name is None:
print("❌ 找不到输入张量")
return []
print(f"📊 输入张量: {input_name}")
supported_batches = []
for batch_size in [1, 2, 4, 8, 16, 32]:
try:
context.set_input_shape(
input_name,
(batch_size, 3, 640, 640)
)
if context.all_binding_shapes_specified:
supported_batches.append(batch_size)
print(f" ✅ 批次 {batch_size} 支持")
else:
print(f" ❌ 批次 {batch_size} 形状未就绪")
except Exception as e:
print(f" ❌ 批次 {batch_size} 不支持: {e}")
print(f"\n🎯 支持的批次大小: {supported_batches}")
return supported_batches
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return []
def main():
print("动态批次 TensorRT 引擎构建器TensorRT 10.14.1 终极兼容版)")
print("=" * 60)
model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
onnx_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.onnx"
engine_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.engine"
if not os.path.exists(model_path):
print(f"❌ 模型文件不存在: {model_path}")
return
if not torch.cuda.is_available():
print("❌ CUDA 不可用")
return
print(f"✅ CUDA 可用,设备: {torch.cuda.get_device_name(0)}")
if not os.path.exists(onnx_path):
print("❌ ONNX 不存在,请先导出动态 ONNX")
return
else:
print(f"✅ ONNX 文件已存在: {onnx_path}")
engine_path = build_dynamic_tensorrt_engine(
onnx_path,
engine_path,
use_fp16=True,
min_bs=1,
opt_bs=8,
max_bs=32,
)
if not engine_path:
return
supported_batches = test_dynamic_engine_shapes(engine_path)
if supported_batches:
print(
f"\n🎉 TensorRT 引擎准备就绪! 支持批次: {supported_batches}"
)
else:
print("⚠️ 引擎构建完成但不支持任何动态批次")
if __name__ == "__main__":
main()