Files
security-ai-edge/core/tensorrt_engine.py
16337 b0ddb6ee1a feat(project): move edge_inference_service contents to root and update paths
- Moved all project files and directories (config, core, models, etc.) from
  edge_inference_service/ to the repository root ai_edge/
- Updated model path in config/settings.py to reflect new structure
- Revised usage paths in __init__.py documentation
2026-01-29 18:43:19 +08:00

449 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
TensorRT推理引擎模块
实现引擎加载、显存优化、异步推理、性能监控
"""
import logging
import threading
import time
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
try:
import tensorrt as trt
TRT_AVAILABLE = True
except ImportError:
TRT_AVAILABLE = False
trt = None
from config.settings import get_settings, InferenceConfig
from utils.logger import get_logger
logger = logging.getLogger(__name__)
class TensorRTEngine:
"""TensorRT引擎管理类
实现engine文件加载、显存管理、异步推理
"""
def __init__(self, config: Optional[InferenceConfig] = None):
"""
初始化TensorRT引擎
Args:
config: 推理配置
"""
if not TRT_AVAILABLE:
raise RuntimeError("TensorRT未安装请先安装tensorrt库")
if config is None:
settings = get_settings()
config = settings.inference
self.config = config
self._engine = None
self._context = None
self._input_binding = None
self._output_bindings = []
self._stream = None
self._logger = get_logger("tensorrt")
self._lock = threading.Lock()
self._memory_pool: Dict[str, np.ndarray] = {}
self._performance_stats = {
"inference_count": 0,
"total_inference_time_ms": 0.0,
"avg_inference_time_ms": 0.0,
"throughput_fps": 0.0,
"last_inference_time_ms": 0.0,
}
self._logger.info(
f"TensorRT引擎初始化配置: "
f"模型={config.model_path}, "
f"输入尺寸={config.input_width}x{config.input_height}, "
f"Batch={config.batch_size}, "
f"FP16={config.fp16_mode}"
)
def load_engine(self, engine_path: Optional[str] = None) -> bool:
"""
加载TensorRT engine文件
Args:
engine_path: engine文件路径
Returns:
是否加载成功
"""
if engine_path is None:
engine_path = self.config.model_path
with self._lock:
try:
if self._context is not None:
self._release_resources()
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
with open(engine_path, "rb") as f:
runtime = trt.Runtime(TRT_LOGGER)
self._engine = runtime.deserialize_cuda_engine(f.read())
self._context = self._engine.create_execution_context()
self._setup_bindings()
self._allocate_memory_pool()
self._logger.log_connection_event(
"load", "TensorRT", engine_path, True
)
self._logger.info(f"TensorRT引擎加载成功: {engine_path}")
return True
except Exception as e:
self._logger.error(f"TensorRT引擎加载失败: {e}")
return False
def _setup_bindings(self):
"""设置输入输出绑定"""
self._input_binding = None
self._output_bindings = []
for i in range(self._engine.num_bindings):
binding_name = self._engine.get_binding_name(i)
binding_shape = self._engine.get_binding_shape(i)
binding_dtype = self._engine.get_binding_dtype(i)
if self._engine.binding_is_input(i):
self._input_binding = {
"name": binding_name,
"shape": binding_shape,
"dtype": binding_dtype,
"index": i,
}
else:
self._output_bindings.append({
"name": binding_name,
"shape": binding_shape,
"dtype": binding_dtype,
"index": i,
})
self._logger.debug(
f"输入绑定: {self._input_binding}, "
f"输出绑定: {len(self._output_bindings)}"
)
def _allocate_memory_pool(self):
"""分配显存池"""
self._memory_pool.clear()
if self._input_binding:
shape = self._input_binding["shape"]
dtype = self._get_numpy_dtype(self._input_binding["dtype"])
self._memory_pool["input"] = np.zeros(shape, dtype=dtype)
for output in self._output_bindings:
shape = output["shape"]
dtype = self._get_numpy_dtype(output["dtype"])
self._memory_pool[output["name"]] = np.zeros(shape, dtype=dtype)
def _get_numpy_dtype(self, trt_dtype) -> np.dtype:
"""转换TensorRT数据类型到numpy"""
if trt_dtype == trt.float16:
return np.float16
elif trt_dtype == trt.float32:
return np.float32
elif trt_dtype == trt.int32:
return np.int32
elif trt_dtype == trt.int8:
return np.int8
else:
return np.float32
def _allocate_device_memory(self, batch_size: int) -> Tuple[np.ndarray, List[np.ndarray]]:
"""
分配设备显存
Returns:
tuple: (输入数据, 输出数据列表)
"""
input_shape = list(self._input_binding["shape"])
input_shape[0] = batch_size
input_data = np.zeros(input_shape, dtype=np.float16 if self.config.fp16_mode else np.float32)
output_data_list = []
for output in self._output_bindings:
output_shape = list(output["shape"])
output_shape[0] = batch_size
output_data = np.zeros(output_shape, dtype=self._get_numpy_dtype(output["dtype"]))
output_data_list.append(output_data)
return input_data, output_data_list
def set_input_shape(self, batch_size: int, height: int, width: int):
"""
动态设置输入形状
Args:
batch_size: 批次大小
height: 输入高度
width: 输入宽度
"""
if self._context is None:
raise RuntimeError("引擎未加载")
self._context.set_input_shape(
self._input_binding["name"],
[batch_size, 3, height, width]
)
self._logger.debug(f"输入形状已设置为: [{batch_size}, 3, {height}, {width}]")
def infer(
self,
input_data: np.ndarray,
async_mode: bool = False
) -> Tuple[List[np.ndarray], float]:
"""
执行推理
Args:
input_data: 输入数据 (NCHW格式)
async_mode: 是否使用异步模式
Returns:
tuple: (输出列表, 推理耗时ms)
"""
if self._engine is None or self._context is None:
raise RuntimeError("引擎未加载")
start_time = time.perf_counter()
batch_size = input_data.shape[0]
input_data = input_data.astype(np.float16 if self.config.fp16_mode else np.float32)
self._context.set_input_shape(
self._input_binding["name"],
input_data.shape
)
input_tensor = input_data
output_tensors = []
for output in self._output_bindings:
output_shape = list(output["shape"])
output_shape[0] = batch_size
output_tensor = np.zeros(output_shape, dtype=self._get_numpy_dtype(output["dtype"]))
output_tensors.append(output_tensor)
bindings = [input_tensor] + output_tensors
self._context.execute_v2(bindings=bindings)
inference_time_ms = (time.perf_counter() - start_time) * 1000
self._update_performance_stats(inference_time_ms, batch_size)
return output_tensors, inference_time_ms
def infer_async(self, input_data: np.ndarray) -> Tuple[List[np.ndarray], float]:
"""
执行异步推理
Args:
input_data: 输入数据
Returns:
tuple: (输出列表, 推理耗时ms)
"""
return self.infer(input_data, async_mode=True)
def infer_batch(
self,
batch_data: np.ndarray,
batch_size: int
) -> Tuple[List[np.ndarray], float]:
"""
推理批次数据
Args:
batch_data: 批次数据
batch_size: 实际批次大小
Returns:
tuple: (输出列表, 推理耗时ms)
"""
if batch_data.shape[0] != batch_size:
batch_data = batch_data[:batch_size]
return self.infer(batch_data)
def _update_performance_stats(self, inference_time_ms: float, batch_size: int):
"""更新性能统计"""
stats = self._performance_stats
stats["inference_count"] += 1
stats["total_inference_time_ms"] += inference_time_ms
stats["last_inference_time_ms"] = inference_time_ms
stats["avg_inference_time_ms"] = (
stats["total_inference_time_ms"] / stats["inference_count"]
)
total_time = stats["total_inference_time_ms"] / 1000.0
if total_time > 0:
stats["throughput_fps"] = stats["inference_count"] / total_time
def get_performance_stats(self) -> Dict[str, Any]:
"""获取性能统计"""
stats = dict(self._performance_stats)
stats["memory_usage"] = self.get_memory_usage()
return stats
def get_memory_usage(self) -> Dict[str, float]:
"""获取显存使用情况"""
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(self.config.device_id)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
return {
"total_mb": mem_info.total / (1024 ** 2),
"used_mb": mem_info.used / (1024 ** 2),
"free_mb": mem_info.free / (1024 ** 2),
}
except Exception as e:
self._logger.warning(f"获取显存信息失败: {e}")
return {"total_mb": 0, "used_mb": 0, "free_mb": 0}
def _release_resources(self):
"""释放资源"""
if self._stream:
try:
self._stream.synchronize()
except Exception:
pass
if self._context:
self._context.destroy()
self._context = None
if self._engine:
self._engine.destroy()
self._engine = None
self._memory_pool.clear()
def release(self):
"""释放引擎资源"""
with self._lock:
self._release_resources()
self._logger.info("TensorRT引擎资源已释放")
def __del__(self):
"""析构函数"""
self.release()
class EngineManager:
"""引擎管理器类
管理多个TensorRT引擎实例
"""
def __init__(self):
self._engines: Dict[str, TensorRTEngine] = {}
self._logger = get_logger("tensorrt")
self._lock = threading.Lock()
def load_engine(
self,
engine_id: str,
engine_path: str,
config: Optional[InferenceConfig] = None
) -> bool:
"""
加载引擎
Args:
engine_id: 引擎标识
engine_path: engine文件路径
config: 推理配置
Returns:
是否加载成功
"""
with self._lock:
if engine_id in self._engines:
self._engines[engine_id].release()
engine = TensorRTEngine(config)
if engine.load_engine(engine_path):
self._engines[engine_id] = engine
self._logger.info(f"引擎已加载: {engine_id}")
return True
else:
return False
def get_engine(self, engine_id: str) -> Optional[TensorRTEngine]:
"""获取引擎"""
return self._engines.get(engine_id)
def release_engine(self, engine_id: str):
"""释放引擎"""
with self._lock:
if engine_id in self._engines:
self._engines[engine_id].release()
del self._engines[engine_id]
self._logger.info(f"引擎已释放: {engine_id}")
def release_all(self):
"""释放所有引擎"""
with self._lock:
for engine_id in list(self._engines.keys()):
self._engines[engine_id].release()
self._engines.clear()
self._logger.info("所有引擎已释放")
def get_all_stats(self) -> Dict[str, Any]:
"""获取所有引擎统计"""
with self._lock:
return {
engine_id: engine.get_performance_stats()
for engine_id, engine in self._engines.items()
}
def create_tensorrt_engine(
engine_path: str,
config: Optional[InferenceConfig] = None
) -> TensorRTEngine:
"""
创建TensorRT引擎的便捷函数
Args:
engine_path: engine文件路径
config: 推理配置
Returns:
TensorRTEngine实例
"""
engine = TensorRTEngine(config)
if engine.load_engine(engine_path):
return engine
else:
raise RuntimeError(f"无法加载TensorRT引擎: {engine_path}")