# TensorRT Engine 生成脚本 # 使用方法: python scripts/build_engine.py import os import sys import argparse project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, project_root) import torch from ultralytics import YOLO import tensorrt as trt import onnx def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True, imgsz=640): """构建TensorRT引擎""" logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(network_flags) parser = trt.OnnxParser(network, logger) with open(onnx_path, 'rb') as f: if not parser.parse(f.read()): for error in range(parser.num_errors): print(parser.get_error(error)) raise RuntimeError("ONNX 解析失败") config = builder.create_builder_config() if dynamic_batch: profile = builder.create_optimization_profile() min_shape = (1, 3, imgsz, imgsz) opt_shape = (4, 3, imgsz, imgsz) max_shape = (8, 3, imgsz, imgsz) profile.set_shape("images", min_shape, opt_shape, max_shape) config.add_optimization_profile(profile) if fp16: config.set_flag(trt.BuilderFlag.FP16) serialized_engine = builder.build_serialized_network(network, config) with open(engine_path, "wb") as f: f.write(serialized_engine) print(f"TensorRT引擎已保存: {engine_path}") def export_onnx(model_path, onnx_path, imgsz=640): """导出ONNX模型""" model = YOLO(model_path) model.export( format="onnx", imgsz=[imgsz, imgsz], simplify=True, opset=12, dynamic=True, ) import shutil import glob onnx_files = glob.glob("yolo11n*.onnx") if onnx_files: shutil.move(onnx_files[0], onnx_path) print(f"ONNX模型已导出: {onnx_path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="TensorRT Engine Builder") parser.add_argument("--model", type=str, default="models/yolo11n.pt", help="YOLO模型路径") parser.add_argument("--engine", type=str, default="models/yolo11n.engine", help="输出引擎路径") parser.add_argument("--onnx", type=str, default="models/yolo11n.onnx", help="ONNX模型路径") parser.add_argument("--imgsz", type=int, default=640, help="输入图像尺寸") parser.add_argument("--fp16", action="store_true", default=True, help="启用FP16") parser.add_argument("--no-dynamic", action="store_true", help="禁用动态Batch") args = parser.parse_args() os.makedirs(os.path.dirname(args.engine), exist_ok=True) onnx_dir = os.path.dirname(args.onnx) if os.path.dirname(args.onnx) else '.' os.makedirs(onnx_dir, exist_ok=True) if not os.path.exists(args.onnx): export_onnx(args.model, args.onnx, args.imgsz) build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic, args.imgsz)