Files
Test_AI/benchmark/engine_builder.py
2026-01-20 10:54:30 +08:00

190 lines
7.0 KiB
Python

"""
TensorRT Engine 构建模块
"""
import os
import json
from dataclasses import dataclass, asdict
from typing import Tuple, Optional, Dict, Any
from pathlib import Path
from datetime import datetime
from .utils import get_file_hash, ensure_dir, setup_logging
logger = setup_logging()
@dataclass
class EngineMetadata:
"""TensorRT Engine 元数据"""
source_model_path: str
source_model_hash: str
engine_path: str
resolution: Tuple[int, int]
batch_size_range: Tuple[int, int, int]
precision: str
tensorrt_version: str
cuda_version: str
gpu_name: str
build_timestamp: str
def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
data["resolution"] = list(data["resolution"])
data["batch_size_range"] = list(data["batch_size_range"])
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "EngineMetadata":
data["resolution"] = tuple(data["resolution"])
data["batch_size_range"] = tuple(data["batch_size_range"])
return cls(**data)
def save(self, path: str):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
@classmethod
def load(cls, path: str) -> "EngineMetadata":
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
return cls.from_dict(data)
class TRTEngineBuilder:
"""TensorRT Engine 构建器"""
def __init__(self, model_path: str, output_dir: str = "./engines"):
self.model_path = model_path
self.output_dir = ensure_dir(output_dir)
self.model_hash = get_file_hash(model_path) if os.path.exists(model_path) else ""
self._check_tensorrt()
def _check_tensorrt(self):
"""检查 TensorRT 是否可用"""
self.trt_available = False
try:
import tensorrt as trt
self.trt = trt
self.trt_version = trt.__version__
self.trt_available = True
logger.info(f"TensorRT 版本: {self.trt_version}")
except ImportError:
raise ImportError(
"TensorRT 未安装!请安装 TensorRT:\n"
" pip install tensorrt\n"
"或从 NVIDIA 官网下载: https://developer.nvidia.com/tensorrt"
)
try:
import torch
self.cuda_version = torch.version.cuda or "N/A"
self.gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"
except ImportError:
self.cuda_version = "N/A"
self.gpu_name = "N/A"
def _get_engine_filename(self, imgsz: int, precision: str) -> str:
# 基于源模型路径生成 engine 文件名
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
return f"{model_name}_{imgsz}x{imgsz}_{precision}.engine"
def _get_metadata_filename(self, imgsz: int, precision: str) -> str:
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
return f"{model_name}_{imgsz}x{imgsz}_{precision}.json"
def get_engine_path(self, imgsz: int, precision: str = "fp16") -> Optional[str]:
engine_path = os.path.join(self.output_dir, self._get_engine_filename(imgsz, precision))
return engine_path if os.path.exists(engine_path) else None
def validate_engine(self, engine_path: str) -> bool:
metadata_path = engine_path.replace(".engine", ".json")
if not os.path.exists(metadata_path):
return False
try:
metadata = EngineMetadata.load(metadata_path)
return metadata.source_model_hash == self.model_hash
except Exception as e:
logger.warning(f"验证 Engine 元数据失败: {e}")
return False
def build_engine(
self,
imgsz: int,
batch_size_range: Tuple[int, int, int] = (1, 8, 16),
precision: str = "fp16"
) -> str:
"""构建 TensorRT Engine - 直接使用 Ultralytics 导出的原始文件"""
model_dir = os.path.dirname(os.path.abspath(self.model_path))
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
# Ultralytics 默认导出路径
engine_path = os.path.join(model_dir, f"{model_name}.engine")
metadata_path = os.path.join(self.output_dir, self._get_metadata_filename(imgsz, precision))
# 检查是否需要重新构建
need_rebuild = True
if os.path.exists(engine_path) and os.path.exists(metadata_path):
try:
metadata = EngineMetadata.load(metadata_path)
# 检查分辨率和模型哈希是否匹配
if (metadata.source_model_hash == self.model_hash and
metadata.resolution == (imgsz, imgsz)):
logger.info(f"使用已有 Engine: {engine_path}")
need_rebuild = False
except Exception as e:
logger.warning(f"元数据验证失败: {e}")
if need_rebuild:
# 删除旧的 engine 文件(如果存在)
if os.path.exists(engine_path):
os.remove(engine_path)
logger.info(f"开始构建 TensorRT Engine: {imgsz}x{imgsz}, {precision}, batch={batch_size_range}")
from ultralytics import YOLO
model = YOLO(self.model_path)
# Ultralytics 导出 - 直接使用导出的文件,不做任何移动或重命名
export_result = model.export(
format="engine",
imgsz=imgsz,
half=(precision == "fp16"),
dynamic=True,
batch=batch_size_range[2],
workspace=4,
verbose=False
)
if not export_result or not os.path.exists(str(export_result)):
raise RuntimeError("Engine 导出失败")
engine_path = str(export_result)
logger.info(f"Engine 构建完成: {engine_path}")
# 保存元数据
metadata = EngineMetadata(
source_model_path=self.model_path,
source_model_hash=self.model_hash,
engine_path=engine_path,
resolution=(imgsz, imgsz),
batch_size_range=batch_size_range,
precision=precision,
tensorrt_version=self.trt_version,
cuda_version=self.cuda_version,
gpu_name=self.gpu_name,
build_timestamp=datetime.now().isoformat()
)
metadata.save(metadata_path)
return engine_path
def build_all_engines(self, resolutions: list = [320, 480], precision: str = "fp16") -> Dict[int, str]:
"""构建所有分辨率的 Engine"""
engines = {}
for imgsz in resolutions:
engine_path = self.build_engine(imgsz, precision=precision)
engines[imgsz] = engine_path
return engines