190 lines
7.0 KiB
Python
190 lines
7.0 KiB
Python
"""
|
|
TensorRT Engine 构建模块
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
from dataclasses import dataclass, asdict
|
|
from typing import Tuple, Optional, Dict, Any
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
|
|
from .utils import get_file_hash, ensure_dir, setup_logging
|
|
|
|
logger = setup_logging()
|
|
|
|
|
|
@dataclass
|
|
class EngineMetadata:
|
|
"""TensorRT Engine 元数据"""
|
|
source_model_path: str
|
|
source_model_hash: str
|
|
engine_path: str
|
|
resolution: Tuple[int, int]
|
|
batch_size_range: Tuple[int, int, int]
|
|
precision: str
|
|
tensorrt_version: str
|
|
cuda_version: str
|
|
gpu_name: str
|
|
build_timestamp: str
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
data = asdict(self)
|
|
data["resolution"] = list(data["resolution"])
|
|
data["batch_size_range"] = list(data["batch_size_range"])
|
|
return data
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: Dict[str, Any]) -> "EngineMetadata":
|
|
data["resolution"] = tuple(data["resolution"])
|
|
data["batch_size_range"] = tuple(data["batch_size_range"])
|
|
return cls(**data)
|
|
|
|
def save(self, path: str):
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
|
|
|
@classmethod
|
|
def load(cls, path: str) -> "EngineMetadata":
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
return cls.from_dict(data)
|
|
|
|
|
|
class TRTEngineBuilder:
|
|
"""TensorRT Engine 构建器"""
|
|
|
|
def __init__(self, model_path: str, output_dir: str = "./engines"):
|
|
self.model_path = model_path
|
|
self.output_dir = ensure_dir(output_dir)
|
|
self.model_hash = get_file_hash(model_path) if os.path.exists(model_path) else ""
|
|
self._check_tensorrt()
|
|
|
|
def _check_tensorrt(self):
|
|
"""检查 TensorRT 是否可用"""
|
|
self.trt_available = False
|
|
|
|
try:
|
|
import tensorrt as trt
|
|
self.trt = trt
|
|
self.trt_version = trt.__version__
|
|
self.trt_available = True
|
|
logger.info(f"TensorRT 版本: {self.trt_version}")
|
|
except ImportError:
|
|
raise ImportError(
|
|
"TensorRT 未安装!请安装 TensorRT:\n"
|
|
" pip install tensorrt\n"
|
|
"或从 NVIDIA 官网下载: https://developer.nvidia.com/tensorrt"
|
|
)
|
|
|
|
try:
|
|
import torch
|
|
self.cuda_version = torch.version.cuda or "N/A"
|
|
self.gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"
|
|
except ImportError:
|
|
self.cuda_version = "N/A"
|
|
self.gpu_name = "N/A"
|
|
|
|
def _get_engine_filename(self, imgsz: int, precision: str) -> str:
|
|
# 基于源模型路径生成 engine 文件名
|
|
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
|
|
return f"{model_name}_{imgsz}x{imgsz}_{precision}.engine"
|
|
|
|
def _get_metadata_filename(self, imgsz: int, precision: str) -> str:
|
|
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
|
|
return f"{model_name}_{imgsz}x{imgsz}_{precision}.json"
|
|
|
|
def get_engine_path(self, imgsz: int, precision: str = "fp16") -> Optional[str]:
|
|
engine_path = os.path.join(self.output_dir, self._get_engine_filename(imgsz, precision))
|
|
return engine_path if os.path.exists(engine_path) else None
|
|
|
|
def validate_engine(self, engine_path: str) -> bool:
|
|
metadata_path = engine_path.replace(".engine", ".json")
|
|
if not os.path.exists(metadata_path):
|
|
return False
|
|
try:
|
|
metadata = EngineMetadata.load(metadata_path)
|
|
return metadata.source_model_hash == self.model_hash
|
|
except Exception as e:
|
|
logger.warning(f"验证 Engine 元数据失败: {e}")
|
|
return False
|
|
|
|
def build_engine(
|
|
self,
|
|
imgsz: int,
|
|
batch_size_range: Tuple[int, int, int] = (1, 8, 16),
|
|
precision: str = "fp16"
|
|
) -> str:
|
|
"""构建 TensorRT Engine - 直接使用 Ultralytics 导出的原始文件"""
|
|
model_dir = os.path.dirname(os.path.abspath(self.model_path))
|
|
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
|
|
|
|
# Ultralytics 默认导出路径
|
|
engine_path = os.path.join(model_dir, f"{model_name}.engine")
|
|
metadata_path = os.path.join(self.output_dir, self._get_metadata_filename(imgsz, precision))
|
|
|
|
# 检查是否需要重新构建
|
|
need_rebuild = True
|
|
if os.path.exists(engine_path) and os.path.exists(metadata_path):
|
|
try:
|
|
metadata = EngineMetadata.load(metadata_path)
|
|
# 检查分辨率和模型哈希是否匹配
|
|
if (metadata.source_model_hash == self.model_hash and
|
|
metadata.resolution == (imgsz, imgsz)):
|
|
logger.info(f"使用已有 Engine: {engine_path}")
|
|
need_rebuild = False
|
|
except Exception as e:
|
|
logger.warning(f"元数据验证失败: {e}")
|
|
|
|
if need_rebuild:
|
|
# 删除旧的 engine 文件(如果存在)
|
|
if os.path.exists(engine_path):
|
|
os.remove(engine_path)
|
|
|
|
logger.info(f"开始构建 TensorRT Engine: {imgsz}x{imgsz}, {precision}, batch={batch_size_range}")
|
|
|
|
from ultralytics import YOLO
|
|
model = YOLO(self.model_path)
|
|
|
|
# Ultralytics 导出 - 直接使用导出的文件,不做任何移动或重命名
|
|
export_result = model.export(
|
|
format="engine",
|
|
imgsz=imgsz,
|
|
half=(precision == "fp16"),
|
|
dynamic=True,
|
|
batch=batch_size_range[2],
|
|
workspace=4,
|
|
verbose=False
|
|
)
|
|
|
|
if not export_result or not os.path.exists(str(export_result)):
|
|
raise RuntimeError("Engine 导出失败")
|
|
|
|
engine_path = str(export_result)
|
|
logger.info(f"Engine 构建完成: {engine_path}")
|
|
|
|
# 保存元数据
|
|
metadata = EngineMetadata(
|
|
source_model_path=self.model_path,
|
|
source_model_hash=self.model_hash,
|
|
engine_path=engine_path,
|
|
resolution=(imgsz, imgsz),
|
|
batch_size_range=batch_size_range,
|
|
precision=precision,
|
|
tensorrt_version=self.trt_version,
|
|
cuda_version=self.cuda_version,
|
|
gpu_name=self.gpu_name,
|
|
build_timestamp=datetime.now().isoformat()
|
|
)
|
|
metadata.save(metadata_path)
|
|
|
|
return engine_path
|
|
|
|
def build_all_engines(self, resolutions: list = [320, 480], precision: str = "fp16") -> Dict[int, str]:
|
|
"""构建所有分辨率的 Engine"""
|
|
engines = {}
|
|
for imgsz in resolutions:
|
|
engine_path = self.build_engine(imgsz, precision=precision)
|
|
engines[imgsz] = engine_path
|
|
return engines
|