355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""
|
||
原生 TensorRT 推理引擎 - 高性能版本
|
||
直接使用 TensorRT API,避免 Ultralytics 封装的性能损失
|
||
"""
|
||
|
||
import time
|
||
import threading
|
||
import numpy as np
|
||
from typing import Tuple, Optional, Dict, Any, List
|
||
from dataclasses import dataclass
|
||
import gc
|
||
|
||
try:
|
||
import tensorrt as trt
|
||
import pycuda.driver as cuda
|
||
import pycuda.autoinit
|
||
TRT_AVAILABLE = True
|
||
except ImportError:
|
||
TRT_AVAILABLE = False
|
||
|
||
from .utils import setup_logging
|
||
|
||
logger = setup_logging()
|
||
|
||
|
||
@dataclass
|
||
class TensorRTConfig:
|
||
"""TensorRT 配置"""
|
||
max_batch_size: int = 32
|
||
max_workspace_size: int = 1 << 30 # 1GB
|
||
fp16_mode: bool = True
|
||
int8_mode: bool = False
|
||
dla_core: Optional[int] = None
|
||
gpu_fallback: bool = True
|
||
strict_type_constraints: bool = False
|
||
|
||
|
||
class TensorRTEngine:
|
||
"""原生 TensorRT 推理引擎"""
|
||
|
||
def __init__(self, engine_path: str, config: TensorRTConfig = None):
|
||
if not TRT_AVAILABLE:
|
||
raise ImportError("TensorRT 不可用,请安装 tensorrt 和 pycuda")
|
||
|
||
self.engine_path = engine_path
|
||
self.config = config or TensorRTConfig()
|
||
|
||
# TensorRT 组件
|
||
self.logger = trt.Logger(trt.Logger.WARNING)
|
||
self.runtime = None
|
||
self.engine = None
|
||
self.context = None
|
||
|
||
# CUDA 相关
|
||
self.cuda_ctx = cuda.Device(0).make_context()
|
||
self.stream = cuda.Stream()
|
||
|
||
# 内存管理
|
||
self.inputs = []
|
||
self.outputs = []
|
||
self.bindings = []
|
||
self.host_inputs = []
|
||
self.host_outputs = []
|
||
self.cuda_inputs = []
|
||
self.cuda_outputs = []
|
||
|
||
# 性能统计
|
||
self.inference_times = []
|
||
self.batch_sizes = []
|
||
|
||
self._load_engine()
|
||
self._allocate_buffers()
|
||
|
||
def _load_engine(self):
|
||
"""加载 TensorRT 引擎"""
|
||
logger.info(f"加载 TensorRT 引擎: {self.engine_path}")
|
||
|
||
self.runtime = trt.Runtime(self.logger)
|
||
|
||
with open(self.engine_path, 'rb') as f:
|
||
engine_data = f.read()
|
||
|
||
self.engine = self.runtime.deserialize_cuda_engine(engine_data)
|
||
if not self.engine:
|
||
raise RuntimeError("无法反序列化 TensorRT 引擎")
|
||
|
||
self.context = self.engine.create_execution_context()
|
||
if not self.context:
|
||
raise RuntimeError("无法创建执行上下文")
|
||
|
||
logger.info(f"TensorRT 引擎加载成功")
|
||
logger.info(f" 输入数量: {self.engine.num_bindings // 2}")
|
||
logger.info(f" 最大批次大小: {self.engine.max_batch_size}")
|
||
|
||
def _allocate_buffers(self):
|
||
"""分配输入输出缓冲区"""
|
||
for binding in self.engine:
|
||
binding_idx = self.engine.get_binding_index(binding)
|
||
size = trt.volume(self.engine.get_binding_shape(binding)) * self.engine.max_batch_size
|
||
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
|
||
|
||
# 分配主机内存
|
||
host_mem = cuda.pagelocked_empty(size, dtype)
|
||
# 分配设备内存
|
||
cuda_mem = cuda.mem_alloc(host_mem.nbytes)
|
||
|
||
# 添加到绑定列表
|
||
self.bindings.append(int(cuda_mem))
|
||
|
||
if self.engine.binding_is_input(binding):
|
||
self.inputs.append(binding)
|
||
self.host_inputs.append(host_mem)
|
||
self.cuda_inputs.append(cuda_mem)
|
||
else:
|
||
self.outputs.append(binding)
|
||
self.host_outputs.append(host_mem)
|
||
self.cuda_outputs.append(cuda_mem)
|
||
|
||
logger.info(f"缓冲区分配完成: {len(self.inputs)} 输入, {len(self.outputs)} 输出")
|
||
|
||
def infer_batch(self, batch_data: np.ndarray) -> Tuple[List[np.ndarray], float]:
|
||
"""批量推理"""
|
||
start_time = time.perf_counter()
|
||
|
||
batch_size = batch_data.shape[0]
|
||
|
||
# 设置动态批次大小
|
||
if hasattr(self.context, 'set_binding_shape'):
|
||
input_shape = list(batch_data.shape)
|
||
self.context.set_binding_shape(0, input_shape)
|
||
|
||
# 复制输入数据到主机内存
|
||
np.copyto(self.host_inputs[0][:batch_data.size], batch_data.ravel())
|
||
|
||
# 传输到 GPU
|
||
cuda.memcpy_htod_async(self.cuda_inputs[0], self.host_inputs[0], self.stream)
|
||
|
||
# 执行推理
|
||
self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle)
|
||
|
||
# 传输结果回主机
|
||
for i, cuda_output in enumerate(self.cuda_outputs):
|
||
cuda.memcpy_dtoh_async(self.host_outputs[i], cuda_output, self.stream)
|
||
|
||
# 同步流
|
||
self.stream.synchronize()
|
||
|
||
end_time = time.perf_counter()
|
||
inference_time = (end_time - start_time) * 1000 # 转换为毫秒
|
||
|
||
# 收集输出
|
||
outputs = []
|
||
for host_output in self.host_outputs:
|
||
# 根据实际输出形状重塑
|
||
output_shape = self.engine.get_binding_shape(len(self.inputs))
|
||
if output_shape[0] == -1: # 动态批次
|
||
output_shape = (batch_size,) + output_shape[1:]
|
||
|
||
output = host_output[:np.prod(output_shape)].reshape(output_shape)
|
||
outputs.append(output.copy())
|
||
|
||
# 记录性能统计
|
||
self.inference_times.append(inference_time)
|
||
self.batch_sizes.append(batch_size)
|
||
|
||
return outputs, inference_time
|
||
|
||
def infer_single(self, image: np.ndarray) -> Tuple[np.ndarray, float]:
|
||
"""单张图片推理"""
|
||
if len(image.shape) == 3:
|
||
batch = np.expand_dims(image, axis=0)
|
||
else:
|
||
batch = image
|
||
|
||
outputs, inference_time = self.infer_batch(batch)
|
||
return outputs[0] if outputs else None, inference_time
|
||
|
||
def get_performance_stats(self) -> Dict[str, Any]:
|
||
"""获取性能统计"""
|
||
if not self.inference_times:
|
||
return {}
|
||
|
||
times = np.array(self.inference_times)
|
||
batches = np.array(self.batch_sizes)
|
||
|
||
return {
|
||
"total_inferences": len(times),
|
||
"avg_inference_time_ms": float(np.mean(times)),
|
||
"min_inference_time_ms": float(np.min(times)),
|
||
"max_inference_time_ms": float(np.max(times)),
|
||
"p95_inference_time_ms": float(np.percentile(times, 95)),
|
||
"avg_batch_size": float(np.mean(batches)),
|
||
"total_frames_processed": int(np.sum(batches)),
|
||
"avg_fps": float(np.sum(batches) / (np.sum(times) / 1000)) if np.sum(times) > 0 else 0
|
||
}
|
||
|
||
def reset_stats(self):
|
||
"""重置性能统计"""
|
||
self.inference_times.clear()
|
||
self.batch_sizes.clear()
|
||
|
||
def cleanup(self):
|
||
"""清理资源"""
|
||
# 释放 CUDA 内存
|
||
for cuda_mem in self.cuda_inputs + self.cuda_outputs:
|
||
cuda_mem.free()
|
||
|
||
# 清理 TensorRT 对象
|
||
if self.context:
|
||
del self.context
|
||
if self.engine:
|
||
del self.engine
|
||
if self.runtime:
|
||
del self.runtime
|
||
|
||
# 清理 CUDA 上下文
|
||
if hasattr(self, 'cuda_ctx'):
|
||
self.cuda_ctx.pop()
|
||
|
||
# 强制垃圾回收
|
||
gc.collect()
|
||
|
||
logger.info("TensorRT 引擎资源已释放")
|
||
|
||
|
||
class MultiStreamTensorRTEngine:
|
||
"""多流 TensorRT 推理引擎"""
|
||
|
||
def __init__(self, engine_path: str, num_streams: int = 4, config: TensorRTConfig = None):
|
||
self.engine_path = engine_path
|
||
self.num_streams = num_streams
|
||
self.config = config or TensorRTConfig()
|
||
|
||
# 创建多个推理引擎实例
|
||
self.engines = []
|
||
self.stream_locks = []
|
||
|
||
for i in range(num_streams):
|
||
try:
|
||
engine = TensorRTEngine(engine_path, config)
|
||
self.engines.append(engine)
|
||
self.stream_locks.append(threading.Lock())
|
||
logger.info(f"创建推理流 {i+1}/{num_streams}")
|
||
except Exception as e:
|
||
logger.error(f"创建推理流 {i+1} 失败: {e}")
|
||
break
|
||
|
||
if not self.engines:
|
||
raise RuntimeError("无法创建任何推理流")
|
||
|
||
logger.info(f"多流 TensorRT 引擎初始化完成: {len(self.engines)} 个流")
|
||
|
||
# 流调度
|
||
self.current_stream = 0
|
||
self.schedule_lock = threading.Lock()
|
||
|
||
def infer_async(self, batch_data: np.ndarray) -> Tuple[List[np.ndarray], float, int]:
|
||
"""异步推理,返回结果、延迟和使用的流ID"""
|
||
# 选择可用的流
|
||
with self.schedule_lock:
|
||
stream_id = self.current_stream
|
||
self.current_stream = (self.current_stream + 1) % len(self.engines)
|
||
|
||
# 使用选定的流进行推理
|
||
engine = self.engines[stream_id]
|
||
with self.stream_locks[stream_id]:
|
||
outputs, inference_time = engine.infer_batch(batch_data)
|
||
|
||
return outputs, inference_time, stream_id
|
||
|
||
def get_combined_stats(self) -> Dict[str, Any]:
|
||
"""获取所有流的综合统计"""
|
||
all_stats = []
|
||
for i, engine in enumerate(self.engines):
|
||
stats = engine.get_performance_stats()
|
||
if stats:
|
||
stats['stream_id'] = i
|
||
all_stats.append(stats)
|
||
|
||
if not all_stats:
|
||
return {}
|
||
|
||
# 合并统计
|
||
total_inferences = sum(s['total_inferences'] for s in all_stats)
|
||
total_frames = sum(s['total_frames_processed'] for s in all_stats)
|
||
|
||
all_times = []
|
||
for engine in self.engines:
|
||
all_times.extend(engine.inference_times)
|
||
|
||
if not all_times:
|
||
return {}
|
||
|
||
times = np.array(all_times)
|
||
|
||
return {
|
||
"num_streams": len(self.engines),
|
||
"total_inferences": total_inferences,
|
||
"total_frames_processed": total_frames,
|
||
"avg_inference_time_ms": float(np.mean(times)),
|
||
"min_inference_time_ms": float(np.min(times)),
|
||
"max_inference_time_ms": float(np.max(times)),
|
||
"p95_inference_time_ms": float(np.percentile(times, 95)),
|
||
"combined_fps": float(total_frames / (np.sum(times) / 1000)) if np.sum(times) > 0 else 0,
|
||
"per_stream_stats": all_stats
|
||
}
|
||
|
||
def reset_all_stats(self):
|
||
"""重置所有流的统计"""
|
||
for engine in self.engines:
|
||
engine.reset_stats()
|
||
|
||
def cleanup(self):
|
||
"""清理所有资源"""
|
||
for engine in self.engines:
|
||
engine.cleanup()
|
||
self.engines.clear()
|
||
logger.info("多流 TensorRT 引擎已清理")
|
||
|
||
|
||
def create_optimized_engine(model_path: str, output_path: str, config: TensorRTConfig) -> bool:
|
||
"""创建优化的 TensorRT 引擎"""
|
||
if not TRT_AVAILABLE:
|
||
logger.error("TensorRT 不可用")
|
||
return False
|
||
|
||
try:
|
||
from ultralytics import YOLO
|
||
|
||
logger.info(f"开始构建优化的 TensorRT 引擎...")
|
||
logger.info(f" 模型: {model_path}")
|
||
logger.info(f" 输出: {output_path}")
|
||
logger.info(f" 最大批次: {config.max_batch_size}")
|
||
logger.info(f" FP16: {config.fp16_mode}")
|
||
|
||
# 使用 Ultralytics 导出,但配置更激进的优化参数
|
||
model = YOLO(model_path)
|
||
|
||
# 导出为 TensorRT 引擎
|
||
model.export(
|
||
format="engine",
|
||
imgsz=320, # 可以根据需要调整
|
||
half=config.fp16_mode,
|
||
dynamic=True,
|
||
batch=config.max_batch_size,
|
||
workspace=config.max_workspace_size // (1024**3), # 转换为 GB
|
||
verbose=True
|
||
)
|
||
|
||
logger.info("TensorRT 引擎构建完成")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"构建 TensorRT 引擎失败: {e}")
|
||
return False |