151 lines
4.4 KiB
Python
151 lines
4.4 KiB
Python
"""
|
|
TensorRT 推理引擎模块
|
|
"""
|
|
|
|
import time
|
|
import threading
|
|
import gc
|
|
from typing import Tuple, Optional, Dict, Any
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
|
|
from .utils import setup_logging
|
|
|
|
logger = setup_logging()
|
|
|
|
|
|
def clear_gpu_memory():
|
|
"""强制清理 GPU 显存"""
|
|
import torch
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
@dataclass
|
|
class InferenceTask:
|
|
"""推理任务"""
|
|
task_id: int
|
|
batch: np.ndarray
|
|
stream_id: int
|
|
start_time: float
|
|
end_time: float = 0.0
|
|
result: Optional[Any] = None
|
|
completed: bool = False
|
|
|
|
|
|
class TRTInferenceEngine:
|
|
"""TensorRT 推理引擎 - 使用 Ultralytics 封装"""
|
|
|
|
def __init__(self, engine_path: str, num_streams: int = 2, device_id: int = 0):
|
|
self.engine_path = engine_path
|
|
self.num_streams = num_streams
|
|
self.device_id = device_id
|
|
|
|
self._task_counter = 0
|
|
self._task_lock = threading.Lock()
|
|
self._tasks: Dict[int, InferenceTask] = {}
|
|
self.model = None
|
|
|
|
self._init_engine()
|
|
|
|
def _init_engine(self):
|
|
"""初始化推理引擎 - 使用 Ultralytics YOLO 加载 TensorRT engine"""
|
|
import torch
|
|
from ultralytics import YOLO
|
|
|
|
logger.info(f"加载 TensorRT Engine: {self.engine_path}")
|
|
|
|
# Ultralytics 会自动处理其封装的 .engine 文件
|
|
self.model = YOLO(self.engine_path, task='detect')
|
|
|
|
# 预热
|
|
logger.info("预热推理引擎...")
|
|
dummy_input = np.random.randint(0, 255, (320, 320, 3), dtype=np.uint8)
|
|
_ = self.model.predict(dummy_input, verbose=False)
|
|
|
|
logger.info(f"TensorRT 引擎初始化完成 (Ultralytics 封装)")
|
|
|
|
def infer_async(self, batch: np.ndarray, stream_id: int = 0) -> int:
|
|
"""异步推理"""
|
|
with self._task_lock:
|
|
task_id = self._task_counter
|
|
self._task_counter += 1
|
|
|
|
task = InferenceTask(
|
|
task_id=task_id,
|
|
batch=batch,
|
|
stream_id=stream_id % self.num_streams,
|
|
start_time=time.perf_counter()
|
|
)
|
|
|
|
self._tasks[task_id] = task
|
|
self._infer(task)
|
|
|
|
return task_id
|
|
|
|
def _infer(self, task: InferenceTask):
|
|
"""执行推理"""
|
|
batch = task.batch
|
|
|
|
# batch 格式: (N, C, H, W), 值范围 0-1
|
|
# 需要转换为 Ultralytics 期望的格式
|
|
if batch.ndim == 4:
|
|
# 转换 NCHW -> NHWC 并缩放到 0-255
|
|
batch_nhwc = np.transpose(batch, (0, 2, 3, 1))
|
|
batch_uint8 = (batch_nhwc * 255).astype(np.uint8)
|
|
|
|
# 逐帧推理
|
|
results = []
|
|
for i in range(batch_uint8.shape[0]):
|
|
frame = batch_uint8[i]
|
|
result = self.model.predict(frame, verbose=False)
|
|
results.append(result)
|
|
|
|
task.result = results
|
|
else:
|
|
task.result = self.model.predict(batch, verbose=False)
|
|
|
|
task.end_time = time.perf_counter()
|
|
task.completed = True
|
|
|
|
def get_results(self, task_id: int, timeout: float = 10.0) -> Tuple[Any, float]:
|
|
"""获取推理结果"""
|
|
if task_id not in self._tasks:
|
|
raise ValueError(f"无效的任务 ID: {task_id}")
|
|
|
|
task = self._tasks[task_id]
|
|
|
|
start_wait = time.time()
|
|
while not task.completed:
|
|
if time.time() - start_wait > timeout:
|
|
raise TimeoutError(f"任务 {task_id} 超时")
|
|
time.sleep(0.001)
|
|
|
|
latency_ms = (task.end_time - task.start_time) * 1000
|
|
result = task.result
|
|
|
|
del self._tasks[task_id]
|
|
|
|
return result, latency_ms
|
|
|
|
def infer_sync(self, batch: np.ndarray) -> Tuple[Any, float]:
|
|
"""同步推理"""
|
|
task_id = self.infer_async(batch)
|
|
return self.get_results(task_id)
|
|
|
|
def synchronize(self, stream_id: Optional[int] = None):
|
|
"""同步 (兼容接口)"""
|
|
pass
|
|
|
|
def cleanup(self):
|
|
"""清理资源,释放显存"""
|
|
if self.model is not None:
|
|
del self.model
|
|
self.model = None
|
|
self._tasks.clear()
|
|
clear_gpu_memory()
|
|
logger.info("推理引擎资源已释放")
|