GPU测试
This commit is contained in:
189
benchmark/engine_builder.py
Normal file
189
benchmark/engine_builder.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
TensorRT Engine 构建模块
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Tuple, Optional, Dict, Any
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
from .utils import get_file_hash, ensure_dir, setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineMetadata:
|
||||
"""TensorRT Engine 元数据"""
|
||||
source_model_path: str
|
||||
source_model_hash: str
|
||||
engine_path: str
|
||||
resolution: Tuple[int, int]
|
||||
batch_size_range: Tuple[int, int, int]
|
||||
precision: str
|
||||
tensorrt_version: str
|
||||
cuda_version: str
|
||||
gpu_name: str
|
||||
build_timestamp: str
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
data = asdict(self)
|
||||
data["resolution"] = list(data["resolution"])
|
||||
data["batch_size_range"] = list(data["batch_size_range"])
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "EngineMetadata":
|
||||
data["resolution"] = tuple(data["resolution"])
|
||||
data["batch_size_range"] = tuple(data["batch_size_range"])
|
||||
return cls(**data)
|
||||
|
||||
def save(self, path: str):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: str) -> "EngineMetadata":
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return cls.from_dict(data)
|
||||
|
||||
|
||||
class TRTEngineBuilder:
|
||||
"""TensorRT Engine 构建器"""
|
||||
|
||||
def __init__(self, model_path: str, output_dir: str = "./engines"):
|
||||
self.model_path = model_path
|
||||
self.output_dir = ensure_dir(output_dir)
|
||||
self.model_hash = get_file_hash(model_path) if os.path.exists(model_path) else ""
|
||||
self._check_tensorrt()
|
||||
|
||||
def _check_tensorrt(self):
|
||||
"""检查 TensorRT 是否可用"""
|
||||
self.trt_available = False
|
||||
|
||||
try:
|
||||
import tensorrt as trt
|
||||
self.trt = trt
|
||||
self.trt_version = trt.__version__
|
||||
self.trt_available = True
|
||||
logger.info(f"TensorRT 版本: {self.trt_version}")
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"TensorRT 未安装!请安装 TensorRT:\n"
|
||||
" pip install tensorrt\n"
|
||||
"或从 NVIDIA 官网下载: https://developer.nvidia.com/tensorrt"
|
||||
)
|
||||
|
||||
try:
|
||||
import torch
|
||||
self.cuda_version = torch.version.cuda or "N/A"
|
||||
self.gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A"
|
||||
except ImportError:
|
||||
self.cuda_version = "N/A"
|
||||
self.gpu_name = "N/A"
|
||||
|
||||
def _get_engine_filename(self, imgsz: int, precision: str) -> str:
|
||||
# 基于源模型路径生成 engine 文件名
|
||||
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
|
||||
return f"{model_name}_{imgsz}x{imgsz}_{precision}.engine"
|
||||
|
||||
def _get_metadata_filename(self, imgsz: int, precision: str) -> str:
|
||||
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
|
||||
return f"{model_name}_{imgsz}x{imgsz}_{precision}.json"
|
||||
|
||||
def get_engine_path(self, imgsz: int, precision: str = "fp16") -> Optional[str]:
|
||||
engine_path = os.path.join(self.output_dir, self._get_engine_filename(imgsz, precision))
|
||||
return engine_path if os.path.exists(engine_path) else None
|
||||
|
||||
def validate_engine(self, engine_path: str) -> bool:
|
||||
metadata_path = engine_path.replace(".engine", ".json")
|
||||
if not os.path.exists(metadata_path):
|
||||
return False
|
||||
try:
|
||||
metadata = EngineMetadata.load(metadata_path)
|
||||
return metadata.source_model_hash == self.model_hash
|
||||
except Exception as e:
|
||||
logger.warning(f"验证 Engine 元数据失败: {e}")
|
||||
return False
|
||||
|
||||
def build_engine(
|
||||
self,
|
||||
imgsz: int,
|
||||
batch_size_range: Tuple[int, int, int] = (1, 8, 16),
|
||||
precision: str = "fp16"
|
||||
) -> str:
|
||||
"""构建 TensorRT Engine - 直接使用 Ultralytics 导出的原始文件"""
|
||||
model_dir = os.path.dirname(os.path.abspath(self.model_path))
|
||||
model_name = os.path.splitext(os.path.basename(self.model_path))[0]
|
||||
|
||||
# Ultralytics 默认导出路径
|
||||
engine_path = os.path.join(model_dir, f"{model_name}.engine")
|
||||
metadata_path = os.path.join(self.output_dir, self._get_metadata_filename(imgsz, precision))
|
||||
|
||||
# 检查是否需要重新构建
|
||||
need_rebuild = True
|
||||
if os.path.exists(engine_path) and os.path.exists(metadata_path):
|
||||
try:
|
||||
metadata = EngineMetadata.load(metadata_path)
|
||||
# 检查分辨率和模型哈希是否匹配
|
||||
if (metadata.source_model_hash == self.model_hash and
|
||||
metadata.resolution == (imgsz, imgsz)):
|
||||
logger.info(f"使用已有 Engine: {engine_path}")
|
||||
need_rebuild = False
|
||||
except Exception as e:
|
||||
logger.warning(f"元数据验证失败: {e}")
|
||||
|
||||
if need_rebuild:
|
||||
# 删除旧的 engine 文件(如果存在)
|
||||
if os.path.exists(engine_path):
|
||||
os.remove(engine_path)
|
||||
|
||||
logger.info(f"开始构建 TensorRT Engine: {imgsz}x{imgsz}, {precision}, batch={batch_size_range}")
|
||||
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(self.model_path)
|
||||
|
||||
# Ultralytics 导出 - 直接使用导出的文件,不做任何移动或重命名
|
||||
export_result = model.export(
|
||||
format="engine",
|
||||
imgsz=imgsz,
|
||||
half=(precision == "fp16"),
|
||||
dynamic=True,
|
||||
batch=batch_size_range[2],
|
||||
workspace=4,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
if not export_result or not os.path.exists(str(export_result)):
|
||||
raise RuntimeError("Engine 导出失败")
|
||||
|
||||
engine_path = str(export_result)
|
||||
logger.info(f"Engine 构建完成: {engine_path}")
|
||||
|
||||
# 保存元数据
|
||||
metadata = EngineMetadata(
|
||||
source_model_path=self.model_path,
|
||||
source_model_hash=self.model_hash,
|
||||
engine_path=engine_path,
|
||||
resolution=(imgsz, imgsz),
|
||||
batch_size_range=batch_size_range,
|
||||
precision=precision,
|
||||
tensorrt_version=self.trt_version,
|
||||
cuda_version=self.cuda_version,
|
||||
gpu_name=self.gpu_name,
|
||||
build_timestamp=datetime.now().isoformat()
|
||||
)
|
||||
metadata.save(metadata_path)
|
||||
|
||||
return engine_path
|
||||
|
||||
def build_all_engines(self, resolutions: list = [320, 480], precision: str = "fp16") -> Dict[int, str]:
|
||||
"""构建所有分辨率的 Engine"""
|
||||
engines = {}
|
||||
for imgsz in resolutions:
|
||||
engine_path = self.build_engine(imgsz, precision=precision)
|
||||
engines[imgsz] = engine_path
|
||||
return engines
|
||||
Reference in New Issue
Block a user