feat: TensorRT 工业级重构
- 添加 HostDeviceMem 类(Buffer Pool) - _allocate_buffers() init 阶段一次性分配 - infer() 使用 async API + CUDA stream - 回退机制:pagelocked 失败时用普通 numpy
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""
|
||||
TensorRT推理引擎模块
|
||||
实现引擎加载、显存优化、异步推理、性能监控
|
||||
工业级实现:Buffer Pool、异步推理、性能监控
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
@@ -27,21 +27,29 @@ from utils.logger import get_logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TensorRTEngine:
|
||||
"""TensorRT引擎管理类
|
||||
class HostDeviceMem:
|
||||
"""Host/Device 内存对(工业级 Buffer Pool)"""
|
||||
|
||||
实现engine文件加载、显存管理、异步推理
|
||||
def __init__(self, host_mem, device_mem):
|
||||
self.host = host_mem
|
||||
self.device = device_mem
|
||||
|
||||
def __repr__(self):
|
||||
return f"Host:{self.host.shape}, Device:{int(self.device)}"
|
||||
|
||||
|
||||
class TensorRTEngine:
|
||||
"""工业级 TensorRT 引擎
|
||||
|
||||
特性:
|
||||
- Buffer Pool: bindings 只在 init 阶段分配一次
|
||||
- Pinned Memory: 使用 pagelocked host memory 提升 H2D/D2H 性能
|
||||
- Async API: CUDA stream + async memcpy + execute_async_v2
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[InferenceConfig] = None):
|
||||
"""
|
||||
初始化TensorRT引擎
|
||||
|
||||
Args:
|
||||
config: 推理配置
|
||||
"""
|
||||
if not TRT_AVAILABLE:
|
||||
raise RuntimeError("TensorRT未安装,请先安装tensorrt库")
|
||||
raise RuntimeError("TensorRT 未安装,请先安装 tensorrt 库")
|
||||
|
||||
if config is None:
|
||||
settings = get_settings()
|
||||
@@ -50,8 +58,6 @@ class TensorRTEngine:
|
||||
self.config = config
|
||||
self._engine = None
|
||||
self._context = None
|
||||
self._input_binding = None
|
||||
self._output_bindings = []
|
||||
self._stream = None
|
||||
self._released = False
|
||||
self._cuda_context = None
|
||||
@@ -59,7 +65,10 @@ class TensorRTEngine:
|
||||
self._logger = get_logger("tensorrt")
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._memory_pool: Dict[str, np.ndarray] = {}
|
||||
self._bindings: List[int] = []
|
||||
self._inputs: List[HostDeviceMem] = []
|
||||
self._outputs: List[HostDeviceMem] = []
|
||||
self._binding_names: Dict[int, str] = {}
|
||||
|
||||
self._performance_stats = {
|
||||
"inference_count": 0,
|
||||
@@ -70,23 +79,15 @@ class TensorRTEngine:
|
||||
}
|
||||
|
||||
self._logger.info(
|
||||
f"TensorRT引擎初始化配置: "
|
||||
f"模型={config.model_path}, "
|
||||
f"输入尺寸={config.input_width}x{config.input_height}, "
|
||||
f"Batch={config.batch_size}, "
|
||||
f"FP16={config.fp16_mode}"
|
||||
f"TensorRT 引擎初始化: "
|
||||
f"{config.model_path}, "
|
||||
f"{config.input_width}x{config.input_height}, "
|
||||
f"batch={config.batch_size}, "
|
||||
f"fp16={config.fp16_mode}"
|
||||
)
|
||||
|
||||
def load_engine(self, engine_path: Optional[str] = None) -> bool:
|
||||
"""
|
||||
加载TensorRT engine文件
|
||||
|
||||
Args:
|
||||
engine_path: engine文件路径
|
||||
|
||||
Returns:
|
||||
是否加载成功
|
||||
"""
|
||||
"""加载 TensorRT engine 文件"""
|
||||
if engine_path is None:
|
||||
engine_path = self.config.model_path
|
||||
|
||||
@@ -95,9 +96,8 @@ class TensorRTEngine:
|
||||
if self._context is not None:
|
||||
self._release_resources()
|
||||
|
||||
if cuda is not None:
|
||||
self._cuda_context = cuda.Device(0).make_context()
|
||||
self._stream = cuda.Stream()
|
||||
self._cuda_context = cuda.Device(0).make_context()
|
||||
self._stream = cuda.Stream()
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
@@ -107,131 +107,76 @@ class TensorRTEngine:
|
||||
|
||||
self._context = self._engine.create_execution_context()
|
||||
|
||||
self._setup_bindings()
|
||||
|
||||
self._allocate_memory_pool()
|
||||
self._allocate_buffers()
|
||||
|
||||
self._logger.log_connection_event(
|
||||
"load", "TensorRT", engine_path, True
|
||||
)
|
||||
self._logger.info(f"TensorRT引擎加载成功: {engine_path}")
|
||||
self._logger.info(f"TensorRT 引擎加载成功: {engine_path}")
|
||||
self._logger.info(f" 输入: {len(self._inputs)}, 输出: {len(self._outputs)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"TensorRT引擎加载失败: {e}")
|
||||
self._logger.error(f"TensorRT 引擎加载失败: {e}")
|
||||
return False
|
||||
|
||||
def _setup_bindings(self):
|
||||
"""设置输入输出绑定"""
|
||||
self._input_binding = None
|
||||
self._output_bindings = []
|
||||
def _allocate_buffers(self):
|
||||
"""Buffer Pool: 初始化阶段一次性分配所有 bindings(工业级关键点)"""
|
||||
self._bindings = []
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._binding_names = {}
|
||||
|
||||
for i in range(self._engine.num_bindings):
|
||||
binding_name = self._engine.get_binding_name(i)
|
||||
binding_shape = self._engine.get_binding_shape(i)
|
||||
binding_dtype = self._engine.get_binding_dtype(i)
|
||||
for binding_idx in range(self._engine.num_bindings):
|
||||
name = self._engine.get_binding_name(binding_idx)
|
||||
dtype = trt.nptype(self._engine.get_binding_dtype(binding_idx))
|
||||
shape = self._engine.get_binding_shape(binding_idx)
|
||||
|
||||
if self._engine.binding_is_input(i):
|
||||
self._input_binding = {
|
||||
"name": binding_name,
|
||||
"shape": binding_shape,
|
||||
"dtype": binding_dtype,
|
||||
"index": i,
|
||||
}
|
||||
self._binding_names[binding_idx] = name
|
||||
|
||||
shape = tuple(max(1, s) if s < 0 else s for s in shape)
|
||||
size = trt.volume(shape)
|
||||
|
||||
try:
|
||||
host_mem = cuda.pagelocked_empty(size, dtype)
|
||||
device_mem = cuda.mem_alloc(host_mem.nbytes)
|
||||
except Exception as e:
|
||||
self._logger.warning(f"pagelocked memory 分配失败,回退到普通 numpy: {e}")
|
||||
host_mem = np.zeros(size, dtype=dtype)
|
||||
device_mem = cuda.mem_alloc(host_mem.nbytes)
|
||||
|
||||
self._bindings.append(int(device_mem))
|
||||
|
||||
mem_pair = HostDeviceMem(host_mem, device_mem)
|
||||
|
||||
if self._engine.binding_is_input(binding_idx):
|
||||
self._inputs.append(mem_pair)
|
||||
else:
|
||||
self._output_bindings.append({
|
||||
"name": binding_name,
|
||||
"shape": binding_shape,
|
||||
"dtype": binding_dtype,
|
||||
"index": i,
|
||||
})
|
||||
self._outputs.append(mem_pair)
|
||||
|
||||
if len(self._inputs) == 0:
|
||||
raise RuntimeError("No input bindings found")
|
||||
if len(self._outputs) == 0:
|
||||
raise RuntimeError("No output bindings found")
|
||||
|
||||
self._logger.debug(
|
||||
f"输入绑定: {self._input_binding}, "
|
||||
f"输出绑定: {len(self._output_bindings)}"
|
||||
f"Buffer Pool 分配完成: "
|
||||
f"inputs={[int(i.device) for i in self._inputs]}, "
|
||||
f"outputs={[int(o.device) for o in self._outputs]}"
|
||||
)
|
||||
|
||||
def _allocate_memory_pool(self):
|
||||
"""分配显存池"""
|
||||
self._memory_pool.clear()
|
||||
|
||||
if self._input_binding:
|
||||
shape = self._input_binding["shape"]
|
||||
shape = tuple(max(1, s) if s < 0 else s for s in shape)
|
||||
dtype = self._get_numpy_dtype(self._input_binding["dtype"])
|
||||
self._memory_pool["input"] = np.zeros(shape, dtype=dtype)
|
||||
|
||||
for output in self._output_bindings:
|
||||
shape = output["shape"]
|
||||
shape = tuple(max(1, s) if s < 0 else s for s in shape)
|
||||
dtype = self._get_numpy_dtype(output["dtype"])
|
||||
self._memory_pool[output["name"]] = np.zeros(shape, dtype=dtype)
|
||||
def _get_output_shape(self, binding_idx: int) -> Tuple[int, ...]:
|
||||
"""获取输出的 shape"""
|
||||
name = self._binding_names[binding_idx]
|
||||
return self._engine.get_binding_shape(name)
|
||||
|
||||
def _get_numpy_dtype(self, trt_dtype) -> np.dtype:
|
||||
"""转换TensorRT数据类型到numpy"""
|
||||
if trt_dtype == trt.float16:
|
||||
return np.float16
|
||||
elif trt_dtype == trt.float32:
|
||||
return np.float32
|
||||
elif trt_dtype == trt.int32:
|
||||
return np.int32
|
||||
elif trt_dtype == trt.int8:
|
||||
return np.int8
|
||||
else:
|
||||
return np.float32
|
||||
|
||||
def _allocate_device_memory(self, batch_size: int) -> Tuple[np.ndarray, List[np.ndarray]]:
|
||||
def infer(self, input_np: np.ndarray) -> Tuple[List[np.ndarray], float]:
|
||||
"""
|
||||
分配设备显存
|
||||
|
||||
Returns:
|
||||
tuple: (输入数据, 输出数据列表)
|
||||
"""
|
||||
input_shape = list(self._input_binding["shape"])
|
||||
input_shape[0] = batch_size
|
||||
|
||||
input_data = np.zeros(input_shape, dtype=np.float16 if self.config.fp16_mode else np.float32)
|
||||
|
||||
output_data_list = []
|
||||
for output in self._output_bindings:
|
||||
output_shape = list(output["shape"])
|
||||
output_shape[0] = batch_size
|
||||
output_data = np.zeros(output_shape, dtype=self._get_numpy_dtype(output["dtype"]))
|
||||
output_data_list.append(output_data)
|
||||
|
||||
return input_data, output_data_list
|
||||
|
||||
def set_input_shape(self, batch_size: int, height: int, width: int):
|
||||
"""
|
||||
动态设置输入形状
|
||||
执行推理(工业级 async 模式)
|
||||
|
||||
Args:
|
||||
batch_size: 批次大小
|
||||
height: 输入高度
|
||||
width: 输入宽度
|
||||
"""
|
||||
if self._context is None:
|
||||
raise RuntimeError("引擎未加载")
|
||||
|
||||
self._context.set_input_shape(
|
||||
self._input_binding["name"],
|
||||
[batch_size, 3, height, width]
|
||||
)
|
||||
|
||||
self._logger.debug(f"输入形状已设置为: [{batch_size}, 3, {height}, {width}]")
|
||||
|
||||
def infer(
|
||||
self,
|
||||
input_data: np.ndarray,
|
||||
async_mode: bool = False
|
||||
) -> Tuple[List[np.ndarray], float]:
|
||||
"""
|
||||
执行推理
|
||||
|
||||
Args:
|
||||
input_data: 输入数据 (NCHW格式)
|
||||
async_mode: 是否使用异步模式
|
||||
input_np: numpy 输入,shape 必须与 engine 一致
|
||||
|
||||
Returns:
|
||||
tuple: (输出列表, 推理耗时ms)
|
||||
@@ -239,103 +184,59 @@ class TensorRTEngine:
|
||||
if self._engine is None or self._context is None:
|
||||
raise RuntimeError("引擎未加载")
|
||||
|
||||
if len(self._inputs) == 0:
|
||||
raise RuntimeError("未分配输入 buffer")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
batch_size = input_data.shape[0]
|
||||
self._cuda_context.push()
|
||||
|
||||
input_data = input_data.astype(np.float16 if self.config.fp16_mode else np.float32)
|
||||
|
||||
self._context.set_input_shape(
|
||||
self._input_binding["name"],
|
||||
input_data.shape
|
||||
)
|
||||
|
||||
if cuda is not None and self._cuda_context is not None:
|
||||
self._cuda_context.push()
|
||||
try:
|
||||
input_np = np.ascontiguousarray(input_np)
|
||||
|
||||
try:
|
||||
input_data = np.ascontiguousarray(input_data)
|
||||
|
||||
input_ptr = cuda.mem_alloc(input_data.nbytes)
|
||||
cuda.memcpy_htod(input_ptr, input_data)
|
||||
|
||||
bindings = [int(input_ptr)]
|
||||
output_tensors = []
|
||||
|
||||
for output in self._output_bindings:
|
||||
output_shape = list(output["shape"])
|
||||
output_shape[0] = batch_size
|
||||
output_tensor = np.zeros(output_shape, dtype=self._get_numpy_dtype(output["dtype"]))
|
||||
output_tensor = np.ascontiguousarray(output_tensor)
|
||||
output_ptr = cuda.mem_alloc(output_tensor.nbytes)
|
||||
cuda.memcpy_htod(output_ptr, output_tensor)
|
||||
bindings.append(int(output_ptr))
|
||||
output_tensors.append((output_tensor, output_ptr))
|
||||
|
||||
self._context.execute_v2(bindings=bindings)
|
||||
|
||||
for output_tensor, output_ptr in output_tensors:
|
||||
cuda.memcpy_dtoh(output_tensor, output_ptr)
|
||||
|
||||
inference_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
self._update_performance_stats(inference_time_ms, batch_size)
|
||||
|
||||
return [t[0] for t in output_tensors], inference_time_ms
|
||||
|
||||
finally:
|
||||
self._cuda_context.pop()
|
||||
else:
|
||||
input_tensor = input_data
|
||||
output_tensors = []
|
||||
input_name = self._binding_names[0]
|
||||
self._context.set_input_shape(input_name, input_np.shape)
|
||||
|
||||
for output in self._output_bindings:
|
||||
output_shape = list(output["shape"])
|
||||
output_shape[0] = batch_size
|
||||
output_tensor = np.zeros(output_shape, dtype=self._get_numpy_dtype(output["dtype"]))
|
||||
output_tensors.append(output_tensor)
|
||||
np.copyto(self._inputs[0].host, input_np.ravel())
|
||||
|
||||
bindings = [int(input_tensor.ctypes.data)] + [int(t.ctypes.data) for t in output_tensors]
|
||||
cuda.memcpy_htod_async(
|
||||
self._inputs[0].device,
|
||||
self._inputs[0].host,
|
||||
self._stream
|
||||
)
|
||||
|
||||
self._context.execute_v2(bindings=bindings)
|
||||
self._context.execute_async_v2(
|
||||
bindings=self._bindings,
|
||||
stream_handle=self._stream.handle
|
||||
)
|
||||
|
||||
for out in self._outputs:
|
||||
cuda.memcpy_dtoh_async(
|
||||
out.host,
|
||||
out.device,
|
||||
self._stream
|
||||
)
|
||||
|
||||
self._stream.synchronize()
|
||||
|
||||
inference_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
batch_size = input_np.shape[0]
|
||||
self._update_performance_stats(inference_time_ms, batch_size)
|
||||
|
||||
return output_tensors, inference_time_ms
|
||||
|
||||
def infer_async(self, input_data: np.ndarray) -> Tuple[List[np.ndarray], float]:
|
||||
"""
|
||||
执行异步推理
|
||||
|
||||
Args:
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
tuple: (输出列表, 推理耗时ms)
|
||||
"""
|
||||
return self.infer(input_data, async_mode=True)
|
||||
|
||||
def infer_batch(
|
||||
self,
|
||||
batch_data: np.ndarray,
|
||||
batch_size: int
|
||||
) -> Tuple[List[np.ndarray], float]:
|
||||
"""
|
||||
推理批次数据
|
||||
|
||||
Args:
|
||||
batch_data: 批次数据
|
||||
batch_size: 实际批次大小
|
||||
|
||||
Returns:
|
||||
tuple: (输出列表, 推理耗时ms)
|
||||
"""
|
||||
if batch_data.shape[0] != batch_size:
|
||||
batch_data = batch_data[:batch_size]
|
||||
|
||||
return self.infer(batch_data)
|
||||
output_shapes = []
|
||||
for i in range(len(self._inputs), self._engine.num_bindings):
|
||||
output_shapes.append(self._get_output_shape(i))
|
||||
|
||||
results = []
|
||||
for idx, out in enumerate(self._outputs):
|
||||
shape = output_shapes[idx] if idx < len(output_shapes) else out.host.shape
|
||||
results.append(out.host.reshape(shape))
|
||||
|
||||
return results, inference_time_ms
|
||||
|
||||
finally:
|
||||
self._cuda_context.pop()
|
||||
|
||||
def _update_performance_stats(self, inference_time_ms: float, batch_size: int):
|
||||
"""更新性能统计"""
|
||||
@@ -377,7 +278,7 @@ class TensorRTEngine:
|
||||
return {"total_mb": 0, "used_mb": 0, "free_mb": 0}
|
||||
|
||||
def _release_resources(self):
|
||||
"""释放资源(Python TensorRT 由 GC 管理,无需 destroy)"""
|
||||
"""释放资源"""
|
||||
if self._cuda_context:
|
||||
try:
|
||||
self._cuda_context.pop()
|
||||
@@ -393,13 +294,11 @@ class TensorRTEngine:
|
||||
pass
|
||||
self._stream = None
|
||||
|
||||
if self._context:
|
||||
self._context = None
|
||||
|
||||
if self._engine:
|
||||
self._engine = None
|
||||
|
||||
self._memory_pool.clear()
|
||||
self._context = None
|
||||
self._engine = None
|
||||
self._bindings = []
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
|
||||
def release(self):
|
||||
"""释放引擎资源(幂等调用)"""
|
||||
@@ -409,7 +308,7 @@ class TensorRTEngine:
|
||||
self._released = True
|
||||
|
||||
self._release_resources()
|
||||
self._logger.info("TensorRT引擎资源已释放")
|
||||
self._logger.info("TensorRT 引擎资源已释放")
|
||||
|
||||
def __del__(self):
|
||||
"""析构函数"""
|
||||
@@ -417,10 +316,7 @@ class TensorRTEngine:
|
||||
|
||||
|
||||
class EngineManager:
|
||||
"""引擎管理器类
|
||||
|
||||
管理多个TensorRT引擎实例
|
||||
"""
|
||||
"""引擎管理器类"""
|
||||
|
||||
def __init__(self):
|
||||
self._engines: Dict[str, TensorRTEngine] = {}
|
||||
@@ -433,17 +329,7 @@ class EngineManager:
|
||||
engine_path: str,
|
||||
config: Optional[InferenceConfig] = None
|
||||
) -> bool:
|
||||
"""
|
||||
加载引擎
|
||||
|
||||
Args:
|
||||
engine_id: 引擎标识
|
||||
engine_path: engine文件路径
|
||||
config: 推理配置
|
||||
|
||||
Returns:
|
||||
是否加载成功
|
||||
"""
|
||||
"""加载引擎"""
|
||||
with self._lock:
|
||||
if engine_id in self._engines:
|
||||
self._engines[engine_id].release()
|
||||
@@ -490,18 +376,9 @@ def create_tensorrt_engine(
|
||||
engine_path: str,
|
||||
config: Optional[InferenceConfig] = None
|
||||
) -> TensorRTEngine:
|
||||
"""
|
||||
创建TensorRT引擎的便捷函数
|
||||
|
||||
Args:
|
||||
engine_path: engine文件路径
|
||||
config: 推理配置
|
||||
|
||||
Returns:
|
||||
TensorRTEngine实例
|
||||
"""
|
||||
"""创建 TensorRT 引擎的便捷函数"""
|
||||
engine = TensorRTEngine(config)
|
||||
if engine.load_engine(engine_path):
|
||||
return engine
|
||||
else:
|
||||
raise RuntimeError(f"无法加载TensorRT引擎: {engine_path}")
|
||||
raise RuntimeError(f"无法加载 TensorRT 引擎: {engine_path}")
|
||||
|
||||
Reference in New Issue
Block a user