ROI选区01
This commit is contained in:
89
scripts/build_engine.py
Normal file
89
scripts/build_engine.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# TensorRT Engine 生成脚本
|
||||
# 使用方法: python scripts/build_engine.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True):
|
||||
"""构建TensorRT引擎"""
|
||||
from tensorrt import Builder, NetworkDefinitionLayer, Runtime
|
||||
from tensorrt.parsers import onnxparser
|
||||
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
|
||||
network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(network_flags)
|
||||
|
||||
parser = onnxparser.create_onnx_parser(network)
|
||||
parser.parse(onnx_path)
|
||||
parser.report_status()
|
||||
|
||||
# 动态形状配置
|
||||
if dynamic_batch:
|
||||
profile = builder.create_optimization_profile()
|
||||
min_shape = (1, 3, 480, 480)
|
||||
opt_shape = (4, 3, 480, 480)
|
||||
max_shape = (8, 3, 480, 480)
|
||||
|
||||
profile.set_shape("input", min_shape, opt_shape, max_shape)
|
||||
network.get_input(0).set_dynamic_range(-1.0, 1.0)
|
||||
network.set_precision_constraints(trt.PrecisionConstraints.PREFER)
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_allocator(trt.MemoryAllocator())
|
||||
config.max_workspace_size = 4 << 30 # 4GB
|
||||
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
serialized_engine = builder.build_serialized_network(network, config)
|
||||
|
||||
with open(engine_path, "wb") as f:
|
||||
f.write(serialized_engine)
|
||||
|
||||
print(f"✅ TensorRT引擎已保存: {engine_path}")
|
||||
|
||||
|
||||
def export_onnx(model_path, onnx_path, imgsz=480):
|
||||
"""导出ONNX模型"""
|
||||
model = YOLO(model_path)
|
||||
model.export(
|
||||
format="onnx",
|
||||
imgsz=[imgsz, imgsz],
|
||||
simplify=True,
|
||||
opset=12,
|
||||
dynamic=True,
|
||||
)
|
||||
print(f"✅ ONNX模型已导出: {onnx_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TensorRT Engine Builder")
|
||||
parser.add_argument("--model", type=str, default="models/yolo11n.pt",
|
||||
help="YOLO模型路径")
|
||||
parser.add_argument("--engine", type=str, default="models/yolo11n_fp16_480.engine",
|
||||
help="输出引擎路径")
|
||||
parser.add_argument("--onnx", type=str, default="models/yolo11n_480.onnx",
|
||||
help="临时ONNX路径")
|
||||
parser.add_argument("--fp16", action="store_true", default=True,
|
||||
help="启用FP16")
|
||||
parser.add_argument("--no-dynamic", action="store_true",
|
||||
help="禁用动态Batch")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(os.path.dirname(args.engine), exist_ok=True)
|
||||
|
||||
if not os.path.exists(args.onnx):
|
||||
export_onnx(args.model, args.onnx)
|
||||
|
||||
build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic)
|
||||
Reference in New Issue
Block a user