Files
security-ai-edge/build_engine.py
16337 b0ddb6ee1a feat(project): move edge_inference_service contents to root and update paths
- Moved all project files and directories (config, core, models, etc.) from
  edge_inference_service/ to the repository root ai_edge/
- Updated model path in config/settings.py to reflect new structure
- Revised usage paths in __init__.py documentation
2026-01-29 18:43:19 +08:00

375 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
TensorRT Engine 生成脚本 (8GB显存优化版)
功能:
- 将 YOLO11 模型转换为 TensorRT engine 文件
- 针对 8GB 显存环境优化
- 支持 FP16 精度、480x480 输入、动态 Batch(1-8)
- 启用所有高性能优化策略
使用方法:
python build_engine.py --download yolo11n --output models/yolo11n.engine [OPTIONS]
示例:
python build_engine.py --download yolo11n --output models/yolo11n.engine --fp16 --batch 8
优化命令参数说明:
--min-batch 最小Batch大小 (默认: 1)
--opt-batch 优化Batch大小 (默认: 4) <-- TensorRT会针对此尺寸专门优化
--max-batch 最大Batch大小 (默认: 8)
--workspace 工作空间大小MB (默认: 6144即6GB)
--tactics 启用优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN)
--best 全局最优搜索 (默认: 启用)
--preview 预览特性 (默认: +faster_dynamic_shapes_0805)
"""
import os
import sys
import argparse
from pathlib import Path
try:
import tensorrt as trt
TRT_AVAILABLE = True
except ImportError:
TRT_AVAILABLE = False
trt = None
TRT_LOGGER = trt.Logger(trt.Logger.WARNING) if TRT_AVAILABLE else None
def check_environment():
"""检查环境配置"""
print(f"\n{'='*60}")
print("环境检查")
print(f"{'='*60}")
try:
import tensorrt as trt
print(f" TensorRT: {trt.__version__}")
except ImportError:
print(" TensorRT: 未安装")
try:
import torch
print(f" PyTorch: {torch.__version__}")
print(f" CUDA: {torch.version.cuda if hasattr(torch, 'version') else 'N/A'}")
if torch.cuda.is_available():
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" GPU显存: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
except ImportError:
print(" PyTorch: 未安装")
try:
from ultralytics import YOLO
print(" Ultralytics: 已安装")
except ImportError:
print(" Ultralytics: 未安装")
def download_model(model_type: str, save_dir: str = "./models") -> str:
"""
下载 YOLO11 模型
Args:
model_type: 模型类型 (yolo11n, yolo11s, yolo11m, yolo11l, yolo11x)
save_dir: 保存目录
Returns:
模型文件路径
"""
from ultralytics import YOLO
os.makedirs(save_dir, exist_ok=True)
model_path = os.path.join(save_dir, f"{model_type}.pt")
if os.path.exists(model_path):
print(f"模型已存在: {model_path}")
return model_path
print(f"\n下载 YOLO11 {model_type} 模型...")
model = YOLO(f"{model_type}.pt")
model.save(model_path)
print(f"模型已保存: {model_path}")
return model_path
def export_onnx(
model_path: str,
output_path: str,
input_size: int = 480,
dynamic_batch: bool = True
) -> bool:
"""
导出 ONNX 模型
Args:
model_path: 模型路径
output_path: 输出路径
input_size: 输入尺寸
dynamic_batch: 是否使用动态Batch
Returns:
是否成功
"""
from ultralytics import YOLO
print(f"\n导出 ONNX 模型...")
print(f" 源模型: {model_path}")
print(f" 输出路径: {output_path}")
print(f" 输入尺寸: {input_size}x{input_size}")
print(f" 动态Batch: {'' if dynamic_batch else ''}")
model = YOLO(model_path)
try:
model.export(
format='onnx',
imgsz=input_size,
dynamic=dynamic_batch,
simplify=True,
opset=17
)
if os.path.exists(output_path):
file_size = os.path.getsize(output_path) / (1024 * 1024)
print(f"\n✓ ONNX 导出成功!")
print(f" 文件: {output_path}")
print(f" 大小: {file_size:.2f} MB")
return True
else:
print("错误: ONNX 导出失败")
return False
except Exception as e:
print(f"错误: ONNX 导出失败: {e}")
return False
def build_engine(
onnx_path: str,
engine_path: str,
input_size: int = 480,
min_batch: int = 1,
opt_batch: int = 4,
max_batch: int = 8,
fp16: bool = True,
workspace_mb: int = 6144,
tactic_sources: str = "+CUBLAS,+CUBLAS_LT,+CUDNN",
best: bool = True,
preview: str = "+faster_dynamic_shapes_0805"
) -> bool:
"""
使用 TensorRT 构建 Engine8GB显存优化版
Args:
onnx_path: ONNX 模型路径
engine_path: 输出 engine 路径
input_size: 输入尺寸
min_batch: 最小 Batch 大小
opt_batch: 优化 Batch 大小 (TensorRT 会针对此尺寸专门优化)
max_batch: 最大 Batch 大小
fp16: 是否使用 FP16
workspace_mb: 工作空间大小 (MB)
tactic_sources: 优化策略
best: 是否全局最优搜索
preview: 预览特性
Returns:
是否成功
"""
print(f"\n{'='*60}")
print("TensorRT Engine 构建 (8GB显存优化版)")
print(f"{'='*60}")
if not os.path.exists(onnx_path):
print(f"错误: ONNX 模型不存在: {onnx_path}")
return False
os.makedirs(os.path.dirname(engine_path), exist_ok=True)
EXPLICIT_BATCH = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with trt.Builder(TRT_LOGGER) as builder:
with builder.create_network(EXPLICIT_BATCH) as network:
with trt.OnnxParser(network, TRT_LOGGER) as parser:
print(f"加载 ONNX 模型: {onnx_path}")
with open(onnx_path, 'rb') as f:
if not parser.parse(f.read()):
print("错误: ONNX 解析失败")
for error in range(parser.num_errors):
print(f" {parser.get_error(error)}")
return False
input_tensor = network.get_input(0)
input_name = input_tensor.name
input_shape = input_tensor.shape
print(f"\n构建参数:")
print(f" 输入名称: {input_name}")
print(f" 输入形状: {input_shape}")
print(f" 最小Batch: {min_batch}")
print(f" 优化Batch: {opt_batch} <-- TensorRT会针对此尺寸专门优化!")
print(f" 最大Batch: {max_batch}")
print(f" 精度: {'FP16' if fp16 else 'FP32'}")
print(f" 工作空间: {workspace_mb} MB")
print(f" 优化策略: {tactic_sources}")
print(f" 全局最优: {'' if best else ''}")
print(f" 预览特性: {preview}")
profile = builder.create_optimization_profile()
profile.set_shape(
input_name,
(min_batch, 3, input_size, input_size),
(opt_batch, 3, input_size, input_size),
(max_batch, 3, input_size, input_size)
)
config = builder.create_builder_config()
config.add_optimization_profile(profile)
if fp16:
config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.TF32)
for source in tactic_sources.split(','):
if source.startswith('+'):
config.set_tactic_sources(int(source[1:]))
elif source.startswith('-'):
config.set_tactic_sources(~int(source[1:]))
if best:
config.set_flag(trt.BuilderFlag.BENCHMARK)
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_mb * 1024 * 1024)
print(f"\n正在构建 Engine请稍候 (可能需要3-5分钟)...")
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
print("错误: Engine 构建失败")
return False
with open(engine_path, 'wb') as f:
f.write(serialized_engine)
if os.path.exists(engine_path):
file_size = os.path.getsize(engine_path) / (1024 * 1024)
print(f"\n✓ Engine 构建成功!")
print(f" 文件: {engine_path}")
print(f" 大小: {file_size:.2f} MB")
return True
else:
print("错误: Engine 保存失败")
return False
def main():
parser = argparse.ArgumentParser(
description="TensorRT Engine 构建工具 (8GB显存优化版)",
formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("--model", type=str, default="yolo11n.pt",
help="模型名称或路径 (默认: yolo11n.pt)")
parser.add_argument("--output", type=str, default="models/yolo11n.engine",
help="输出 engine 路径 (默认: models/yolo11n.engine)")
parser.add_argument("--download", type=str, default=None,
help="直接下载模型 (yolo11n/s/m/l/x)")
group = parser.add_argument_group("构建参数")
group.add_argument("--input-size", type=int, default=480,
help="输入图像尺寸 (默认: 480)")
group.add_argument("--min-batch", type=int, default=1,
help="最小Batch大小 (默认: 1)")
group.add_argument("--opt-batch", type=int, default=4,
help="优化Batch大小TensorRT针对此尺寸专门优化 (默认: 4)")
group.add_argument("--max-batch", type=int, default=8,
help="最大Batch大小 (默认: 8)")
group.add_argument("--no-fp16", dest="fp16", action="store_false",
help="禁用 FP16 精度")
parser.set_defaults(fp16=True)
opt_group = parser.add_argument_group("8GB显存优化参数")
opt_group.add_argument("--workspace", type=int, default=6144,
help="工作空间大小 MB8GB显存建议 6144 (默认: 6144)")
opt_group.add_argument("--tactics", type=str,
default="+CUBLAS,+CUBLAS_LT,+CUDNN",
help="优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN)")
opt_group.add_argument("--best", action="store_true", default=True,
help="全局最优搜索 (默认: 启用)")
opt_group.add_argument("--no-best", dest="best", action="store_false",
help="禁用全局最优搜索")
opt_group.add_argument("--preview", type=str,
default="+faster_dynamic_shapes_0805",
help="预览特性 (默认: +faster_dynamic_shapes_0805)")
args = parser.parse_args()
check_environment()
if not TRT_AVAILABLE:
print("\n错误: TensorRT 未安装")
print("请安装: pip install tensorrt")
return 1
try:
from ultralytics import YOLO
except ImportError:
print("\n错误: Ultralytics 未安装")
print("请安装: pip install ultralytics")
return 1
model_path = args.model
if args.download:
model_path = download_model(args.download)
elif not os.path.exists(model_path):
if model_path.endswith('.pt') or model_path.endswith('.onnx'):
print(f"错误: 模型文件不存在: {model_path}")
return 1
else:
model_path = download_model(model_path)
else:
print(f"\n使用本地模型: {model_path}")
onnx_path = model_path
if model_path.endswith('.pt'):
onnx_path = model_path.replace('.pt', '.onnx')
success = export_onnx(model_path, onnx_path, args.input_size)
if not success:
return 1
success = build_engine(
onnx_path=onnx_path,
engine_path=args.output,
input_size=args.input_size,
min_batch=args.min_batch,
opt_batch=args.opt_batch,
max_batch=args.max_batch,
fp16=args.fp16,
workspace_mb=args.workspace,
tactic_sources=args.tactics,
best=args.best,
preview=args.preview
)
if success:
print(f"\n{'='*60}")
print("构建完成!")
print(f"{'='*60}")
print(f"Engine 文件: {args.output}")
print(f"\n推荐使用 trtexec 测试性能:")
print(f" trtexec --loadEngine={args.output} --streams=8 --iterations=100")
return 0
else:
return 1
if __name__ == "__main__":
sys.exit(main())