""" TensorRT Engine 构建模块 """ import os import json from dataclasses import dataclass, asdict from typing import Tuple, Optional, Dict, Any from pathlib import Path from datetime import datetime from .utils import get_file_hash, ensure_dir, setup_logging logger = setup_logging() @dataclass class EngineMetadata: """TensorRT Engine 元数据""" source_model_path: str source_model_hash: str engine_path: str resolution: Tuple[int, int] batch_size_range: Tuple[int, int, int] precision: str tensorrt_version: str cuda_version: str gpu_name: str build_timestamp: str def to_dict(self) -> Dict[str, Any]: data = asdict(self) data["resolution"] = list(data["resolution"]) data["batch_size_range"] = list(data["batch_size_range"]) return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> "EngineMetadata": data["resolution"] = tuple(data["resolution"]) data["batch_size_range"] = tuple(data["batch_size_range"]) return cls(**data) def save(self, path: str): with open(path, "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) @classmethod def load(cls, path: str) -> "EngineMetadata": with open(path, "r", encoding="utf-8") as f: data = json.load(f) return cls.from_dict(data) class TRTEngineBuilder: """TensorRT Engine 构建器""" def __init__(self, model_path: str, output_dir: str = "./engines"): self.model_path = model_path self.output_dir = ensure_dir(output_dir) self.model_hash = get_file_hash(model_path) if os.path.exists(model_path) else "" self._check_tensorrt() def _check_tensorrt(self): """检查 TensorRT 是否可用""" self.trt_available = False try: import tensorrt as trt self.trt = trt self.trt_version = trt.__version__ self.trt_available = True logger.info(f"TensorRT 版本: {self.trt_version}") except ImportError: raise ImportError( "TensorRT 未安装!请安装 TensorRT:\n" " pip install tensorrt\n" "或从 NVIDIA 官网下载: https://developer.nvidia.com/tensorrt" ) try: import torch self.cuda_version = torch.version.cuda or "N/A" self.gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A" except ImportError: self.cuda_version = "N/A" self.gpu_name = "N/A" def _get_engine_filename(self, imgsz: int, precision: str) -> str: # 基于源模型路径生成 engine 文件名 model_name = os.path.splitext(os.path.basename(self.model_path))[0] return f"{model_name}_{imgsz}x{imgsz}_{precision}.engine" def _get_metadata_filename(self, imgsz: int, precision: str) -> str: model_name = os.path.splitext(os.path.basename(self.model_path))[0] return f"{model_name}_{imgsz}x{imgsz}_{precision}.json" def get_engine_path(self, imgsz: int, precision: str = "fp16") -> Optional[str]: engine_path = os.path.join(self.output_dir, self._get_engine_filename(imgsz, precision)) return engine_path if os.path.exists(engine_path) else None def validate_engine(self, engine_path: str) -> bool: metadata_path = engine_path.replace(".engine", ".json") if not os.path.exists(metadata_path): return False try: metadata = EngineMetadata.load(metadata_path) return metadata.source_model_hash == self.model_hash except Exception as e: logger.warning(f"验证 Engine 元数据失败: {e}") return False def build_engine( self, imgsz: int, batch_size_range: Tuple[int, int, int] = (1, 8, 16), precision: str = "fp16" ) -> str: """构建 TensorRT Engine - 直接使用 Ultralytics 导出的原始文件""" model_dir = os.path.dirname(os.path.abspath(self.model_path)) model_name = os.path.splitext(os.path.basename(self.model_path))[0] # Ultralytics 默认导出路径 engine_path = os.path.join(model_dir, f"{model_name}.engine") metadata_path = os.path.join(self.output_dir, self._get_metadata_filename(imgsz, precision)) # 检查是否需要重新构建 need_rebuild = True if os.path.exists(engine_path) and os.path.exists(metadata_path): try: metadata = EngineMetadata.load(metadata_path) # 检查分辨率和模型哈希是否匹配 if (metadata.source_model_hash == self.model_hash and metadata.resolution == (imgsz, imgsz)): logger.info(f"使用已有 Engine: {engine_path}") need_rebuild = False except Exception as e: logger.warning(f"元数据验证失败: {e}") if need_rebuild: # 删除旧的 engine 文件(如果存在) if os.path.exists(engine_path): os.remove(engine_path) logger.info(f"开始构建 TensorRT Engine: {imgsz}x{imgsz}, {precision}, batch={batch_size_range}") from ultralytics import YOLO model = YOLO(self.model_path) # Ultralytics 导出 - 直接使用导出的文件,不做任何移动或重命名 export_result = model.export( format="engine", imgsz=imgsz, half=(precision == "fp16"), dynamic=True, batch=batch_size_range[2], workspace=4, verbose=False ) if not export_result or not os.path.exists(str(export_result)): raise RuntimeError("Engine 导出失败") engine_path = str(export_result) logger.info(f"Engine 构建完成: {engine_path}") # 保存元数据 metadata = EngineMetadata( source_model_path=self.model_path, source_model_hash=self.model_hash, engine_path=engine_path, resolution=(imgsz, imgsz), batch_size_range=batch_size_range, precision=precision, tensorrt_version=self.trt_version, cuda_version=self.cuda_version, gpu_name=self.gpu_name, build_timestamp=datetime.now().isoformat() ) metadata.save(metadata_path) return engine_path def build_all_engines(self, resolutions: list = [320, 480], precision: str = "fp16") -> Dict[int, str]: """构建所有分辨率的 Engine""" engines = {} for imgsz in resolutions: engine_path = self.build_engine(imgsz, precision=precision) engines[imgsz] = engine_path return engines