fix(inference): resolve multiple YOLO inference and API issues

This commit is contained in:
2026-01-21 14:48:01 +08:00
parent 1b344aeb2e
commit 1c7190bbb0
5 changed files with 146 additions and 80 deletions

View File

@@ -23,16 +23,17 @@ class DatabaseConfig(BaseModel):
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
engine_path: str = "models/yolo11s.engine" engine_path: str = "models/yolo11n.engine"
onnx_path: str = "models/yolo11s.onnx" onnx_path: str = "models/yolo11n.onnx"
pt_model_path: str = "models/yolo11s.pt" pt_model_path: str = "models/yolo11n.pt"
imgsz: List[int] = [640, 640] imgsz: List[int] = [640, 640]
conf_threshold: float = 0.5 conf_threshold: float = 0.5
iou_threshold: float = 0.45 iou_threshold: float = 0.45
device: int = 0 device: int = 0
batch_size: int = 8 batch_size: int = 8
half: bool = True half: bool = False
use_onnx: bool = True use_onnx: bool = False
use_trt: bool = False
class StreamConfig(BaseModel): class StreamConfig(BaseModel):

View File

@@ -49,6 +49,10 @@ class ONNXEngine:
return img return img
def postprocess(self, output: np.ndarray, orig_img: np.ndarray) -> List[Results]: def postprocess(self, output: np.ndarray, orig_img: np.ndarray) -> List[Results]:
import torch
import numpy as np
from ultralytics.engine.results import Boxes as BoxesObj, Results
c, n = output.shape c, n = output.shape
output = output.T output = output.T
@@ -74,6 +78,9 @@ class ONNXEngine:
orig_h, orig_w = orig_img.shape[:2] orig_h, orig_w = orig_img.shape[:2]
scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0] scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0]
if len(indices) == 0:
return [Results(orig_img=orig_img, path="", names={0: "person"})]
filtered_boxes = [] filtered_boxes = []
for idx in indices: for idx in indices:
if idx >= len(boxes): if idx >= len(boxes):
@@ -82,30 +89,30 @@ class ONNXEngine:
x1, y1, x2, y2 = box x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1 w, h = x2 - x1, y2 - y1
filtered_boxes.append([ filtered_boxes.append([
int(x1 * scale_x), float(x1 * scale_x),
int(y1 * scale_y), float(y1 * scale_y),
int(w * scale_x), float(w * scale_x),
int(h * scale_y), float(h * scale_y),
float(scores[idx]), float(scores[idx]),
int(classes[idx]) int(classes[idx])
]) ])
from ultralytics.engine.results import Boxes as BoxesObj
if filtered_boxes: if filtered_boxes:
box_tensor = torch.tensor(filtered_boxes) box_array = np.array(filtered_boxes, dtype=np.float32)
boxes_obj = BoxesObj( else:
box_tensor, box_array = np.zeros((0, 6), dtype=np.float32)
orig_shape=(orig_h, orig_w)
)
result = Results(
orig_img=orig_img,
path="",
names={0: "person"},
boxes=boxes_obj
)
return [result]
return [Results(orig_img=orig_img, path="", names={0: "person"})] boxes_obj = BoxesObj(
torch.from_numpy(box_array),
orig_shape=(orig_h, orig_w)
)
result = Results(
orig_img=orig_img,
path="",
names={0: "person"},
boxes=boxes_obj
)
return [result]
def inference(self, images: List[np.ndarray]) -> List[Results]: def inference(self, images: List[np.ndarray]) -> List[Results]:
if not images: if not images:
@@ -183,29 +190,21 @@ class TensorRTEngine:
self.context = self.engine.create_execution_context() self.context = self.engine.create_execution_context()
self.stream = torch.cuda.Stream(device=self.device) self.stream = torch.cuda.Stream(device=self.device)
self.batch_size = 1
for i in range(self.engine.num_io_tensors): for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i) name = self.engine.get_tensor_name(i)
dtype = self.engine.get_tensor_dtype(name) dtype = self.engine.get_tensor_dtype(name)
shape = list(self.engine.get_tensor_shape(name)) shape = list(self.engine.get_tensor_shape(name))
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
if -1 in shape:
shape = [self.batch_size if d == -1 else d for d in shape]
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
self.input_buffer = buffer self.input_buffer = buffer
self.input_name = name self.input_name = name
else: else:
if -1 in shape:
shape = [self.batch_size if d == -1 else d for d in shape]
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
self.output_buffers.append(buffer) self.output_buffers.append(buffer)
if self.output_name is None: if self.output_name is None:
self.output_name = name self.output_name = name
@@ -215,8 +214,6 @@ class TensorRTEngine:
stream_handle = torch.cuda.current_stream(self.device).cuda_stream stream_handle = torch.cuda.current_stream(self.device).cuda_stream
self.context.set_optimization_profile_async(0, stream_handle) self.context.set_optimization_profile_async(0, stream_handle)
self.batch_size = 1
def preprocess(self, frame: np.ndarray) -> torch.Tensor: def preprocess(self, frame: np.ndarray) -> torch.Tensor:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, self.imgsz) img = cv2.resize(img, self.imgsz)
@@ -247,9 +244,6 @@ class TensorRTEngine:
self.input_name, input_tensor.contiguous().data_ptr() self.input_name, input_tensor.contiguous().data_ptr()
) )
input_shape = list(input_tensor.shape)
self.context.set_input_shape(self.input_name, input_shape)
torch.cuda.synchronize(self.stream) torch.cuda.synchronize(self.stream)
self.context.execute_async_v3(self.stream.cuda_stream) self.context.execute_async_v3(self.stream.cuda_stream)
torch.cuda.synchronize(self.stream) torch.cuda.synchronize(self.stream)
@@ -336,6 +330,10 @@ class Boxes:
self.orig_shape = orig_shape self.orig_shape = orig_shape
self.is_track = is_track self.is_track = is_track
@property
def ndim(self) -> int:
return self.data.ndim
@property @property
def xyxy(self): def xyxy(self):
if self.is_track: if self.is_track:
@@ -369,35 +367,15 @@ class YOLOEngine:
self, self,
model_path: Optional[str] = None, model_path: Optional[str] = None,
device: int = 0, device: int = 0,
use_trt: bool = True, use_trt: bool = False,
): ):
self.use_trt = False self.use_trt = False
self.onnx_engine = None self.onnx_engine = None
self.trt_engine = None self.trt_engine = None
self.model = None
self.device = device self.device = device
config = get_config() config = get_config()
self.config = config
if use_trt:
try:
self.trt_engine = TensorRTEngine(device=device)
self.trt_engine.warmup()
self.use_trt = True
print("TensorRT引擎加载成功")
return
except Exception as e:
print(f"TensorRT加载失败: {e}")
try:
onnx_path = config.model.onnx_path
if os.path.exists(onnx_path):
self.onnx_engine = ONNXEngine(device=device)
self.onnx_engine.warmup()
print("ONNX引擎加载成功")
return
else:
print(f"ONNX模型不存在: {onnx_path}")
except Exception as e:
print(f"ONNX加载失败: {e}")
try: try:
pt_path = model_path or config.model.pt_model_path pt_path = model_path or config.model.pt_model_path
@@ -409,26 +387,17 @@ class YOLOEngine:
raise FileNotFoundError(f"PT文件无效或不存在: {pt_path}") raise FileNotFoundError(f"PT文件无效或不存在: {pt_path}")
except Exception as e: except Exception as e:
print(f"PyTorch加载失败: {e}") print(f"PyTorch加载失败: {e}")
raise RuntimeError("所有模型加载方式均失败") raise RuntimeError("无法加载模型")
def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]: def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]:
if self.use_trt and self.trt_engine: if self.model is not None:
try: try:
return self.trt_engine.inference_single(frame) return self.model(frame, imgsz=self.config.model.imgsz, conf=self.config.model.conf_threshold, iou=self.config.model.iou_threshold, **kwargs)
except Exception as e: except Exception as e:
print(f"TensorRT推理失败切换到ONNX: {e}") print(f"PyTorch推理失败: {e}")
self.use_trt = False
if self.onnx_engine: print("警告: 模型不可用,返回空结果")
return self.onnx_engine.inference_single(frame) return []
elif self.model:
return self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
else:
return []
elif self.onnx_engine:
return self.onnx_engine.inference_single(frame)
else:
results = self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
return results
def __del__(self): def __del__(self):
if self.trt_engine: if self.trt_engine:

