""" TensorRT 推理引擎模块 """ import time import threading import gc from typing import Tuple, Optional, Dict, Any from dataclasses import dataclass import numpy as np from .utils import setup_logging logger = setup_logging() def clear_gpu_memory(): """强制清理 GPU 显存""" import torch gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() @dataclass class InferenceTask: """推理任务""" task_id: int batch: np.ndarray stream_id: int start_time: float end_time: float = 0.0 result: Optional[Any] = None completed: bool = False class TRTInferenceEngine: """TensorRT 推理引擎 - 使用 Ultralytics 封装""" def __init__(self, engine_path: str, num_streams: int = 2, device_id: int = 0): self.engine_path = engine_path self.num_streams = num_streams self.device_id = device_id self._task_counter = 0 self._task_lock = threading.Lock() self._tasks: Dict[int, InferenceTask] = {} self.model = None self._init_engine() def _init_engine(self): """初始化推理引擎 - 使用 Ultralytics YOLO 加载 TensorRT engine""" import torch from ultralytics import YOLO logger.info(f"加载 TensorRT Engine: {self.engine_path}") # Ultralytics 会自动处理其封装的 .engine 文件 self.model = YOLO(self.engine_path, task='detect') # 预热 logger.info("预热推理引擎...") dummy_input = np.random.randint(0, 255, (320, 320, 3), dtype=np.uint8) _ = self.model.predict(dummy_input, verbose=False) logger.info(f"TensorRT 引擎初始化完成 (Ultralytics 封装)") def infer_async(self, batch: np.ndarray, stream_id: int = 0) -> int: """异步推理""" with self._task_lock: task_id = self._task_counter self._task_counter += 1 task = InferenceTask( task_id=task_id, batch=batch, stream_id=stream_id % self.num_streams, start_time=time.perf_counter() ) self._tasks[task_id] = task self._infer(task) return task_id def _infer(self, task: InferenceTask): """执行推理""" batch = task.batch # batch 格式: (N, C, H, W), 值范围 0-1 # 需要转换为 Ultralytics 期望的格式 if batch.ndim == 4: # 转换 NCHW -> NHWC 并缩放到 0-255 batch_nhwc = np.transpose(batch, (0, 2, 3, 1)) batch_uint8 = (batch_nhwc * 255).astype(np.uint8) # 逐帧推理 results = [] for i in range(batch_uint8.shape[0]): frame = batch_uint8[i] result = self.model.predict(frame, verbose=False) results.append(result) task.result = results else: task.result = self.model.predict(batch, verbose=False) task.end_time = time.perf_counter() task.completed = True def get_results(self, task_id: int, timeout: float = 10.0) -> Tuple[Any, float]: """获取推理结果""" if task_id not in self._tasks: raise ValueError(f"无效的任务 ID: {task_id}") task = self._tasks[task_id] start_wait = time.time() while not task.completed: if time.time() - start_wait > timeout: raise TimeoutError(f"任务 {task_id} 超时") time.sleep(0.001) latency_ms = (task.end_time - task.start_time) * 1000 result = task.result del self._tasks[task_id] return result, latency_ms def infer_sync(self, batch: np.ndarray) -> Tuple[Any, float]: """同步推理""" task_id = self.infer_async(batch) return self.get_results(task_id) def synchronize(self, stream_id: Optional[int] = None): """同步 (兼容接口)""" pass def cleanup(self): """清理资源,释放显存""" if self.model is not None: del self.model self.model = None self._tasks.clear() clear_gpu_memory() logger.info("推理引擎资源已释放")