97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
# TensorRT Engine 生成脚本
|
|
# 使用方法: python scripts/build_engine.py
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
|
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
sys.path.insert(0, project_root)
|
|
|
|
import torch
|
|
from ultralytics import YOLO
|
|
import tensorrt as trt
|
|
import onnx
|
|
|
|
|
|
def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True, imgsz=640):
|
|
"""构建TensorRT引擎"""
|
|
logger = trt.Logger(trt.Logger.INFO)
|
|
builder = trt.Builder(logger)
|
|
|
|
network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
network = builder.create_network(network_flags)
|
|
|
|
parser = trt.OnnxParser(network, logger)
|
|
with open(onnx_path, 'rb') as f:
|
|
if not parser.parse(f.read()):
|
|
for error in range(parser.num_errors):
|
|
print(parser.get_error(error))
|
|
raise RuntimeError("ONNX 解析失败")
|
|
|
|
config = builder.create_builder_config()
|
|
|
|
if dynamic_batch:
|
|
profile = builder.create_optimization_profile()
|
|
min_shape = (1, 3, imgsz, imgsz)
|
|
opt_shape = (4, 3, imgsz, imgsz)
|
|
max_shape = (8, 3, imgsz, imgsz)
|
|
|
|
profile.set_shape("images", min_shape, opt_shape, max_shape)
|
|
config.add_optimization_profile(profile)
|
|
|
|
if fp16:
|
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
|
|
serialized_engine = builder.build_serialized_network(network, config)
|
|
|
|
with open(engine_path, "wb") as f:
|
|
f.write(serialized_engine)
|
|
|
|
print(f"TensorRT引擎已保存: {engine_path}")
|
|
|
|
|
|
def export_onnx(model_path, onnx_path, imgsz=640):
|
|
"""导出ONNX模型"""
|
|
model = YOLO(model_path)
|
|
model.export(
|
|
format="onnx",
|
|
imgsz=[imgsz, imgsz],
|
|
simplify=True,
|
|
opset=12,
|
|
dynamic=True,
|
|
)
|
|
import shutil
|
|
import glob
|
|
onnx_files = glob.glob("yolo11n*.onnx")
|
|
if onnx_files:
|
|
shutil.move(onnx_files[0], onnx_path)
|
|
print(f"ONNX模型已导出: {onnx_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="TensorRT Engine Builder")
|
|
parser.add_argument("--model", type=str, default="models/yolo11n.pt",
|
|
help="YOLO模型路径")
|
|
parser.add_argument("--engine", type=str, default="models/yolo11n.engine",
|
|
help="输出引擎路径")
|
|
parser.add_argument("--onnx", type=str, default="models/yolo11n.onnx",
|
|
help="ONNX模型路径")
|
|
parser.add_argument("--imgsz", type=int, default=640,
|
|
help="输入图像尺寸")
|
|
parser.add_argument("--fp16", action="store_true", default=True,
|
|
help="启用FP16")
|
|
parser.add_argument("--no-dynamic", action="store_true",
|
|
help="禁用动态Batch")
|
|
|
|
args = parser.parse_args()
|
|
|
|
os.makedirs(os.path.dirname(args.engine), exist_ok=True)
|
|
onnx_dir = os.path.dirname(args.onnx) if os.path.dirname(args.onnx) else '.'
|
|
os.makedirs(onnx_dir, exist_ok=True)
|
|
|
|
if not os.path.exists(args.onnx):
|
|
export_onnx(args.model, args.onnx, args.imgsz)
|
|
|
|
build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic, args.imgsz)
|