Files
Test_AI/benchmark/inference_engine.py
2026-01-20 10:54:30 +08:00

151 lines
4.4 KiB
Python

"""
TensorRT 推理引擎模块
"""
import time
import threading
import gc
from typing import Tuple, Optional, Dict, Any
from dataclasses import dataclass
import numpy as np
from .utils import setup_logging
logger = setup_logging()
def clear_gpu_memory():
"""强制清理 GPU 显存"""
import torch
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
@dataclass
class InferenceTask:
"""推理任务"""
task_id: int
batch: np.ndarray
stream_id: int
start_time: float
end_time: float = 0.0
result: Optional[Any] = None
completed: bool = False
class TRTInferenceEngine:
"""TensorRT 推理引擎 - 使用 Ultralytics 封装"""
def __init__(self, engine_path: str, num_streams: int = 2, device_id: int = 0):
self.engine_path = engine_path
self.num_streams = num_streams
self.device_id = device_id
self._task_counter = 0
self._task_lock = threading.Lock()
self._tasks: Dict[int, InferenceTask] = {}
self.model = None
self._init_engine()
def _init_engine(self):
"""初始化推理引擎 - 使用 Ultralytics YOLO 加载 TensorRT engine"""
import torch
from ultralytics import YOLO
logger.info(f"加载 TensorRT Engine: {self.engine_path}")
# Ultralytics 会自动处理其封装的 .engine 文件
self.model = YOLO(self.engine_path, task='detect')
# 预热
logger.info("预热推理引擎...")
dummy_input = np.random.randint(0, 255, (320, 320, 3), dtype=np.uint8)
_ = self.model.predict(dummy_input, verbose=False)
logger.info(f"TensorRT 引擎初始化完成 (Ultralytics 封装)")
def infer_async(self, batch: np.ndarray, stream_id: int = 0) -> int:
"""异步推理"""
with self._task_lock:
task_id = self._task_counter
self._task_counter += 1
task = InferenceTask(
task_id=task_id,
batch=batch,
stream_id=stream_id % self.num_streams,
start_time=time.perf_counter()
)
self._tasks[task_id] = task
self._infer(task)
return task_id
def _infer(self, task: InferenceTask):
"""执行推理"""
batch = task.batch
# batch 格式: (N, C, H, W), 值范围 0-1
# 需要转换为 Ultralytics 期望的格式
if batch.ndim == 4:
# 转换 NCHW -> NHWC 并缩放到 0-255
batch_nhwc = np.transpose(batch, (0, 2, 3, 1))
batch_uint8 = (batch_nhwc * 255).astype(np.uint8)
# 逐帧推理
results = []
for i in range(batch_uint8.shape[0]):
frame = batch_uint8[i]
result = self.model.predict(frame, verbose=False)
results.append(result)
task.result = results
else:
task.result = self.model.predict(batch, verbose=False)
task.end_time = time.perf_counter()
task.completed = True
def get_results(self, task_id: int, timeout: float = 10.0) -> Tuple[Any, float]:
"""获取推理结果"""
if task_id not in self._tasks:
raise ValueError(f"无效的任务 ID: {task_id}")
task = self._tasks[task_id]
start_wait = time.time()
while not task.completed:
if time.time() - start_wait > timeout:
raise TimeoutError(f"任务 {task_id} 超时")
time.sleep(0.001)
latency_ms = (task.end_time - task.start_time) * 1000
result = task.result
del self._tasks[task_id]
return result, latency_ms
def infer_sync(self, batch: np.ndarray) -> Tuple[Any, float]:
"""同步推理"""
task_id = self.infer_async(batch)
return self.get_results(task_id)
def synchronize(self, stream_id: Optional[int] = None):
"""同步 (兼容接口)"""
pass
def cleanup(self):
"""清理资源,释放显存"""
if self.model is not None:
del self.model
self.model = None
self._tasks.clear()
clear_gpu_memory()
logger.info("推理引擎资源已释放")