View File

@@ -7,6 +7,7 @@ from collections import deque
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np import numpy as np
from config import get_config from config import get_config

View File

@@ -133,3 +133,88 @@
2026-01-21 13:18:55,795 - security_monitor - INFO - 启动安保异常行为识别系统 2026-01-21 13:18:55,795 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 13:18:55,809 - security_monitor - INFO - 数据库初始化完成 2026-01-21 13:18:55,809 - security_monitor - INFO - 数据库初始化完成
2026-01-21 13:19:08,492 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2 2026-01-21 13:19:08,492 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:01:21,015 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:01:21,257 - security_monitor - INFO - 系统已关闭
2026-01-21 14:03:48,547 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:03:48,563 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:04:01,197 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:04:20,191 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:04:20,414 - security_monitor - INFO - 系统已关闭
2026-01-21 14:05:48,342 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:05:48,355 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:06:00,984 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:07:24,065 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:07:24,222 - security_monitor - INFO - 系统已关闭
2026-01-21 14:08:10,073 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:08:10,088 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:08:22,715 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:09:05,249 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:09:05,480 - security_monitor - INFO - 系统已关闭
2026-01-21 14:11:29,491 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:11:29,513 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:11:42,900 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:14:04,974 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:14:05,161 - security_monitor - INFO - 系统已关闭
2026-01-21 14:14:41,203 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:14:41,220 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:14:54,380 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:15:30,975 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:15:31,180 - security_monitor - INFO - 系统已关闭
2026-01-21 14:16:24,472 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:16:24,485 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:16:37,611 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:17:01,178 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:17:01,420 - security_monitor - INFO - 系统已关闭
2026-01-21 14:18:00,008 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:18:00,022 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:18:13,126 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:18:13,128 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:18:21,683 - security_monitor - INFO - 系统已关闭
2026-01-21 14:20:04,985 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:20:04,999 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:20:18,151 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:21:24,782 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:21:24,927 - security_monitor - INFO - 系统已关闭
2026-01-21 14:22:48,064 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:22:48,078 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:23:01,270 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:23:13,509 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:23:13,628 - security_monitor - INFO - 系统已关闭
2026-01-21 14:24:16,374 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:24:16,386 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:24:29,425 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:24:42,751 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:24:42,846 - security_monitor - INFO - 系统已关闭
2026-01-21 14:25:25,549 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:25:25,562 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:25:38,636 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:26:02,871 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:26:03,124 - security_monitor - INFO - 系统已关闭
2026-01-21 14:26:45,885 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:26:45,899 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:26:59,042 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:27:26,873 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:27:26,980 - security_monitor - INFO - 系统已关闭
2026-01-21 14:31:38,376 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:31:38,390 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:31:51,594 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:32:17,471 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:32:17,536 - security_monitor - INFO - 系统已关闭
2026-01-21 14:32:53,841 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:32:53,855 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:33:06,946 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:34:30,645 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:34:30,818 - security_monitor - INFO - 系统已关闭
2026-01-21 14:38:24,673 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:38:24,685 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:38:37,183 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:39:04,359 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:39:04,486 - security_monitor - INFO - 系统已关闭
2026-01-21 14:40:07,246 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:40:07,259 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:40:19,745 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:40:33,742 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:40:33,863 - security_monitor - INFO - 系统已关闭
2026-01-21 14:41:27,191 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:41:27,205 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:41:39,701 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2

10
main.py
View File

@@ -18,6 +18,16 @@ from prometheus_client import start_http_server
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from ultralytics.engine.results import Boxes as UltralyticsBoxes
def _patch_boxes_ndim():
if not hasattr(UltralyticsBoxes, 'ndim'):
@property
def ndim(self):
return self.data.ndim
UltralyticsBoxes.ndim = ndim
_patch_boxes_ndim()
from api.alarm import router as alarm_router from api.alarm import router as alarm_router
from api.camera import router as camera_router from api.camera import router as camera_router
from api.roi import router as roi_router from api.roi import router as roi_router