GPU测试
This commit is contained in:
150
benchmark/inference_engine.py
Normal file
150
benchmark/inference_engine.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
TensorRT 推理引擎模块
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
import gc
|
||||
from typing import Tuple, Optional, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
def clear_gpu_memory():
|
||||
"""强制清理 GPU 显存"""
|
||||
import torch
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceTask:
|
||||
"""推理任务"""
|
||||
task_id: int
|
||||
batch: np.ndarray
|
||||
stream_id: int
|
||||
start_time: float
|
||||
end_time: float = 0.0
|
||||
result: Optional[Any] = None
|
||||
completed: bool = False
|
||||
|
||||
|
||||
class TRTInferenceEngine:
|
||||
"""TensorRT 推理引擎 - 使用 Ultralytics 封装"""
|
||||
|
||||
def __init__(self, engine_path: str, num_streams: int = 2, device_id: int = 0):
|
||||
self.engine_path = engine_path
|
||||
self.num_streams = num_streams
|
||||
self.device_id = device_id
|
||||
|
||||
self._task_counter = 0
|
||||
self._task_lock = threading.Lock()
|
||||
self._tasks: Dict[int, InferenceTask] = {}
|
||||
self.model = None
|
||||
|
||||
self._init_engine()
|
||||
|
||||
def _init_engine(self):
|
||||
"""初始化推理引擎 - 使用 Ultralytics YOLO 加载 TensorRT engine"""
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
logger.info(f"加载 TensorRT Engine: {self.engine_path}")
|
||||
|
||||
# Ultralytics 会自动处理其封装的 .engine 文件
|
||||
self.model = YOLO(self.engine_path, task='detect')
|
||||
|
||||
# 预热
|
||||
logger.info("预热推理引擎...")
|
||||
dummy_input = np.random.randint(0, 255, (320, 320, 3), dtype=np.uint8)
|
||||
_ = self.model.predict(dummy_input, verbose=False)
|
||||
|
||||
logger.info(f"TensorRT 引擎初始化完成 (Ultralytics 封装)")
|
||||
|
||||
def infer_async(self, batch: np.ndarray, stream_id: int = 0) -> int:
|
||||
"""异步推理"""
|
||||
with self._task_lock:
|
||||
task_id = self._task_counter
|
||||
self._task_counter += 1
|
||||
|
||||
task = InferenceTask(
|
||||
task_id=task_id,
|
||||
batch=batch,
|
||||
stream_id=stream_id % self.num_streams,
|
||||
start_time=time.perf_counter()
|
||||
)
|
||||
|
||||
self._tasks[task_id] = task
|
||||
self._infer(task)
|
||||
|
||||
return task_id
|
||||
|
||||
def _infer(self, task: InferenceTask):
|
||||
"""执行推理"""
|
||||
batch = task.batch
|
||||
|
||||
# batch 格式: (N, C, H, W), 值范围 0-1
|
||||
# 需要转换为 Ultralytics 期望的格式
|
||||
if batch.ndim == 4:
|
||||
# 转换 NCHW -> NHWC 并缩放到 0-255
|
||||
batch_nhwc = np.transpose(batch, (0, 2, 3, 1))
|
||||
batch_uint8 = (batch_nhwc * 255).astype(np.uint8)
|
||||
|
||||
# 逐帧推理
|
||||
results = []
|
||||
for i in range(batch_uint8.shape[0]):
|
||||
frame = batch_uint8[i]
|
||||
result = self.model.predict(frame, verbose=False)
|
||||
results.append(result)
|
||||
|
||||
task.result = results
|
||||
else:
|
||||
task.result = self.model.predict(batch, verbose=False)
|
||||
|
||||
task.end_time = time.perf_counter()
|
||||
task.completed = True
|
||||
|
||||
def get_results(self, task_id: int, timeout: float = 10.0) -> Tuple[Any, float]:
|
||||
"""获取推理结果"""
|
||||
if task_id not in self._tasks:
|
||||
raise ValueError(f"无效的任务 ID: {task_id}")
|
||||
|
||||
task = self._tasks[task_id]
|
||||
|
||||
start_wait = time.time()
|
||||
while not task.completed:
|
||||
if time.time() - start_wait > timeout:
|
||||
raise TimeoutError(f"任务 {task_id} 超时")
|
||||
time.sleep(0.001)
|
||||
|
||||
latency_ms = (task.end_time - task.start_time) * 1000
|
||||
result = task.result
|
||||
|
||||
del self._tasks[task_id]
|
||||
|
||||
return result, latency_ms
|
||||
|
||||
def infer_sync(self, batch: np.ndarray) -> Tuple[Any, float]:
|
||||
"""同步推理"""
|
||||
task_id = self.infer_async(batch)
|
||||
return self.get_results(task_id)
|
||||
|
||||
def synchronize(self, stream_id: Optional[int] = None):
|
||||
"""同步 (兼容接口)"""
|
||||
pass
|
||||
|
||||
def cleanup(self):
|
||||
"""清理资源,释放显存"""
|
||||
if self.model is not None:
|
||||
del self.model
|
||||
self.model = None
|
||||
self._tasks.clear()
|
||||
clear_gpu_memory()
|
||||
logger.info("推理引擎资源已释放")
|
||||
Reference in New Issue
Block a user