- Moved all project files and directories (config, core, models, etc.) from edge_inference_service/ to the repository root ai_edge/ - Updated model path in config/settings.py to reflect new structure - Revised usage paths in __init__.py documentation
375 lines
13 KiB
Python
375 lines
13 KiB
Python
"""
|
||
TensorRT Engine 生成脚本 (8GB显存优化版)
|
||
|
||
功能:
|
||
- 将 YOLO11 模型转换为 TensorRT engine 文件
|
||
- 针对 8GB 显存环境优化
|
||
- 支持 FP16 精度、480x480 输入、动态 Batch(1-8)
|
||
- 启用所有高性能优化策略
|
||
|
||
使用方法:
|
||
python build_engine.py --download yolo11n --output models/yolo11n.engine [OPTIONS]
|
||
|
||
示例:
|
||
python build_engine.py --download yolo11n --output models/yolo11n.engine --fp16 --batch 8
|
||
|
||
优化命令参数说明:
|
||
--min-batch 最小Batch大小 (默认: 1)
|
||
--opt-batch 优化Batch大小 (默认: 4) <-- TensorRT会针对此尺寸专门优化
|
||
--max-batch 最大Batch大小 (默认: 8)
|
||
--workspace 工作空间大小MB (默认: 6144,即6GB)
|
||
--tactics 启用优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN)
|
||
--best 全局最优搜索 (默认: 启用)
|
||
--preview 预览特性 (默认: +faster_dynamic_shapes_0805)
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
from pathlib import Path
|
||
|
||
try:
|
||
import tensorrt as trt
|
||
TRT_AVAILABLE = True
|
||
except ImportError:
|
||
TRT_AVAILABLE = False
|
||
trt = None
|
||
|
||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) if TRT_AVAILABLE else None
|
||
|
||
|
||
def check_environment():
|
||
"""检查环境配置"""
|
||
print(f"\n{'='*60}")
|
||
print("环境检查")
|
||
print(f"{'='*60}")
|
||
|
||
try:
|
||
import tensorrt as trt
|
||
print(f" TensorRT: {trt.__version__}")
|
||
except ImportError:
|
||
print(" TensorRT: 未安装")
|
||
|
||
try:
|
||
import torch
|
||
print(f" PyTorch: {torch.__version__}")
|
||
print(f" CUDA: {torch.version.cuda if hasattr(torch, 'version') else 'N/A'}")
|
||
if torch.cuda.is_available():
|
||
print(f" GPU: {torch.cuda.get_device_name(0)}")
|
||
print(f" GPU显存: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
|
||
except ImportError:
|
||
print(" PyTorch: 未安装")
|
||
|
||
try:
|
||
from ultralytics import YOLO
|
||
print(" Ultralytics: 已安装")
|
||
except ImportError:
|
||
print(" Ultralytics: 未安装")
|
||
|
||
|
||
def download_model(model_type: str, save_dir: str = "./models") -> str:
|
||
"""
|
||
下载 YOLO11 模型
|
||
|
||
Args:
|
||
model_type: 模型类型 (yolo11n, yolo11s, yolo11m, yolo11l, yolo11x)
|
||
save_dir: 保存目录
|
||
|
||
Returns:
|
||
模型文件路径
|
||
"""
|
||
from ultralytics import YOLO
|
||
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
model_path = os.path.join(save_dir, f"{model_type}.pt")
|
||
|
||
if os.path.exists(model_path):
|
||
print(f"模型已存在: {model_path}")
|
||
return model_path
|
||
|
||
print(f"\n下载 YOLO11 {model_type} 模型...")
|
||
|
||
model = YOLO(f"{model_type}.pt")
|
||
model.save(model_path)
|
||
|
||
print(f"模型已保存: {model_path}")
|
||
return model_path
|
||
|
||
|
||
def export_onnx(
|
||
model_path: str,
|
||
output_path: str,
|
||
input_size: int = 480,
|
||
dynamic_batch: bool = True
|
||
) -> bool:
|
||
"""
|
||
导出 ONNX 模型
|
||
|
||
Args:
|
||
model_path: 模型路径
|
||
output_path: 输出路径
|
||
input_size: 输入尺寸
|
||
dynamic_batch: 是否使用动态Batch
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
from ultralytics import YOLO
|
||
|
||
print(f"\n导出 ONNX 模型...")
|
||
print(f" 源模型: {model_path}")
|
||
print(f" 输出路径: {output_path}")
|
||
print(f" 输入尺寸: {input_size}x{input_size}")
|
||
print(f" 动态Batch: {'是' if dynamic_batch else '否'}")
|
||
|
||
model = YOLO(model_path)
|
||
|
||
try:
|
||
model.export(
|
||
format='onnx',
|
||
imgsz=input_size,
|
||
dynamic=dynamic_batch,
|
||
simplify=True,
|
||
opset=17
|
||
)
|
||
|
||
if os.path.exists(output_path):
|
||
file_size = os.path.getsize(output_path) / (1024 * 1024)
|
||
print(f"\n✓ ONNX 导出成功!")
|
||
print(f" 文件: {output_path}")
|
||
print(f" 大小: {file_size:.2f} MB")
|
||
return True
|
||
else:
|
||
print("错误: ONNX 导出失败")
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"错误: ONNX 导出失败: {e}")
|
||
return False
|
||
|
||
|
||
def build_engine(
|
||
onnx_path: str,
|
||
engine_path: str,
|
||
input_size: int = 480,
|
||
min_batch: int = 1,
|
||
opt_batch: int = 4,
|
||
max_batch: int = 8,
|
||
fp16: bool = True,
|
||
workspace_mb: int = 6144,
|
||
tactic_sources: str = "+CUBLAS,+CUBLAS_LT,+CUDNN",
|
||
best: bool = True,
|
||
preview: str = "+faster_dynamic_shapes_0805"
|
||
) -> bool:
|
||
"""
|
||
使用 TensorRT 构建 Engine(8GB显存优化版)
|
||
|
||
Args:
|
||
onnx_path: ONNX 模型路径
|
||
engine_path: 输出 engine 路径
|
||
input_size: 输入尺寸
|
||
min_batch: 最小 Batch 大小
|
||
opt_batch: 优化 Batch 大小 (TensorRT 会针对此尺寸专门优化)
|
||
max_batch: 最大 Batch 大小
|
||
fp16: 是否使用 FP16
|
||
workspace_mb: 工作空间大小 (MB)
|
||
tactic_sources: 优化策略
|
||
best: 是否全局最优搜索
|
||
preview: 预览特性
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
print(f"\n{'='*60}")
|
||
print("TensorRT Engine 构建 (8GB显存优化版)")
|
||
print(f"{'='*60}")
|
||
|
||
if not os.path.exists(onnx_path):
|
||
print(f"错误: ONNX 模型不存在: {onnx_path}")
|
||
return False
|
||
|
||
os.makedirs(os.path.dirname(engine_path), exist_ok=True)
|
||
|
||
EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||
|
||
with trt.Builder(TRT_LOGGER) as builder:
|
||
with builder.create_network(EXPLICIT_BATCH) as network:
|
||
with trt.OnnxParser(network, TRT_LOGGER) as parser:
|
||
print(f"加载 ONNX 模型: {onnx_path}")
|
||
with open(onnx_path, 'rb') as f:
|
||
if not parser.parse(f.read()):
|
||
print("错误: ONNX 解析失败")
|
||
for error in range(parser.num_errors):
|
||
print(f" {parser.get_error(error)}")
|
||
return False
|
||
|
||
input_tensor = network.get_input(0)
|
||
input_name = input_tensor.name
|
||
input_shape = input_tensor.shape
|
||
|
||
print(f"\n构建参数:")
|
||
print(f" 输入名称: {input_name}")
|
||
print(f" 输入形状: {input_shape}")
|
||
print(f" 最小Batch: {min_batch}")
|
||
print(f" 优化Batch: {opt_batch} <-- TensorRT会针对此尺寸专门优化!")
|
||
print(f" 最大Batch: {max_batch}")
|
||
print(f" 精度: {'FP16' if fp16 else 'FP32'}")
|
||
print(f" 工作空间: {workspace_mb} MB")
|
||
print(f" 优化策略: {tactic_sources}")
|
||
print(f" 全局最优: {'是' if best else '否'}")
|
||
print(f" 预览特性: {preview}")
|
||
|
||
profile = builder.create_optimization_profile()
|
||
profile.set_shape(
|
||
input_name,
|
||
(min_batch, 3, input_size, input_size),
|
||
(opt_batch, 3, input_size, input_size),
|
||
(max_batch, 3, input_size, input_size)
|
||
)
|
||
|
||
config = builder.create_builder_config()
|
||
config.add_optimization_profile(profile)
|
||
|
||
if fp16:
|
||
config.set_flag(trt.BuilderFlag.FP16)
|
||
|
||
config.set_flag(trt.BuilderFlag.TF32)
|
||
|
||
for source in tactic_sources.split(','):
|
||
if source.startswith('+'):
|
||
config.set_tactic_sources(int(source[1:]))
|
||
elif source.startswith('-'):
|
||
config.set_tactic_sources(~int(source[1:]))
|
||
|
||
if best:
|
||
config.set_flag(trt.BuilderFlag.BENCHMARK)
|
||
|
||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_mb * 1024 * 1024)
|
||
|
||
print(f"\n正在构建 Engine,请稍候 (可能需要3-5分钟)...")
|
||
|
||
serialized_engine = builder.build_serialized_network(network, config)
|
||
|
||
if serialized_engine is None:
|
||
print("错误: Engine 构建失败")
|
||
return False
|
||
|
||
with open(engine_path, 'wb') as f:
|
||
f.write(serialized_engine)
|
||
|
||
if os.path.exists(engine_path):
|
||
file_size = os.path.getsize(engine_path) / (1024 * 1024)
|
||
print(f"\n✓ Engine 构建成功!")
|
||
print(f" 文件: {engine_path}")
|
||
print(f" 大小: {file_size:.2f} MB")
|
||
return True
|
||
else:
|
||
print("错误: Engine 保存失败")
|
||
return False
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="TensorRT Engine 构建工具 (8GB显存优化版)",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter
|
||
)
|
||
parser.add_argument("--model", type=str, default="yolo11n.pt",
|
||
help="模型名称或路径 (默认: yolo11n.pt)")
|
||
parser.add_argument("--output", type=str, default="models/yolo11n.engine",
|
||
help="输出 engine 路径 (默认: models/yolo11n.engine)")
|
||
parser.add_argument("--download", type=str, default=None,
|
||
help="直接下载模型 (yolo11n/s/m/l/x)")
|
||
|
||
group = parser.add_argument_group("构建参数")
|
||
group.add_argument("--input-size", type=int, default=480,
|
||
help="输入图像尺寸 (默认: 480)")
|
||
group.add_argument("--min-batch", type=int, default=1,
|
||
help="最小Batch大小 (默认: 1)")
|
||
group.add_argument("--opt-batch", type=int, default=4,
|
||
help="优化Batch大小,TensorRT针对此尺寸专门优化 (默认: 4)")
|
||
group.add_argument("--max-batch", type=int, default=8,
|
||
help="最大Batch大小 (默认: 8)")
|
||
group.add_argument("--no-fp16", dest="fp16", action="store_false",
|
||
help="禁用 FP16 精度")
|
||
parser.set_defaults(fp16=True)
|
||
|
||
opt_group = parser.add_argument_group("8GB显存优化参数")
|
||
opt_group.add_argument("--workspace", type=int, default=6144,
|
||
help="工作空间大小 MB,8GB显存建议 6144 (默认: 6144)")
|
||
opt_group.add_argument("--tactics", type=str,
|
||
default="+CUBLAS,+CUBLAS_LT,+CUDNN",
|
||
help="优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN)")
|
||
opt_group.add_argument("--best", action="store_true", default=True,
|
||
help="全局最优搜索 (默认: 启用)")
|
||
opt_group.add_argument("--no-best", dest="best", action="store_false",
|
||
help="禁用全局最优搜索")
|
||
opt_group.add_argument("--preview", type=str,
|
||
default="+faster_dynamic_shapes_0805",
|
||
help="预览特性 (默认: +faster_dynamic_shapes_0805)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
check_environment()
|
||
|
||
if not TRT_AVAILABLE:
|
||
print("\n错误: TensorRT 未安装")
|
||
print("请安装: pip install tensorrt")
|
||
return 1
|
||
|
||
try:
|
||
from ultralytics import YOLO
|
||
except ImportError:
|
||
print("\n错误: Ultralytics 未安装")
|
||
print("请安装: pip install ultralytics")
|
||
return 1
|
||
|
||
model_path = args.model
|
||
|
||
if args.download:
|
||
model_path = download_model(args.download)
|
||
elif not os.path.exists(model_path):
|
||
if model_path.endswith('.pt') or model_path.endswith('.onnx'):
|
||
print(f"错误: 模型文件不存在: {model_path}")
|
||
return 1
|
||
else:
|
||
model_path = download_model(model_path)
|
||
else:
|
||
print(f"\n使用本地模型: {model_path}")
|
||
|
||
onnx_path = model_path
|
||
if model_path.endswith('.pt'):
|
||
onnx_path = model_path.replace('.pt', '.onnx')
|
||
success = export_onnx(model_path, onnx_path, args.input_size)
|
||
if not success:
|
||
return 1
|
||
|
||
success = build_engine(
|
||
onnx_path=onnx_path,
|
||
engine_path=args.output,
|
||
input_size=args.input_size,
|
||
min_batch=args.min_batch,
|
||
opt_batch=args.opt_batch,
|
||
max_batch=args.max_batch,
|
||
fp16=args.fp16,
|
||
workspace_mb=args.workspace,
|
||
tactic_sources=args.tactics,
|
||
best=args.best,
|
||
preview=args.preview
|
||
)
|
||
|
||
if success:
|
||
print(f"\n{'='*60}")
|
||
print("构建完成!")
|
||
print(f"{'='*60}")
|
||
print(f"Engine 文件: {args.output}")
|
||
print(f"\n推荐使用 trtexec 测试性能:")
|
||
print(f" trtexec --loadEngine={args.output} --streams=8 --iterations=100")
|
||
return 0
|
||
else:
|
||
return 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
sys.exit(main())
|