更新TensorRT

This commit is contained in:
2026-01-21 13:28:42 +08:00
parent f9c7f9018e
commit 1e562798eb
2 changed files with 95 additions and 28 deletions

View File

@@ -10,37 +10,35 @@ sys.path.insert(0, project_root)
import torch
from ultralytics import YOLO
import tensorrt as trt
import onnx
def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True):
def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True, imgsz=640):
"""构建TensorRT引擎"""
from tensorrt import Builder, NetworkDefinitionLayer, Runtime
from tensorrt.parsers import onnxparser
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(network_flags)
parser = onnxparser.create_onnx_parser(network)
parser.parse(onnx_path)
parser.report_status()
# 动态形状配置
if dynamic_batch:
profile = builder.create_optimization_profile()
min_shape = (1, 3, 480, 480)
opt_shape = (4, 3, 480, 480)
max_shape = (8, 3, 480, 480)
profile.set_shape("input", min_shape, opt_shape, max_shape)
network.get_input(0).set_dynamic_range(-1.0, 1.0)
network.set_precision_constraints(trt.PrecisionConstraints.PREFER)
parser = trt.OnnxParser(network, logger)
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise RuntimeError("ONNX 解析失败")
config = builder.create_builder_config()
config.set_memory_allocator(trt.MemoryAllocator())
config.max_workspace_size = 4 << 30 # 4GB
if dynamic_batch:
profile = builder.create_optimization_profile()
min_shape = (1, 3, imgsz, imgsz)
opt_shape = (4, 3, imgsz, imgsz)
max_shape = (8, 3, imgsz, imgsz)
profile.set_shape("images", min_shape, opt_shape, max_shape)
config.add_optimization_profile(profile)
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
@@ -50,10 +48,10 @@ def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True):
with open(engine_path, "wb") as f:
f.write(serialized_engine)
print(f"TensorRT引擎已保存: {engine_path}")
print(f"TensorRT引擎已保存: {engine_path}")
def export_onnx(model_path, onnx_path, imgsz=480):
def export_onnx(model_path, onnx_path, imgsz=640):
"""导出ONNX模型"""
model = YOLO(model_path)
model.export(
@@ -63,17 +61,24 @@ def export_onnx(model_path, onnx_path, imgsz=480):
opset=12,
dynamic=True,
)
print(f"✅ ONNX模型已导出: {onnx_path}")
import shutil
import glob
onnx_files = glob.glob("yolo11n*.onnx")
if onnx_files:
shutil.move(onnx_files[0], onnx_path)
print(f"ONNX模型已导出: {onnx_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TensorRT Engine Builder")
parser.add_argument("--model", type=str, default="models/yolo11n.pt",
help="YOLO模型路径")
parser.add_argument("--engine", type=str, default="models/yolo11n_fp16_480.engine",
parser.add_argument("--engine", type=str, default="models/yolo11n.engine",
help="输出引擎路径")
parser.add_argument("--onnx", type=str, default="models/yolo11n_480.onnx",
help="临时ONNX路径")
parser.add_argument("--onnx", type=str, default="models/yolo11n.onnx",
help="ONNX模型路径")
parser.add_argument("--imgsz", type=int, default=640,
help="输入图像尺寸")
parser.add_argument("--fp16", action="store_true", default=True,
help="启用FP16")
parser.add_argument("--no-dynamic", action="store_true",
@@ -82,8 +87,10 @@ if __name__ == "__main__":
args = parser.parse_args()
os.makedirs(os.path.dirname(args.engine), exist_ok=True)
onnx_dir = os.path.dirname(args.onnx) if os.path.dirname(args.onnx) else '.'
os.makedirs(onnx_dir, exist_ok=True)
if not os.path.exists(args.onnx):
export_onnx(args.model, args.onnx)
export_onnx(args.model, args.onnx, args.imgsz)
build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic)
build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic, args.imgsz)