Files
security-ai-edge/core/tensorrt_engine.py

388 lines
12 KiB
Python
Raw Normal View History

2026-01-29 18:33:12 +08:00
"""
TensorRT推理引擎模块
固定 batch=4, FP16, 3×480×480
工业级实现Buffer Pool异步推理性能监控
2026-01-29 18:33:12 +08:00
"""
import logging
import threading
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
2026-01-29 18:33:12 +08:00
TRT_AVAILABLE = True
except ImportError:
TRT_AVAILABLE = False
trt = None
cuda = None
2026-01-29 18:33:12 +08:00
from config.settings import get_settings, InferenceConfig
from utils.logger import get_logger
logger = logging.getLogger(__name__)
class HostDeviceMem:
"""Host/Device 内存对(工业级 Buffer Pool"""
def __init__(self, host_mem, device_mem):
self.host = host_mem
self.device = device_mem
def __repr__(self):
return f"Host:{self.host.shape}, Device:{int(self.device)}"
def pad_to_batch4(frames: List[np.ndarray]) -> np.ndarray:
"""
Padding batch=N重复最后一帧已弃用改用 batch=1
Args:
frames: list of [3, 480, 480] numpy arrays
Returns:
np.ndarray: [N, 3, 480, 480]
"""
if len(frames) == 0:
raise ValueError("Empty frames list")
return np.stack(frames)
2026-01-29 18:33:12 +08:00
class TensorRTEngine:
"""TensorRT 引擎 (batch=1, FP16, 3×480×480)
2026-01-29 18:33:12 +08:00
特性
- Buffer Pool: bindings 只在 init 阶段分配一次
- Pinned Memory: 使用 pagelocked host memory 提升 H2D/D2H 性能
- Async API: CUDA stream + async memcpy + execute_async_v2
2026-01-29 18:33:12 +08:00
"""
BATCH_SIZE = 1
INPUT_SHAPE = (3, 480, 480)
2026-01-29 18:33:12 +08:00
def __init__(self, config: Optional[InferenceConfig] = None):
if not TRT_AVAILABLE:
raise RuntimeError("TensorRT 未安装,请先安装 tensorrt 库")
2026-01-29 18:33:12 +08:00
if config is None:
settings = get_settings()
config = settings.inference
self.config = config
self._engine = None
self._context = None
self._stream = None
self._released = False
self._cuda_context = None
2026-01-29 18:33:12 +08:00
self._logger = get_logger("tensorrt")
self._lock = threading.Lock()
self._bindings: List[int] = []
self._inputs: List[HostDeviceMem] = []
self._outputs: List[HostDeviceMem] = []
2026-01-29 18:33:12 +08:00
self._performance_stats = {
"inference_count": 0,
"total_inference_time_ms": 0.0,
"avg_inference_time_ms": 0.0,
"throughput_fps": 0.0,
"last_inference_time_ms": 0.0,
}
self._logger.info(
f"TensorRT 引擎初始化: "
f"{config.model_path}, "
f"batch={self.BATCH_SIZE}, "
f"shape={self.INPUT_SHAPE}, "
f"fp16={config.fp16_mode}"
2026-01-29 18:33:12 +08:00
)
def load_engine(self, engine_path: Optional[str] = None) -> bool:
"""加载 TensorRT engine 文件"""
2026-01-29 18:33:12 +08:00
if engine_path is None:
engine_path = self.config.model_path
with self._lock:
try:
if self._context is not None:
self._release_resources()
self._cuda_context = cuda.Device(0).make_context()
self._stream = cuda.Stream()
2026-01-29 18:33:12 +08:00
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with open(engine_path, "rb") as f:
runtime = trt.Runtime(TRT_LOGGER)
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._allocate_buffers()
2026-01-29 18:33:12 +08:00
self._logger.log_connection_event(
"load", "TensorRT", engine_path, True
)
self._logger.info(f"TensorRT 引擎加载成功: {engine_path}")
self._logger.info(f" 输入: {len(self._inputs)}, 输出: {len(self._outputs)}, batch={self.BATCH_SIZE}")
2026-01-29 18:33:12 +08:00
return True
except Exception as e:
self._logger.error(f"TensorRT 引擎加载失败: {e}")
2026-01-29 18:33:12 +08:00
return False
def _allocate_buffers(self):
"""Buffer Pool: 初始化阶段一次性分配所有 bindings
对于动态 shape engine使用配置中的 batch_size 作为默认大小
"""
self._bindings = []
self._inputs = []
self._outputs = []
for binding_idx in range(self._engine.num_bindings):
shape = list(self._engine.get_binding_shape(binding_idx))
dtype = trt.nptype(self._engine.get_binding_dtype(binding_idx))
if shape[0] == -1:
shape[0] = self.BATCH_SIZE
2026-01-29 18:33:12 +08:00
shape = tuple(max(1, s) if s < 0 else s for s in shape)
size = trt.volume(shape)
try:
host_mem = cuda.pagelocked_empty(size, dtype)
except Exception as e:
self._logger.warning(f"pagelocked memory 分配失败,回退到普通 numpy: {e}")
host_mem = np.zeros(size, dtype=dtype)
device_mem = cuda.mem_alloc(host_mem.nbytes)
self._bindings.append(int(device_mem))
mem_pair = HostDeviceMem(host_mem, device_mem)
if self._engine.binding_is_input(binding_idx):
self._inputs.append(mem_pair)
2026-01-29 18:33:12 +08:00
else:
self._outputs.append(mem_pair)
if len(self._inputs) == 0:
raise RuntimeError("No input bindings found")
if len(self._outputs) == 0:
raise RuntimeError("No output bindings found")
2026-01-29 18:33:12 +08:00
def infer(self, input_batch: np.ndarray) -> Tuple[List[np.ndarray], float]:
2026-01-29 18:33:12 +08:00
"""
执行推理工业级 async 模式
2026-01-29 18:33:12 +08:00
Args:
input_batch: numpy 输入shape = [batch, 3, 480, 480]dtype = np.float16
2026-01-29 18:33:12 +08:00
Returns:
tuple: (输出列表, 推理耗时ms)
"""
if self._engine is None or self._context is None:
raise RuntimeError("引擎未加载")
if len(self._inputs) == 0:
raise RuntimeError("未分配输入 buffer")
2026-01-29 18:33:12 +08:00
batch_size = input_batch.shape[0]
start_time = time.perf_counter()
2026-01-29 18:33:12 +08:00
self._cuda_context.push()
2026-01-29 18:33:12 +08:00
try:
input_batch = np.ascontiguousarray(input_batch)
input_name = self._engine.get_binding_name(0)
actual_shape = list(input_batch.shape)
self._context.set_input_shape(input_name, actual_shape)
np.copyto(self._inputs[0].host, input_batch.ravel())
cuda.memcpy_htod_async(
self._inputs[0].device,
self._inputs[0].host,
self._stream
)
self._context.execute_async_v2(
bindings=self._bindings,
stream_handle=self._stream.handle
)
results = []
for out in self._outputs:
cuda.memcpy_dtoh_async(
out.host,
out.device,
self._stream
)
results.append(out.host.copy())
self._stream.synchronize()
inference_time_ms = (time.perf_counter() - start_time) * 1000
self._update_performance_stats(inference_time_ms, self.BATCH_SIZE)
return results, inference_time_ms
finally:
self._cuda_context.pop()
2026-01-29 18:33:12 +08:00
def _update_performance_stats(self, inference_time_ms: float, batch_size: int):
"""更新性能统计"""
stats = self._performance_stats
stats["inference_count"] += 1
stats["total_inference_time_ms"] += inference_time_ms
stats["last_inference_time_ms"] = inference_time_ms
stats["avg_inference_time_ms"] = (
stats["total_inference_time_ms"] / stats["inference_count"]
)
total_time = stats["total_inference_time_ms"] / 1000.0
if total_time > 0:
stats["throughput_fps"] = stats["inference_count"] / total_time
def get_performance_stats(self) -> Dict[str, Any]:
"""获取性能统计"""
stats = dict(self._performance_stats)
stats["memory_usage"] = self.get_memory_usage()
return stats
def get_memory_usage(self) -> Dict[str, float]:
"""获取显存使用情况"""
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(self.config.device_id)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
return {
"total_mb": mem_info.total / (1024 ** 2),
"used_mb": mem_info.used / (1024 ** 2),
"free_mb": mem_info.free / (1024 ** 2),
}
except Exception as e:
self._logger.warning(f"获取显存信息失败: {e}")
return {"total_mb": 0, "used_mb": 0, "free_mb": 0}
def _release_resources(self):
"""释放资源"""
if self._cuda_context:
try:
self._cuda_context.pop()
self._cuda_context.detach()
except Exception:
pass
self._cuda_context = None
2026-01-29 18:33:12 +08:00
if self._stream:
try:
self._stream.synchronize()
except Exception:
pass
self._stream = None
2026-01-29 18:33:12 +08:00
self._context = None
self._engine = None
self._bindings = []
self._inputs = []
self._outputs = []
2026-01-29 18:33:12 +08:00
def release(self):
"""释放引擎资源(幂等调用)"""
2026-01-29 18:33:12 +08:00
with self._lock:
if self._released:
return
self._released = True
2026-01-29 18:33:12 +08:00
self._release_resources()
self._logger.info("TensorRT 引擎资源已释放")
2026-01-29 18:33:12 +08:00
def __del__(self):
"""析构函数"""
self.release()
class EngineManager:
"""引擎管理器类"""
2026-01-29 18:33:12 +08:00
def __init__(self):
self._engines: Dict[str, TensorRTEngine] = {}
self._logger = get_logger("tensorrt")
self._lock = threading.Lock()
def load_engine(
self,
engine_id: str,
engine_path: str,
config: Optional[InferenceConfig] = None
) -> bool:
"""加载引擎"""
2026-01-29 18:33:12 +08:00
with self._lock:
if engine_id in self._engines:
self._engines[engine_id].release()
engine = TensorRTEngine(config)
if engine.load_engine(engine_path):
self._engines[engine_id] = engine
self._logger.info(f"引擎已加载: {engine_id}")
return True
else:
return False
def get_engine(self, engine_id: str) -> Optional[TensorRTEngine]:
"""获取引擎"""
return self._engines.get(engine_id)
def release_engine(self, engine_id: str):
"""释放引擎"""
with self._lock:
if engine_id in self._engines:
self._engines[engine_id].release()
del self._engines[engine_id]
self._logger.info(f"引擎已释放: {engine_id}")
def release_all(self):
"""释放所有引擎"""
with self._lock:
for engine_id in list(self._engines.keys()):
self._engines[engine_id].release()
self._engines.clear()
self._logger.info("所有引擎已释放")
def get_all_stats(self) -> Dict[str, Any]:
"""获取所有引擎统计"""
with self._lock:
return {
engine_id: engine.get_performance_stats()
for engine_id, engine in self._engines.items()
}
def create_tensorrt_engine(
engine_path: str,
config: Optional[InferenceConfig] = None
) -> TensorRTEngine:
"""创建 TensorRT 引擎的便捷函数"""
2026-01-29 18:33:12 +08:00
engine = TensorRTEngine(config)
if engine.load_engine(engine_path):
return engine
else:
raise RuntimeError(f"无法加载 TensorRT 引擎: {engine_path}")