# TensorRT Engine 生成脚本 # 使用方法: python scripts/build_engine.py import os import sys import argparse project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) import torch from ultralytics import YOLO def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True): """构建TensorRT引擎""" from tensorrt import Builder, NetworkDefinitionLayer, Runtime from tensorrt.parsers import onnxparser logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(network_flags) parser = onnxparser.create_onnx_parser(network) parser.parse(onnx_path) parser.report_status() # 动态形状配置 if dynamic_batch: profile = builder.create_optimization_profile() min_shape = (1, 3, 480, 480) opt_shape = (4, 3, 480, 480) max_shape = (8, 3, 480, 480) profile.set_shape("input", min_shape, opt_shape, max_shape) network.get_input(0).set_dynamic_range(-1.0, 1.0) network.set_precision_constraints(trt.PrecisionConstraints.PREFER) config = builder.create_builder_config() config.set_memory_allocator(trt.MemoryAllocator()) config.max_workspace_size = 4 << 30 # 4GB if fp16: config.set_flag(trt.BuilderFlag.FP16) serialized_engine = builder.build_serialized_network(network, config) with open(engine_path, "wb") as f: f.write(serialized_engine) print(f"✅ TensorRT引擎已保存: {engine_path}") def export_onnx(model_path, onnx_path, imgsz=480): """导出ONNX模型""" model = YOLO(model_path) model.export( format="onnx", imgsz=[imgsz, imgsz], simplify=True, opset=12, dynamic=True, ) print(f"✅ ONNX模型已导出: {onnx_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="TensorRT Engine Builder") parser.add_argument("--model", type=str, default="models/yolo11n.pt", help="YOLO模型路径") parser.add_argument("--engine", type=str, default="models/yolo11n_fp16_480.engine", help="输出引擎路径") parser.add_argument("--onnx", type=str, default="models/yolo11n_480.onnx", help="临时ONNX路径") parser.add_argument("--fp16", action="store_true", default=True, help="启用FP16") parser.add_argument("--no-dynamic", action="store_true", help="禁用动态Batch") args = parser.parse_args() os.makedirs(os.path.dirname(args.engine), exist_ok=True) if not os.path.exists(args.onnx): export_onnx(args.model, args.onnx) build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic)