Files
Security_AI_integrated/inference/engine.py

407 lines
12 KiB
Python
Raw Normal View History

2026-01-20 17:42:18 +08:00
import os
2026-01-21 13:29:39 +08:00
os.environ["TENSORRT_DISABLE_MYELIN"] = "1"
2026-01-20 17:42:18 +08:00
import time
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
import tensorrt as trt
import torch
2026-01-21 13:29:39 +08:00
import onnxruntime as ort
2026-01-20 17:42:18 +08:00
from ultralytics import YOLO
2026-01-21 13:29:39 +08:00
from ultralytics.engine.results import Results, Boxes as UltralyticsBoxes
2026-01-20 17:42:18 +08:00
from config import get_config
2026-01-21 13:29:39 +08:00
class ONNXEngine:
def __init__(self, onnx_path: Optional[str] = None, device: int = 0):
config = get_config()
self.onnx_path = onnx_path or config.model.onnx_path
self.device = device
self.imgsz = tuple(config.model.imgsz)
self.conf_thresh = config.model.conf_threshold
self.iou_thresh = config.model.iou_threshold
self.session = None
self.input_names = None
self.output_names = None
self.load_model()
def load_model(self):
if not os.path.exists(self.onnx_path):
raise FileNotFoundError(f"ONNX模型文件不存在: {self.onnx_path}")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if self.device >= 0 else ['CPUExecutionProvider']
self.session = ort.InferenceSession(self.onnx_path, providers=providers)
self.input_names = [inp.name for inp in self.session.get_inputs()]
self.output_names = [out.name for out in self.session.get_outputs()]
def preprocess(self, frame: np.ndarray) -> np.ndarray:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, self.imgsz)
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
return img
def postprocess(self, output: np.ndarray, orig_img: np.ndarray) -> List[Results]:
import torch
import numpy as np
from ultralytics.engine.results import Boxes as BoxesObj, Results
2026-01-21 13:29:39 +08:00
c, n = output.shape
output = output.T
boxes = output[:, :4]
scores = output[:, 4]
classes = output[:, 5:].argmax(axis=1) if output.shape[1] > 5 else np.zeros(len(output), dtype=np.int32)
mask = scores > self.conf_thresh
boxes = boxes[mask]
scores = scores[mask]
classes = classes[mask]
if len(boxes) == 0:
return [Results(orig_img=orig_img, path="", names={0: "person"})]
indices = cv2.dnn.NMSBoxes(
boxes.tolist(),
scores.tolist(),
self.conf_thresh,
self.iou_thresh,
)
orig_h, orig_w = orig_img.shape[:2]
scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0]
if len(indices) == 0:
return [Results(orig_img=orig_img, path="", names={0: "person"})]
2026-01-21 13:29:39 +08:00
filtered_boxes = []
for idx in indices:
if idx >= len(boxes):
continue
box = boxes[idx]
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
filtered_boxes.append([
float(x1 * scale_x),
float(y1 * scale_y),
float(w * scale_x),
float(h * scale_y),
2026-01-21 13:29:39 +08:00
float(scores[idx]),
int(classes[idx])
])
if filtered_boxes:
box_array = np.array(filtered_boxes, dtype=np.float32)
else:
box_array = np.zeros((0, 6), dtype=np.float32)
2026-01-21 13:29:39 +08:00
boxes_obj = BoxesObj(
torch.from_numpy(box_array),
orig_shape=(orig_h, orig_w)
)
result = Results(
orig_img=orig_img,
path="",
names={0: "person"},
boxes=boxes_obj
)
return [result]
2026-01-21 13:29:39 +08:00
def inference(self, images: List[np.ndarray]) -> List[Results]:
if not images:
return []
batch_imgs = []
for frame in images:
img = self.preprocess(frame)
batch_imgs.append(img)
batch = np.stack(batch_imgs, axis=0)
inputs = {self.input_names[0]: batch}
outputs = self.session.run(self.output_names, inputs)
results = []
output = outputs[0]
if output.shape[0] == 1:
result = self.postprocess(output[0], images[0])
results.extend(result)
else:
for i in range(output.shape[0]):
result = self.postprocess(output[i], images[i])
results.extend(result)
return results
def inference_single(self, frame: np.ndarray) -> List[Results]:
return self.inference([frame])
def warmup(self, num_warmup: int = 10):
dummy_frame = np.zeros((640, 640, 3), dtype=np.uint8)
for _ in range(num_warmup):
self.inference_single(dummy_frame)
def __del__(self):
if self.session:
try:
self.session.end_profiling()
except Exception:
pass
2026-01-20 17:42:18 +08:00
class TensorRTEngine:
def __init__(self, engine_path: Optional[str] = None, device: int = 0):
config = get_config()
self.engine_path = engine_path or config.model.engine_path
self.device = device
self.imgsz = tuple(config.model.imgsz)
self.conf_thresh = config.model.conf_threshold
self.iou_thresh = config.model.iou_threshold
self.half = config.model.half
self.logger = trt.Logger(trt.Logger.INFO)
self.engine = None
self.context = None
2026-01-21 13:29:39 +08:00
self.stream = torch.cuda.Stream(device=self.device)
2026-01-20 17:42:18 +08:00
self.input_buffer = None
self.output_buffers = []
2026-01-21 13:29:39 +08:00
self.input_name = None
self.output_name = None
2026-01-20 17:42:18 +08:00
self._load_engine()
def _load_engine(self):
if not os.path.exists(self.engine_path):
raise FileNotFoundError(f"TensorRT引擎文件不存在: {self.engine_path}")
with open(self.engine_path, "rb") as f:
serialized_engine = f.read()
runtime = trt.Runtime(self.logger)
self.engine = runtime.deserialize_cuda_engine(serialized_engine)
self.context = self.engine.create_execution_context()
self.stream = torch.cuda.Stream(device=self.device)
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
dtype = self.engine.get_tensor_dtype(name)
2026-01-21 13:29:39 +08:00
shape = list(self.engine.get_tensor_shape(name))
2026-01-20 17:42:18 +08:00
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
2026-01-20 17:42:18 +08:00
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
2026-01-21 13:29:39 +08:00
self.input_buffer = buffer
self.input_name = name
2026-01-20 17:42:18 +08:00
else:
self.output_buffers.append(buffer)
2026-01-21 13:29:39 +08:00
if self.output_name is None:
self.output_name = name
2026-01-20 17:42:18 +08:00
2026-01-21 13:29:39 +08:00
self.context.set_tensor_address(name, buffer.data_ptr())
2026-01-20 17:42:18 +08:00
2026-01-21 13:29:39 +08:00
stream_handle = torch.cuda.current_stream(self.device).cuda_stream
self.context.set_optimization_profile_async(0, stream_handle)
2026-01-20 17:42:18 +08:00
def preprocess(self, frame: np.ndarray) -> torch.Tensor:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, self.imgsz)
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
if self.half:
img = img.astype(np.float16)
tensor = torch.from_numpy(img).unsqueeze(0).to(self.device)
return tensor
def inference(self, images: List[np.ndarray]) -> List[Results]:
batch_size = len(images)
if batch_size == 0:
return []
input_tensor = self.preprocess(images[0])
if batch_size > 1:
for i in range(1, batch_size):
input_tensor = torch.cat(
[input_tensor, self.preprocess(images[i])], dim=0
)
self.context.set_tensor_address(
2026-01-21 13:29:39 +08:00
self.input_name, input_tensor.contiguous().data_ptr()
2026-01-20 17:42:18 +08:00
)
torch.cuda.synchronize(self.stream)
2026-01-21 13:29:39 +08:00
self.context.execute_async_v3(self.stream.cuda_stream)
2026-01-20 17:42:18 +08:00
torch.cuda.synchronize(self.stream)
results = []
for i in range(batch_size):
pred = self.output_buffers[0][i].cpu().numpy()
2026-01-21 13:29:39 +08:00
pred = pred.T # 转置: (8400, 84)
2026-01-20 17:42:18 +08:00
boxes = pred[:, :4]
scores = pred[:, 4]
classes = pred[:, 5].astype(np.int32)
mask = scores > self.conf_thresh
boxes = boxes[mask]
scores = scores[mask]
classes = classes[mask]
indices = cv2.dnn.NMSBoxes(
boxes.tolist(),
scores.tolist(),
self.conf_thresh,
self.iou_thresh,
)
if len(indices) > 0:
for idx in indices:
box = boxes[idx]
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
conf = scores[idx]
cls = classes[idx]
orig_h, orig_w = images[i].shape[:2]
scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0]
box_orig = [
int(x1 * scale_x),
int(y1 * scale_y),
int(w * scale_x),
int(h * scale_y),
]
result = Results(
orig_img=images[i],
path="",
names={0: "person"},
2026-01-21 13:29:39 +08:00
boxes=UltralyticsBoxes(
2026-01-20 17:42:18 +08:00
torch.tensor([box_orig + [conf, cls]]),
orig_shape=(orig_h, orig_w),
),
)
results.append(result)
return results
def inference_single(self, frame: np.ndarray) -> List[Results]:
return self.inference([frame])
def warmup(self, num_warmup: int = 10):
dummy_frame = np.zeros((480, 640, 3), dtype=np.uint8)
for _ in range(num_warmup):
self.inference_single(dummy_frame)
def __del__(self):
if self.context:
2026-01-21 13:29:39 +08:00
try:
self.context.synchronize()
except Exception:
pass
2026-01-20 17:42:18 +08:00
if self.stream:
2026-01-21 13:29:39 +08:00
try:
self.stream.synchronize()
except Exception:
pass
2026-01-20 17:42:18 +08:00
class Boxes:
def __init__(
self,
data: torch.Tensor,
orig_shape: Tuple[int, int],
is_track: bool = False,
):
self.data = data
self.orig_shape = orig_shape
self.is_track = is_track
@property
def ndim(self) -> int:
return self.data.ndim
2026-01-20 17:42:18 +08:00
@property
def xyxy(self):
if self.is_track:
return self.data[:, :4]
return self.data[:, :4]
@property
def conf(self):
if self.is_track:
return self.data[:, 4]
return self.data[:, 4]
@property
def cls(self):
if self.is_track:
return self.data[:, 5]
return self.data[:, 5]
2026-01-21 13:29:39 +08:00
def _check_pt_file_valid(pt_path: str) -> bool:
try:
with open(pt_path, 'rb') as f:
header = f.read(10)
return len(header) == 10
except Exception:
return False
2026-01-20 17:42:18 +08:00
class YOLOEngine:
def __init__(
self,
model_path: Optional[str] = None,
device: int = 0,
use_trt: bool = False,
2026-01-20 17:42:18 +08:00
):
2026-01-21 13:29:39 +08:00
self.use_trt = False
self.onnx_engine = None
2026-01-20 17:42:18 +08:00
self.trt_engine = None
self.model = None
2026-01-21 13:29:39 +08:00
self.device = device
config = get_config()
self.config = config
2026-01-21 13:29:39 +08:00
try:
pt_path = model_path or config.model.pt_model_path
if os.path.exists(pt_path) and _check_pt_file_valid(pt_path):
2026-01-20 17:42:18 +08:00
self.model = YOLO(pt_path)
self.model.to(device)
2026-01-21 13:29:39 +08:00
print(f"PyTorch模型加载成功: {pt_path}")
else:
raise FileNotFoundError(f"PT文件无效或不存在: {pt_path}")
except Exception as e:
print(f"PyTorch加载失败: {e}")
raise RuntimeError("无法加载模型")
2026-01-20 17:42:18 +08:00
def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]:
if self.model is not None:
2026-01-21 13:29:39 +08:00
try:
return self.model(frame, imgsz=self.config.model.imgsz, conf=self.config.model.conf_threshold, iou=self.config.model.iou_threshold, **kwargs)
2026-01-21 13:29:39 +08:00
except Exception as e:
print(f"PyTorch推理失败: {e}")
print("警告: 模型不可用,返回空结果")
return []
2026-01-20 17:42:18 +08:00
def __del__(self):
if self.trt_engine:
del self.trt_engine
2026-01-21 13:29:39 +08:00
if self.onnx_engine:
del self.onnx_engine