Files
security-ai-edge/core/postprocessor.py
16337 7a5ddef2f6 feat: 支持按算法独立配置置信度阈值
通过 ALGO_CONF_{ALGO_CODE} 环境变量为每个算法设置独立的 conf_threshold,
未配置的算法回退到全局 CONF_THRESHOLD。推理过程零改动,仅后处理过滤阶段
按 bind.algo_code 使用对应阈值。

当前配置:离岗=0.4(降低漏检),入侵=0.5(减少误报)。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-08 17:31:12 +08:00

954 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
后处理与业务逻辑模块
实现NMS算法、坐标映射、ROI区域判定、告警状态机
"""
import logging
import threading
import time
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
import numpy as np
try:
import torch
from torchvision.ops import nms as torch_nms
_HAS_TORCH = True
except ImportError:
_HAS_TORCH = False
from config.settings import get_settings
from config.config_models import ROIInfo, ROIType, AlertInfo, AlertLevel
from utils.logger import get_logger
from utils.common import generate_unique_id
logger = logging.getLogger(__name__)
class NMSProcessor:
"""非极大值抑制处理器 (向量化版本)
使用纯 NumPy 向量化操作,避免 Python 循环
可选 GPU 加速 (torchvision.ops.nms)
"""
def __init__(self, nms_threshold: float = 0.5, use_gpu: bool = False):
self.nms_threshold = nms_threshold
self.use_gpu = use_gpu and _HAS_TORCH
self._logger = get_logger("postprocessor")
if self.use_gpu:
self._logger.debug("NMS 使用 GPU 加速 (torchvision.ops.nms)")
def process(
self,
boxes: np.ndarray,
scores: np.ndarray,
class_ids: Optional[np.ndarray] = None,
max_output_size: int = 300
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
执行NMS (向量化版本,可选 GPU 加速)
Args:
boxes: 检测框数组 [N, 4] (x1, y1, x2, y2)
scores: 置信度数组 [N]
class_ids: 类别ID数组 [N]
max_output_size: 最大输出数量
Returns:
tuple: (保留的框索引, 保留的置信度, 保留的类别ID)
"""
if len(boxes) == 0:
return np.array([], dtype=np.int32), np.array([]), np.array([])
if self.use_gpu and _HAS_TORCH:
return self._process_gpu(boxes, scores, class_ids, max_output_size)
return self._process_cpu(boxes, scores, class_ids, max_output_size)
def _process_gpu(
self,
boxes: np.ndarray,
scores: np.ndarray,
class_ids: Optional[np.ndarray],
max_output_size: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""GPU 加速 NMS"""
boxes_t = torch.from_numpy(boxes).cuda()
scores_t = torch.from_numpy(scores).cuda()
keep = torch_nms(boxes_t, scores_t, iou_threshold=self.nms_threshold)
keep_np = keep.cpu().numpy()
if len(keep_np) > max_output_size:
top_k = np.argsort(scores[keep_np])[::-1][:max_output_size]
keep_np = keep_np[top_k]
return (
keep_np.astype(np.int32),
scores[keep_np],
class_ids[keep_np] if class_ids is not None else np.array([])
)
def _process_cpu(
self,
boxes: np.ndarray,
scores: np.ndarray,
class_ids: Optional[np.ndarray],
max_output_size: int
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""CPU 向量化 NMS"""
order = np.argsort(scores)[::-1]
keep = []
suppressed = np.zeros(len(boxes), dtype=bool)
for i in range(len(order)):
if len(keep) >= max_output_size:
break
idx = order[i]
if suppressed[idx]:
continue
keep.append(idx)
remaining = order[i + 1:]
if len(remaining) == 0:
break
remaining = remaining[~suppressed[remaining]]
if len(remaining) == 0:
break
xx1 = np.maximum(boxes[idx, 0], boxes[remaining, 0])
yy1 = np.maximum(boxes[idx, 1], boxes[remaining, 1])
xx2 = np.minimum(boxes[idx, 2], boxes[remaining, 2])
yy2 = np.minimum(boxes[idx, 3], boxes[remaining, 3])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
areas = (boxes[:, 2] - boxes[:, 0] + 1) * (boxes[:, 3] - boxes[:, 1] + 1)
ovr = inter / (areas[idx] + areas[remaining] - inter + 1e-6)
suppress_mask = ovr > self.nms_threshold
suppressed[remaining[suppress_mask]] = True
keep_indices = np.array(keep, dtype=np.int32) if keep else np.array([], dtype=np.int32)
if len(keep_indices) == 0:
return (
np.array([], dtype=np.int32),
np.array([]),
np.array([])
)
return (
keep_indices,
scores[keep_indices],
class_ids[keep_indices] if class_ids is not None else np.array([])
)
def process_with_confidence_filter(
self,
boxes: np.ndarray,
scores: np.ndarray,
class_ids: np.ndarray,
conf_threshold: float = 0.45,
max_output_size: int = 300
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
执行置信度过滤和NMS
Args:
boxes: 检测框数组
scores: 置信度数组
class_ids: 类别ID数组
conf_threshold: 置信度阈值
max_output_size: 最大输出数量
Returns:
tuple: (保留的框, 置信度, 类别ID)
"""
mask = scores >= conf_threshold
boxes = boxes[mask]
scores = scores[mask]
class_ids = class_ids[mask]
if len(boxes) == 0:
return np.array([]).reshape(0, 4), np.array([]), np.array([])
keep_indices, filtered_scores, filtered_class_ids = self.process(
boxes, scores, class_ids, max_output_size
)
return boxes[keep_indices], filtered_scores, filtered_class_ids
class CoordinateMapper:
"""坐标映射系统
实现多级坐标映射: 480x480 → ROI空间 → 原始帧坐标
"""
def __init__(self):
self._logger = get_logger("postprocessor")
def map_from_letterbox(
self,
box: List[float],
scale_info: Tuple[float, float, float, float],
original_size: Tuple[int, int]
) -> List[float]:
"""
将检测框从Letterbox空间映射回原始空间
Args:
box: Letterbox空间中的坐标 [x1, y1, x2, y2]
scale_info: 缩放信息 (scale, pad_x, pad_y, scale)
original_size: 原始图像尺寸 (width, height)
Returns:
原始空间中的坐标
"""
scale, pad_x, pad_y, _ = scale_info
x1 = (box[0] - pad_x) / scale
y1 = (box[1] - pad_y) / scale
x2 = (box[2] - pad_x) / scale
y2 = (box[3] - pad_y) / scale
orig_width, orig_height = original_size
x1 = max(0, min(x1, orig_width - 1))
y1 = max(0, min(y1, orig_height - 1))
x2 = max(0, min(x2, orig_width))
y2 = max(0, min(y2, orig_height))
return [x1, y1, x2, y2]
def map_roi_to_frame(
self,
box: List[float],
roi_crop_info: Dict[str, Any]
) -> List[float]:
"""
将ROI空间中的坐标映射回原始帧坐标
Args:
box: ROI空间中的坐标
roi_crop_info: ROI裁剪信息
Returns:
原始帧中的坐标
"""
x1, y1, x2, y2 = box
offset_x = roi_crop_info.get("offset_x", 0)
offset_y = roi_crop_info.get("offset_y", 0)
if "scale" in roi_crop_info:
scale = roi_crop_info["scale"]
x1 = x1 * scale + offset_x
y1 = y1 * scale + offset_y
x2 = x2 * scale + offset_x
y2 = y2 * scale + offset_y
else:
x1 += offset_x
y1 += offset_y
x2 += offset_x
y2 += offset_y
return [x1, y1, x2, y2]
def get_box_center(self, box: List[float]) -> Tuple[float, float]:
"""获取检测框中心点"""
return (box[0] + box[2]) / 2, (box[1] + box[3]) / 2
def get_box_bottom_center(self, box: List[float]) -> Tuple[float, float]:
"""获取检测框底部中心点"""
return (box[0] + box[2]) / 2, box[3]
def get_box_area(self, box: List[float]) -> float:
"""计算检测框面积"""
width = box[2] - box[0]
height = box[3] - box[1]
return width * height
class ROIAnalyzer:
"""ROI区域分析器
实现ROI区域内目标检测与判定
"""
def __init__(self):
self._logger = get_logger("postprocessor")
def is_point_in_roi(
self,
point: Tuple[float, float],
roi: ROIInfo
) -> bool:
"""
判断点是否在ROI区域内
Args:
point: (x, y) 坐标
roi: ROI配置
Returns:
是否在区域内
"""
return roi.is_point_inside(list(point))
def is_detection_in_roi(
self,
box: List[float],
roi: ROIInfo,
check_mode: str = "center"
) -> bool:
"""
判断检测是否在ROI区域内
Args:
box: 检测框坐标
roi: ROI配置
check_mode: 检查模式 (center, bottom, any)
Returns:
是否在区域内
"""
if check_mode == "center":
point = self._get_box_center(box)
elif check_mode == "bottom":
point = self._get_box_bottom_center(box)
else:
point = self._get_box_center(box)
return self.is_point_in_roi(point, roi)
def _get_box_center(self, box: List[float]) -> Tuple[float, float]:
"""获取检测框中心点"""
return (box[0] + box[2]) / 2, (box[1] + box[3]) / 2
def _get_box_bottom_center(self, box: List[float]) -> Tuple[float, float]:
"""获取检测框底部中心点"""
return (box[0] + box[2]) / 2, box[3]
def filter_detections_in_roi(
self,
boxes: np.ndarray,
roi: ROIInfo,
check_mode: str = "center"
) -> np.ndarray:
"""
筛选在ROI区域内的检测
Args:
boxes: 检测框数组
roi: ROI配置
check_mode: 检查模式
Returns:
符合条件的检测框索引
"""
valid_indices = []
for i, box in enumerate(boxes):
if self.is_detection_in_roi(box.tolist(), roi, check_mode):
valid_indices.append(i)
return np.array(valid_indices)
def get_roi_crop_info(
self,
frame_width: int,
frame_height: int,
roi: ROIInfo
) -> Dict[str, Any]:
"""
获取ROI裁剪信息
Args:
frame_width: 帧宽度
frame_height: 帧高度
roi: ROI配置
Returns:
裁剪信息字典
"""
if roi.roi_type == ROIType.RECTANGLE and len(roi.coordinates) >= 2:
x1 = int(roi.coordinates[0][0])
y1 = int(roi.coordinates[0][1])
x2 = int(roi.coordinates[1][0])
y2 = int(roi.coordinates[1][1])
offset_x = min(x1, x2)
offset_y = min(y1, y2)
crop_width = abs(x2 - x1)
crop_height = abs(y2 - y1)
elif roi.roi_type == ROIType.POLYGON:
xs = [c[0] for c in roi.coordinates]
ys = [c[1] for c in roi.coordinates]
offset_x = min(xs)
offset_y = min(ys)
crop_width = max(xs) - offset_x
crop_height = max(ys) - offset_y
else:
offset_x = offset_y = 0
crop_width = frame_width
crop_height = frame_height
return {
"offset_x": offset_x,
"offset_y": offset_y,
"crop_width": crop_width,
"crop_height": crop_height,
"frame_width": frame_width,
"frame_height": frame_height,
}
@dataclass
class AlarmState:
"""告警状态"""
roi_id: str
detection_count: int = 0
last_detection_time: Optional[datetime] = None
is_alarming: bool = False
alert_sent: bool = False
frame_buffer: deque = field(default_factory=deque)
class AlarmStateMachine:
"""告警状态机
实现连续帧检测触发机制,防止瞬时误报
"""
def __init__(self, alert_threshold: int = 3, alert_cooldown: int = 300):
"""
初始化告警状态机
Args:
alert_threshold: 连续检测触发阈值
alert_cooldown: 告警冷却时间(秒)
"""
self.alert_threshold = alert_threshold
self.alert_cooldown = alert_cooldown
self._states: Dict[str, AlarmState] = {}
self._lock = threading.RLock()
self._logger = get_logger("postprocessor")
def get_state(self, roi_id: str) -> AlarmState:
"""获取ROI告警状态"""
with self._lock:
if roi_id not in self._states:
self._states[roi_id] = AlarmState(roi_id=roi_id)
return self._states[roi_id]
def update(
self,
roi_id: str,
has_target: bool,
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
更新告警状态
Args:
roi_id: ROI标识
has_target: 是否有目标
current_time: 当前时间
Returns:
状态更新结果
"""
if current_time is None:
current_time = datetime.now()
with self._lock:
state = self.get_state(roi_id)
if state.is_alarming and state.alert_sent:
cooldown_elapsed = (
current_time - state.last_detection_time
).total_seconds() if state.last_detection_time else float('inf')
if cooldown_elapsed < self.alert_cooldown:
return {
"should_alert": False,
"reason": "cooldown",
"cooldown_remaining": self.alert_cooldown - cooldown_elapsed,
"state": state,
}
else:
state.alert_sent = False
state.is_alarming = False
state.detection_count = 0
if has_target:
state.detection_count += 1
state.last_detection_time = current_time
while len(state.frame_buffer) >= self.alert_threshold:
state.frame_buffer.popleft()
state.frame_buffer.append(current_time)
if (state.detection_count >= self.alert_threshold and
not state.is_alarming):
state.is_alarming = True
self._logger.log_alert(
"detection_threshold_reached",
"",
roi_id,
None
)
return {
"should_alert": True,
"reason": "threshold_reached",
"detection_count": state.detection_count,
"state": state,
}
else:
if state.detection_count > 0:
state.detection_count = max(0, state.detection_count - 1)
if not has_target and state.is_alarming:
state.is_alarming = False
return {
"should_alert": False,
"reason": "waiting",
"detection_count": state.detection_count,
"state": state,
}
def reset(self, roi_id: str):
"""重置ROI告警状态"""
with self._lock:
if roi_id in self._states:
self._states[roi_id] = AlarmState(roi_id=roi_id)
self._logger.info(f"告警状态已重置: {roi_id}")
def reset_all(self):
"""重置所有告警状态"""
with self._lock:
self._states.clear()
self._logger.info("所有告警状态已重置")
def get_all_states(self) -> Dict[str, AlarmState]:
"""获取所有状态"""
with self._lock:
return dict(self._states)
class PostProcessor:
"""后处理主类
整合NMS、坐标映射、ROI分析、告警状态机
"""
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""
初始化后处理器
Args:
config: 配置参数字典
"""
if config is None:
settings = get_settings()
config = {
"nms_threshold": settings.inference.nms_threshold,
"conf_threshold": settings.inference.conf_threshold,
"alert_threshold": 3,
"alert_cooldown": 300,
}
self.nms_threshold = config.get("nms_threshold", 0.5)
self.conf_threshold = config.get("conf_threshold", 0.45)
MAX_DETECTIONS = 8400
self._buffer_xyxy = np.zeros((MAX_DETECTIONS, 4), dtype=np.float32)
self._buffer_class_ids = np.zeros(MAX_DETECTIONS, dtype=np.int32)
self._buffer_boxes_xywh = np.zeros((MAX_DETECTIONS, 4), dtype=np.float32)
self._nms = NMSProcessor(self.nms_threshold, use_gpu=True)
self._mapper = CoordinateMapper()
self._roi_analyzer = ROIAnalyzer()
self._alarm_state_machine = AlarmStateMachine(
alert_threshold=config.get("alert_threshold", 3),
alert_cooldown=config.get("alert_cooldown", 300)
)
self._logger = get_logger("postprocessor")
self._logger.info(
f"后处理器初始化完成: NMS阈值={self.nms_threshold}, "
f"置信度阈值={self.conf_threshold}"
)
def process_detections(
self,
raw_outputs: List[np.ndarray],
conf_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
处理模型原始输出
Args:
raw_outputs: 模型原始输出
conf_threshold: 置信度阈值
nms_threshold: NMS阈值
Returns:
tuple: (检测框, 置信度, 类别ID)
"""
if conf_threshold is None:
conf_threshold = self.conf_threshold
if nms_threshold is None:
nms_threshold = self.nms_threshold
boxes, scores, class_ids = self._parse_yolo_output(
raw_outputs,
prefilter_threshold=0.3
)
if len(boxes) == 0:
return np.array([]), np.array([]), np.array([])
nms_processor = NMSProcessor(nms_threshold, use_gpu=True)
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
boxes, scores, class_ids, conf_threshold
)
return keep_boxes, keep_scores, keep_classes
def _parse_yolo_output(
self,
outputs: List[np.ndarray],
prefilter_threshold: float = 0.3
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
解析YOLO模型输出 - 向量化版本 + 预过滤
Args:
outputs: 模型输出列表
prefilter_threshold: 预过滤阈值低于此值的框直接丢弃减少NMS计算量
Returns:
tuple: (检测框, 置信度, 类别ID)
"""
if not outputs:
return np.array([]), np.array([]), np.array([])
output = outputs[0]
if output.ndim == 3:
output = output[0]
if output.ndim != 2:
return np.array([]), np.array([]), np.array([])
if output.shape[0] != 84:
return np.array([]), np.array([]), np.array([])
boxes_xywh = output[0:4, :].T
# YOLO11 输出格式: [4+num_classes, 8400]
# 前4行是 xywh后80行是各类别分数没有单独的 objectness 行
class_scores = output[4:, :] # [num_classes, 8400]
scores = class_scores.max(axis=0) # 取各类别最大分数
class_ids = class_scores.argmax(axis=0) # 对应类别ID
coarse_mask = scores > prefilter_threshold
if not np.any(coarse_mask):
return np.array([]), np.array([]), np.array([])
boxes = boxes_xywh[coarse_mask]
scores_coarse = scores[coarse_mask]
class_ids_filtered = class_ids[coarse_mask]
valid_count = len(boxes)
np.copyto(self._buffer_boxes_xywh[:valid_count], boxes)
self._buffer_xyxy[:valid_count, 0] = boxes[:, 0] - boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 1] = boxes[:, 1] - boxes[:, 3] / 2
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
self._buffer_class_ids[:valid_count] = class_ids_filtered
return (
self._buffer_xyxy[:valid_count].copy(),
scores_coarse.astype(np.float32),
self._buffer_class_ids[:valid_count].copy()
)
def batch_process_detections(
self,
batch_outputs: List[np.ndarray],
batch_size: int,
conf_threshold: Optional[float] = None,
nms_threshold: Optional[float] = None,
per_item_conf_thresholds: Optional[List[float]] = None,
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
批量处理多 ROI 检测结果
Args:
batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组)
batch_size: batch 大小
conf_threshold: 置信度阈值
nms_threshold: NMS阈值
per_item_conf_thresholds: 每个 batch item 独立的置信度阈值列表,
长度必须等于 batch_size。传入时忽略 conf_threshold 参数。
Returns:
List of (检测框, 置信度, 类别ID) per ROI
"""
default_conf = conf_threshold if conf_threshold is not None else self.conf_threshold
if nms_threshold is None:
nms_threshold = self.nms_threshold
if len(batch_outputs) == 1:
first_output = batch_outputs[0]
if isinstance(first_output, np.ndarray):
if first_output.ndim == 3 and first_output.shape[0] == batch_size:
# 已经是 (batch, 84, anchors) 格式
outputs_array = first_output
elif first_output.ndim == 1:
# TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 84, anchors)
per_image = first_output.shape[0] // batch_size
num_anchors = per_image // 84
outputs_array = first_output.reshape(batch_size, 84, num_anchors)
elif first_output.ndim == 2:
# (84, anchors) 单张图的输出
outputs_array = first_output.reshape(1, first_output.shape[0], first_output.shape[1])
else:
outputs_array = first_output
else:
outputs_array = np.array([first_output])
else:
outputs_array = np.array(batch_outputs)
results = []
nms_processor = NMSProcessor(nms_threshold, use_gpu=True)
for batch_idx in range(batch_size):
if outputs_array.ndim == 3:
if outputs_array.shape[0] == batch_size:
output = outputs_array[batch_idx]
else:
output = outputs_array[0]
elif outputs_array.ndim == 2:
output = outputs_array
else:
if batch_idx < len(outputs_array):
output = outputs_array[batch_idx]
else:
output = outputs_array[0]
if output.ndim == 3:
output = output[0]
if output.shape[0] != 84:
results.append((np.array([]), np.array([]), np.array([])))
continue
boxes_xywh = output[0:4, :].T
# YOLO11: 无 objectness直接取各类别最大分数
class_scores = output[4:, :]
scores = class_scores.max(axis=0)
class_ids_raw = class_scores.argmax(axis=0)
coarse_mask = scores > 0.3
if not np.any(coarse_mask):
results.append((np.array([]), np.array([]), np.array([])))
continue
boxes = boxes_xywh[coarse_mask]
scores_coarse = scores[coarse_mask]
class_ids = class_ids_raw[coarse_mask].astype(np.int32)
valid_count = len(boxes)
self._buffer_xyxy[:valid_count, 0] = boxes[:, 0] - boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 1] = boxes[:, 1] - boxes[:, 3] / 2
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
item_conf = (
per_item_conf_thresholds[batch_idx]
if per_item_conf_thresholds is not None
else default_conf
)
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
self._buffer_xyxy[:valid_count],
scores_coarse.astype(np.float32),
class_ids,
item_conf
)
results.append((keep_boxes, keep_scores, keep_classes))
return results
def filter_by_roi(
self,
boxes: np.ndarray,
scores: np.ndarray,
class_ids: np.ndarray,
roi: ROIInfo,
check_mode: str = "center"
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
按ROI筛选检测结果
Args:
boxes: 检测框
scores: 置信度
class_ids: 类别ID
roi: ROI配置
check_mode: 检查模式
Returns:
筛选后的结果
"""
valid_indices = self._roi_analyzer.filter_detections_in_roi(
boxes, roi, check_mode
)
return (
boxes[valid_indices],
scores[valid_indices],
class_ids[valid_indices]
)
def map_coordinates(
self,
boxes: np.ndarray,
scale_info: Tuple[float, float, float, float],
original_size: Tuple[int, int],
roi_crop_info: Optional[Dict[str, Any]] = None
) -> List[List[float]]:
"""
映射检测框坐标到原始帧
Args:
boxes: 检测框数组
scale_info: Letterbox缩放信息
original_size: 原始图像尺寸
roi_crop_info: ROI裁剪信息
Returns:
映射后的坐标列表
"""
mapped_boxes = []
for box in boxes:
if roi_crop_info:
mapped = self._mapper.map_from_letterbox(
box.tolist(), scale_info, original_size
)
mapped = self._mapper.map_roi_to_frame(mapped, roi_crop_info)
else:
mapped = self._mapper.map_from_letterbox(
box.tolist(), scale_info, original_size
)
mapped_boxes.append(mapped)
return mapped_boxes
def check_alarm_condition(
self,
roi_id: str,
has_target: bool,
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
检查告警条件
Args:
roi_id: ROI标识
has_target: 是否有目标
current_time: 当前时间
Returns:
告警检查结果
"""
return self._alarm_state_machine.update(roi_id, has_target, current_time)
def get_alarm_state(self, roi_id: str) -> AlarmState:
"""获取告警状态"""
return self._alarm_state_machine.get_state(roi_id)
def reset_alarm_state(self, roi_id: str):
"""重置告警状态"""
self._alarm_state_machine.reset(roi_id)
def create_alert_info(
self,
roi_id: str,
camera_id: str,
detection_results: Dict[str, Any],
message: str
) -> AlertInfo:
"""
创建告警信息
Args:
roi_id: ROI标识
camera_id: 摄像头标识
detection_results: 检测结果
message: 告警消息
Returns:
AlertInfo实例
"""
return AlertInfo(
alert_id=generate_unique_id("alert"),
camera_id=camera_id,
roi_id=roi_id,
alert_type=detection_results.get("alert_type", "detection"),
target_class=detection_results.get("class_name"),
confidence=detection_results.get("confidence"),
bbox=detection_results.get("bbox"),
message=message,
level=AlertLevel.MEDIUM,
timestamp=datetime.now().isoformat(),
)
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"nms_threshold": self.nms_threshold,
"conf_threshold": self.conf_threshold,
"alarm_states": {
roi_id: {
"detection_count": state.detection_count,
"is_alarming": state.is_alarming,
"alert_sent": state.alert_sent,
}
for roi_id, state in self._alarm_state_machine.get_all_states().items()
}
}