#!/usr/bin/env python3 """ 动态批次 TensorRT 引擎构建器(TensorRT 10.14.1 终极兼容版) 支持 YOLO ONNX 支持 batch: 1-32 """ import os import time import torch def build_dynamic_tensorrt_engine( onnx_path, engine_path, use_fp16=True, min_bs=1, opt_bs=8, max_bs=32 ): print("🔧 第二步: 构建 TensorRT 引擎...") try: import tensorrt as trt if os.path.exists(engine_path): os.remove(engine_path) print(f"🗑️ 删除旧 engine 文件: {engine_path}") logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, logger) print(f"📁 解析 ONNX 模型: {onnx_path}") with open(onnx_path, "rb") as f: if not parser.parse(f.read()): print("❌ ONNX 解析失败:") for i in range(parser.num_errors): print(f" {parser.get_error(i)}") return None config = builder.create_builder_config() config.set_memory_pool_limit( trt.MemoryPoolType.WORKSPACE, 8 << 30 ) if use_fp16: config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() input_name = network.get_input(0).name profile.set_shape( input_name, (min_bs, 3, 640, 640), (opt_bs, 3, 640, 640), (max_bs, 3, 640, 640), ) config.add_optimization_profile(profile) print(f"🎯 动态形状配置: min={min_bs}, opt={opt_bs}, max={max_bs}") print("⏳ 开始构建 TensorRT 引擎(可能需要几分钟)...") start_time = time.time() serialized_engine = builder.build_serialized_network( network, config ) build_time = time.time() - start_time if serialized_engine is None: print("❌ TensorRT 引擎构建失败") return None with open(engine_path, "wb") as f: f.write(serialized_engine) file_size = os.path.getsize(engine_path) / (1024 * 1024) print( f"✅ 引擎构建完成: {engine_path} ({file_size:.1f} MB)" ) print(f"⏱️ 构建耗时: {build_time:.1f} 秒") return engine_path except Exception as e: print(f"❌ 引擎构建失败: {e}") import traceback traceback.print_exc() return None def test_dynamic_engine_shapes(engine_path): print(f"\n🧪 测试动态 engine 支持的批次: {engine_path}") try: import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit # noqa logger = trt.Logger(trt.Logger.WARNING) runtime = trt.Runtime(logger) with open(engine_path, "rb") as f: engine = runtime.deserialize_cuda_engine(f.read()) if engine is None: print("❌ 引擎加载失败") return [] print( "⚡ 引擎是否使用 EXPLICIT_BATCH:", not engine.has_implicit_batch_dimension, ) context = engine.create_execution_context() # TRT 10.x 必须 async 选 profile stream = cuda.Stream() context.set_optimization_profile_async(0, stream.handle) # -------- TensorRT 10.x 正确获取输入 tensor -------- input_name = None for i in range(engine.num_io_tensors): name = engine.get_tensor_name(i) mode = engine.get_tensor_mode(name) if mode == trt.TensorIOMode.INPUT: input_name = name break if input_name is None: print("❌ 找不到输入张量") return [] print(f"📊 输入张量: {input_name}") supported_batches = [] for batch_size in [1, 2, 4, 8, 16, 32]: try: context.set_input_shape( input_name, (batch_size, 3, 640, 640) ) if context.all_binding_shapes_specified: supported_batches.append(batch_size) print(f" ✅ 批次 {batch_size} 支持") else: print(f" ❌ 批次 {batch_size} 形状未就绪") except Exception as e: print(f" ❌ 批次 {batch_size} 不支持: {e}") print(f"\n🎯 支持的批次大小: {supported_batches}") return supported_batches except Exception as e: print(f"❌ 测试失败: {e}") import traceback traceback.print_exc() return [] def main(): print("动态批次 TensorRT 引擎构建器(TensorRT 10.14.1 终极兼容版)") print("=" * 60) model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt" onnx_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.onnx" engine_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.engine" if not os.path.exists(model_path): print(f"❌ 模型文件不存在: {model_path}") return if not torch.cuda.is_available(): print("❌ CUDA 不可用") return print(f"✅ CUDA 可用,设备: {torch.cuda.get_device_name(0)}") if not os.path.exists(onnx_path): print("❌ ONNX 不存在,请先导出动态 ONNX") return else: print(f"✅ ONNX 文件已存在: {onnx_path}") engine_path = build_dynamic_tensorrt_engine( onnx_path, engine_path, use_fp16=True, min_bs=1, opt_bs=8, max_bs=32, ) if not engine_path: return supported_batches = test_dynamic_engine_shapes(engine_path) if supported_batches: print( f"\n🎉 TensorRT 引擎准备就绪! 支持批次: {supported_batches}" ) else: print("⚠️ 引擎构建完成但不支持任何动态批次") if __name__ == "__main__": main()