""" 后处理与业务逻辑模块 实现NMS算法、坐标映射、ROI区域判定、告警状态机 """ import logging import threading import time from collections import deque from dataclasses import dataclass, field from datetime import datetime from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np from config.settings import get_settings from config.config_models import ROIInfo, ROIType, AlertInfo, AlertLevel from utils.logger import get_logger from utils.common import generate_unique_id logger = logging.getLogger(__name__) class NMSProcessor: """非极大值抑制处理器 实现高效的NMS算法去除冗余检测框 """ def __init__(self, nms_threshold: float = 0.45): """ 初始化NMS处理器 Args: nms_threshold: NMS阈值 """ self.nms_threshold = nms_threshold self._logger = get_logger("postprocessor") def process( self, boxes: np.ndarray, scores: np.ndarray, class_ids: Optional[np.ndarray] = None, max_output_size: int = 300 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ 执行NMS Args: boxes: 检测框数组 [N, 4] (x1, y1, x2, y2) scores: 置信度数组 [N] class_ids: 类别ID数组 [N] max_output_size: 最大输出数量 Returns: tuple: (保留的框索引, 保留的置信度, 保留的类别ID) """ if len(boxes) == 0: return np.array([], dtype=np.int32), np.array([]), np.array([]) x1 = boxes[:, 0] y1 = boxes[:, 1] x2 = boxes[:, 2] y2 = boxes[:, 3] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = scores.argsort()[::-1] keep_indices = [] while len(order) > 0: if len(keep_indices) >= max_output_size: break i = order[0] keep_indices.append(i) if len(order) == 1: break remaining = order[1:] xx1 = np.maximum(x1[i], x1[remaining]) yy1 = np.maximum(y1[i], y1[remaining]) xx2 = np.minimum(x2[i], x2[remaining]) yy2 = np.minimum(y2[i], y2[remaining]) w = np.maximum(0.0, xx2 - xx1 + 1) h = np.maximum(0.0, yy2 - yy1 + 1) inter = w * h ovr = inter / (areas[i] + areas[remaining] - inter) indices = np.where(ovr <= self.nms_threshold)[0] order = remaining[indices] keep_indices = np.array(keep_indices, dtype=np.int32) return ( keep_indices, scores[keep_indices], class_ids[keep_indices] if class_ids is not None else np.array([]) ) def process_with_confidence_filter( self, boxes: np.ndarray, scores: np.ndarray, class_ids: np.ndarray, conf_threshold: float = 0.5, max_output_size: int = 300 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ 执行置信度过滤和NMS Args: boxes: 检测框数组 scores: 置信度数组 class_ids: 类别ID数组 conf_threshold: 置信度阈值 max_output_size: 最大输出数量 Returns: tuple: (保留的框, 置信度, 类别ID) """ mask = scores >= conf_threshold boxes = boxes[mask] scores = scores[mask] class_ids = class_ids[mask] if len(boxes) == 0: return np.array([]).reshape(0, 4), np.array([]), np.array([]) keep_indices, filtered_scores, filtered_class_ids = self.process( boxes, scores, class_ids, max_output_size ) return boxes[keep_indices], filtered_scores, filtered_class_ids class CoordinateMapper: """坐标映射系统 实现多级坐标映射: 480x480 → ROI空间 → 原始帧坐标 """ def __init__(self): self._logger = get_logger("postprocessor") def map_from_letterbox( self, box: List[float], scale_info: Tuple[float, float, float, float], original_size: Tuple[int, int] ) -> List[float]: """ 将检测框从Letterbox空间映射回原始空间 Args: box: Letterbox空间中的坐标 [x1, y1, x2, y2] scale_info: 缩放信息 (scale, pad_x, pad_y, scale) original_size: 原始图像尺寸 (width, height) Returns: 原始空间中的坐标 """ scale, pad_x, pad_y, _ = scale_info x1 = (box[0] - pad_x) / scale y1 = (box[1] - pad_y) / scale x2 = (box[2] - pad_x) / scale y2 = (box[3] - pad_y) / scale orig_width, orig_height = original_size x1 = max(0, min(x1, orig_width - 1)) y1 = max(0, min(y1, orig_height - 1)) x2 = max(0, min(x2, orig_width)) y2 = max(0, min(y2, orig_height)) return [x1, y1, x2, y2] def map_roi_to_frame( self, box: List[float], roi_crop_info: Dict[str, Any] ) -> List[float]: """ 将ROI空间中的坐标映射回原始帧坐标 Args: box: ROI空间中的坐标 roi_crop_info: ROI裁剪信息 Returns: 原始帧中的坐标 """ x1, y1, x2, y2 = box offset_x = roi_crop_info.get("offset_x", 0) offset_y = roi_crop_info.get("offset_y", 0) if "scale" in roi_crop_info: scale = roi_crop_info["scale"] x1 = x1 * scale + offset_x y1 = y1 * scale + offset_y x2 = x2 * scale + offset_x y2 = y2 * scale + offset_y else: x1 += offset_x y1 += offset_y x2 += offset_x y2 += offset_y return [x1, y1, x2, y2] def get_box_center(self, box: List[float]) -> Tuple[float, float]: """获取检测框中心点""" return (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 def get_box_bottom_center(self, box: List[float]) -> Tuple[float, float]: """获取检测框底部中心点""" return (box[0] + box[2]) / 2, box[3] def get_box_area(self, box: List[float]) -> float: """计算检测框面积""" width = box[2] - box[0] height = box[3] - box[1] return width * height class ROIAnalyzer: """ROI区域分析器 实现ROI区域内目标检测与判定 """ def __init__(self): self._logger = get_logger("postprocessor") def is_point_in_roi( self, point: Tuple[float, float], roi: ROIInfo ) -> bool: """ 判断点是否在ROI区域内 Args: point: (x, y) 坐标 roi: ROI配置 Returns: 是否在区域内 """ return roi.is_point_inside(list(point)) def is_detection_in_roi( self, box: List[float], roi: ROIInfo, check_mode: str = "center" ) -> bool: """ 判断检测是否在ROI区域内 Args: box: 检测框坐标 roi: ROI配置 check_mode: 检查模式 (center, bottom, any) Returns: 是否在区域内 """ if check_mode == "center": point = self._get_box_center(box) elif check_mode == "bottom": point = self._get_box_bottom_center(box) else: point = self._get_box_center(box) return self.is_point_in_roi(point, roi) def _get_box_center(self, box: List[float]) -> Tuple[float, float]: """获取检测框中心点""" return (box[0] + box[2]) / 2, (box[1] + box[3]) / 2 def _get_box_bottom_center(self, box: List[float]) -> Tuple[float, float]: """获取检测框底部中心点""" return (box[0] + box[2]) / 2, box[3] def filter_detections_in_roi( self, boxes: np.ndarray, roi: ROIInfo, check_mode: str = "center" ) -> np.ndarray: """ 筛选在ROI区域内的检测 Args: boxes: 检测框数组 roi: ROI配置 check_mode: 检查模式 Returns: 符合条件的检测框索引 """ valid_indices = [] for i, box in enumerate(boxes): if self.is_detection_in_roi(box.tolist(), roi, check_mode): valid_indices.append(i) return np.array(valid_indices) def get_roi_crop_info( self, frame_width: int, frame_height: int, roi: ROIInfo ) -> Dict[str, Any]: """ 获取ROI裁剪信息 Args: frame_width: 帧宽度 frame_height: 帧高度 roi: ROI配置 Returns: 裁剪信息字典 """ if roi.roi_type == ROIType.RECTANGLE and len(roi.coordinates) >= 2: x1 = int(roi.coordinates[0][0]) y1 = int(roi.coordinates[0][1]) x2 = int(roi.coordinates[1][0]) y2 = int(roi.coordinates[1][1]) offset_x = min(x1, x2) offset_y = min(y1, y2) crop_width = abs(x2 - x1) crop_height = abs(y2 - y1) elif roi.roi_type == ROIType.POLYGON: xs = [c[0] for c in roi.coordinates] ys = [c[1] for c in roi.coordinates] offset_x = min(xs) offset_y = min(ys) crop_width = max(xs) - offset_x crop_height = max(ys) - offset_y else: offset_x = offset_y = 0 crop_width = frame_width crop_height = frame_height return { "offset_x": offset_x, "offset_y": offset_y, "crop_width": crop_width, "crop_height": crop_height, "frame_width": frame_width, "frame_height": frame_height, } @dataclass class AlarmState: """告警状态""" roi_id: str detection_count: int = 0 last_detection_time: Optional[datetime] = None is_alarming: bool = False alert_sent: bool = False frame_buffer: deque = field(default_factory=deque) class AlarmStateMachine: """告警状态机 实现连续帧检测触发机制,防止瞬时误报 """ def __init__(self, alert_threshold: int = 3, alert_cooldown: int = 300): """ 初始化告警状态机 Args: alert_threshold: 连续检测触发阈值 alert_cooldown: 告警冷却时间(秒) """ self.alert_threshold = alert_threshold self.alert_cooldown = alert_cooldown self._states: Dict[str, AlarmState] = {} self._lock = threading.Lock() self._logger = get_logger("postprocessor") def get_state(self, roi_id: str) -> AlarmState: """获取ROI告警状态""" with self._lock: if roi_id not in self._states: self._states[roi_id] = AlarmState(roi_id=roi_id) return self._states[roi_id] def update( self, roi_id: str, has_target: bool, current_time: Optional[datetime] = None ) -> Dict[str, Any]: """ 更新告警状态 Args: roi_id: ROI标识 has_target: 是否有目标 current_time: 当前时间 Returns: 状态更新结果 """ if current_time is None: current_time = datetime.now() with self._lock: state = self.get_state(roi_id) if state.is_alarming and state.alert_sent: cooldown_elapsed = ( current_time - state.last_detection_time ).total_seconds() if state.last_detection_time else float('inf') if cooldown_elapsed < self.alert_cooldown: return { "should_alert": False, "reason": "cooldown", "cooldown_remaining": self.alert_cooldown - cooldown_elapsed, "state": state, } else: state.alert_sent = False state.is_alarming = False state.detection_count = 0 if has_target: state.detection_count += 1 state.last_detection_time = current_time while len(state.frame_buffer) >= self.alert_threshold: state.frame_buffer.popleft() state.frame_buffer.append(current_time) if (state.detection_count >= self.alert_threshold and not state.is_alarming): state.is_alarming = True self._logger.log_alert( "detection_threshold_reached", roi_id=roi_id, camera_id="", confidence=None ) return { "should_alert": True, "reason": "threshold_reached", "detection_count": state.detection_count, "state": state, } else: if state.detection_count > 0: state.detection_count = max(0, state.detection_count - 1) if not has_target and state.is_alarming: state.is_alarming = False return { "should_alert": False, "reason": "waiting", "detection_count": state.detection_count, "state": state, } def reset(self, roi_id: str): """重置ROI告警状态""" with self._lock: if roi_id in self._states: self._states[roi_id] = AlarmState(roi_id=roi_id) self._logger.info(f"告警状态已重置: {roi_id}") def reset_all(self): """重置所有告警状态""" with self._lock: self._states.clear() self._logger.info("所有告警状态已重置") def get_all_states(self) -> Dict[str, AlarmState]: """获取所有状态""" with self._lock: return dict(self._states) class PostProcessor: """后处理主类 整合NMS、坐标映射、ROI分析、告警状态机 """ def __init__(self, config: Optional[Dict[str, Any]] = None): """ 初始化后处理器 Args: config: 配置参数字典 """ if config is None: settings = get_settings() config = { "nms_threshold": settings.inference.nms_threshold, "conf_threshold": settings.inference.conf_threshold, "alert_threshold": 3, "alert_cooldown": 300, } self.nms_threshold = config.get("nms_threshold", 0.45) self.conf_threshold = config.get("conf_threshold", 0.5) self._nms = NMSProcessor(self.nms_threshold) self._mapper = CoordinateMapper() self._roi_analyzer = ROIAnalyzer() self._alarm_state_machine = AlarmStateMachine( alert_threshold=config.get("alert_threshold", 3), alert_cooldown=config.get("alert_cooldown", 300) ) self._logger = get_logger("postprocessor") self._logger.info( f"后处理器初始化完成: NMS阈值={self.nms_threshold}, " f"置信度阈值={self.conf_threshold}" ) def process_detections( self, raw_outputs: List[np.ndarray], conf_threshold: Optional[float] = None, nms_threshold: Optional[float] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ 处理模型原始输出 Args: raw_outputs: 模型原始输出 conf_threshold: 置信度阈值 nms_threshold: NMS阈值 Returns: tuple: (检测框, 置信度, 类别ID) """ if conf_threshold is None: conf_threshold = self.conf_threshold if nms_threshold is None: nms_threshold = self.nms_threshold boxes, scores, class_ids = self._parse_yolo_output(raw_outputs) if len(boxes) == 0: return np.array([]), np.array([]), np.array([]) nms_processor = NMSProcessor(nms_threshold) keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter( boxes, scores, class_ids, conf_threshold ) return keep_boxes, keep_scores, keep_classes def _parse_yolo_output( self, outputs: List[np.ndarray] ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ 解析YOLO模型输出 Args: outputs: 模型输出列表 Returns: tuple: (检测框, 置信度, 类别ID) """ if not outputs: return np.array([]), np.array([]), np.array([]) output = outputs[0] if output.ndim == 3: output = output[0] if output.ndim != 2: return np.array([]), np.array([]), np.array([]) if output.shape[0] != 84: return np.array([]), np.array([]), np.array([]) num_boxes = output.shape[1] boxes_xywh = output[0:4, :].T obj_conf = output[4, :] cls_scores = output[5:, :] cls_ids = np.argmax(cls_scores, axis=0) cls_conf = cls_scores[cls_ids, np.arange(num_boxes)] scores = obj_conf * cls_conf valid_mask = scores > self._conf_threshold if not np.any(valid_mask): return np.array([]), np.array([]), np.array([]) boxes = boxes_xywh[valid_mask] scores = scores[valid_mask] class_ids = cls_ids[valid_mask] boxes_xyxy = np.zeros_like(boxes) boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2] / 2 boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3] / 2 boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2] / 2 boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3] / 2 return ( boxes_xyxy.astype(np.float32), scores.astype(np.float32), class_ids.astype(np.int32) ) def filter_by_roi( self, boxes: np.ndarray, scores: np.ndarray, class_ids: np.ndarray, roi: ROIInfo, check_mode: str = "center" ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ 按ROI筛选检测结果 Args: boxes: 检测框 scores: 置信度 class_ids: 类别ID roi: ROI配置 check_mode: 检查模式 Returns: 筛选后的结果 """ valid_indices = self._roi_analyzer.filter_detections_in_roi( boxes, roi, check_mode ) return ( boxes[valid_indices], scores[valid_indices], class_ids[valid_indices] ) def map_coordinates( self, boxes: np.ndarray, scale_info: Tuple[float, float, float, float], original_size: Tuple[int, int], roi_crop_info: Optional[Dict[str, Any]] = None ) -> List[List[float]]: """ 映射检测框坐标到原始帧 Args: boxes: 检测框数组 scale_info: Letterbox缩放信息 original_size: 原始图像尺寸 roi_crop_info: ROI裁剪信息 Returns: 映射后的坐标列表 """ mapped_boxes = [] for box in boxes: if roi_crop_info: mapped = self._mapper.map_from_letterbox( box.tolist(), scale_info, original_size ) mapped = self._mapper.map_roi_to_frame(mapped, roi_crop_info) else: mapped = self._mapper.map_from_letterbox( box.tolist(), scale_info, original_size ) mapped_boxes.append(mapped) return mapped_boxes def check_alarm_condition( self, roi_id: str, has_target: bool, current_time: Optional[datetime] = None ) -> Dict[str, Any]: """ 检查告警条件 Args: roi_id: ROI标识 has_target: 是否有目标 current_time: 当前时间 Returns: 告警检查结果 """ return self._alarm_state_machine.update(roi_id, has_target, current_time) def get_alarm_state(self, roi_id: str) -> AlarmState: """获取告警状态""" return self._alarm_state_machine.get_state(roi_id) def reset_alarm_state(self, roi_id: str): """重置告警状态""" self._alarm_state_machine.reset(roi_id) def create_alert_info( self, roi_id: str, camera_id: str, detection_results: Dict[str, Any], message: str ) -> AlertInfo: """ 创建告警信息 Args: roi_id: ROI标识 camera_id: 摄像头标识 detection_results: 检测结果 message: 告警消息 Returns: AlertInfo实例 """ return AlertInfo( alert_id=generate_unique_id("alert"), camera_id=camera_id, roi_id=roi_id, alert_type=detection_results.get("alert_type", "detection"), target_class=detection_results.get("class_name"), confidence=detection_results.get("confidence"), bbox=detection_results.get("bbox"), message=message, level=AlertLevel.MEDIUM, timestamp=datetime.now().isoformat(), ) def get_statistics(self) -> Dict[str, Any]: """获取统计信息""" return { "nms_threshold": self.nms_threshold, "conf_threshold": self.conf_threshold, "alarm_states": { roi_id: { "detection_count": state.detection_count, "is_alarming": state.is_alarming, "alert_sent": state.alert_sent, } for roi_id, state in self._alarm_state_machine.get_all_states().items() } }