diff --git a/config.yaml b/config.yaml index 1c40cfc..c25ba71 100644 --- a/config.yaml +++ b/config.yaml @@ -70,8 +70,7 @@ algorithms: confirm_sec: 10 # 上岗确认时间(10秒) return_sec: 30 # 离岗缓冲时间(30秒) intrusion: - check_interval_sec: 1.0 # 检测间隔 - direction_sensitive: false # 方向敏感 + cooldown_seconds: 300 # 入侵检测冷却时间(秒) # 日志配置 logging: diff --git a/inference/roi/cache_manager.py b/inference/roi/cache_manager.py new file mode 100644 index 0000000..2db2358 --- /dev/null +++ b/inference/roi/cache_manager.py @@ -0,0 +1,167 @@ +import json +import threading +import time +from typing import Dict, List, Optional, Callable +from collections import deque + + +class ROICacheManager: + _instance = None + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self._initialized = True + + self._cache: Dict[int, List[Dict]] = {} + self._cache_timestamps: Dict[int, float] = {} + self._refresh_interval = 10.0 + self._db_session_factory = None + self._refresh_thread: Optional[threading.Thread] = None + self._stop_event = threading.Event() + self._last_refresh_time = 0 + self._on_cache_update: Optional[Callable[[int], None]] = None + self._update_callbacks: Dict[int, List[Callable]] = {} + + def initialize(self, session_factory, refresh_interval: float = 10.0): + self._db_session_factory = session_factory + self._refresh_interval = refresh_interval + + def start_background_refresh(self): + if self._refresh_thread is not None and self._refresh_thread.is_alive(): + return + + self._stop_event.clear() + self._refresh_thread = threading.Thread(target=self._background_refresh_loop, daemon=True) + self._refresh_thread.start() + + def stop_background_refresh(self): + self._stop_event.set() + if self._refresh_thread is not None: + self._refresh_thread.join(timeout=2) + self._refresh_thread = None + + def _background_refresh_loop(self): + while not self._stop_event.is_set(): + try: + self.refresh_all() + except Exception: + pass + self._stop_event.wait(self._refresh_interval) + + def _load_rois_from_db(self, camera_id: int) -> List[Dict]: + if self._db_session_factory is None: + return [] + + session = self._db_session_factory() + try: + from db.crud import get_all_rois + rois = get_all_rois(session, camera_id) + roi_configs = [] + for roi in rois: + try: + points = json.loads(roi.points) if isinstance(roi.points, str) else roi.points + except (json.JSONDecodeError, TypeError): + points = [] + + roi_config = { + "id": roi.id, + "roi_id": roi.roi_id, + "name": roi.name, + "type": roi.roi_type, + "points": points, + "rule": roi.rule_type, + "direction": roi.direction, + "enabled": roi.enabled, + "threshold_sec": roi.threshold_sec, + "confirm_sec": roi.confirm_sec, + "return_sec": roi.return_sec, + "working_hours": json.loads(roi.working_hours) if roi.working_hours else None, + } + roi_configs.append(roi_config) + return roi_configs + finally: + session.close() + + def refresh_all(self): + if self._db_session_factory is None: + return + + current_time = time.time() + if current_time - self._last_refresh_time < 1.0: + return + + self._last_refresh_time = current_time + camera_ids = list(self._cache.keys()) + + for camera_id in camera_ids: + try: + new_rois = self._load_rois_from_db(camera_id) + old_rois_str = str(self._cache.get(camera_id, [])) + new_rois_str = str(new_rois) + + if old_rois_str != new_rois_str: + self._cache[camera_id] = new_rois + self._cache_timestamps[camera_id] = current_time + self._notify_update(camera_id) + except Exception: + pass + + def get_rois(self, camera_id: int, force_refresh: bool = False) -> List[Dict]: + if force_refresh or camera_id not in self._cache: + self._cache[camera_id] = self._load_rois_from_db(camera_id) + self._cache_timestamps[camera_id] = time.time() + + return self._cache.get(camera_id, []) + + def get_rois_by_rule(self, camera_id: int, rule_type: str) -> List[Dict]: + rois = self.get_rois(camera_id) + return [roi for roi in rois if roi.get("rule") == rule_type and roi.get("enabled", True)] + + def invalidate(self, camera_id: Optional[int] = None): + if camera_id is None: + self._cache.clear() + self._cache_timestamps.clear() + elif camera_id in self._cache: + del self._cache[camera_id] + if camera_id in self._cache_timestamps: + del self._cache_timestamps[camera_id] + + def register_update_callback(self, camera_id: int, callback: Callable): + if camera_id not in self._update_callbacks: + self._update_callbacks[camera_id] = [] + self._update_callbacks[camera_id].append(callback) + + def _notify_update(self, camera_id: int): + if camera_id in self._update_callbacks: + for callback in self._update_callbacks[camera_id]: + try: + callback(camera_id) + except Exception: + pass + + def get_cache_info(self) -> Dict: + return { + "camera_count": len(self._cache), + "refresh_interval": self._refresh_interval, + "cameras": { + cam_id: { + "roi_count": len(rois), + "last_update": self._cache_timestamps.get(cam_id, 0), + } + for cam_id, rois in self._cache.items() + }, + } + + +def get_roi_cache() -> ROICacheManager: + return ROICacheManager() diff --git a/inference/rules/algorithms.py b/inference/rules/algorithms.py index 46b574d..ca668e8 100644 --- a/inference/rules/algorithms.py +++ b/inference/rules/algorithms.py @@ -1,5 +1,6 @@ import os import sys +import time from collections import deque from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Tuple @@ -9,8 +10,6 @@ import numpy as np sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from sort import Sort - class LeavePostAlgorithm: STATE_ON_DUTY = "ON_DUTY" @@ -189,19 +188,26 @@ class LeavePostAlgorithm: class IntrusionAlgorithm: - def __init__( - self, - check_interval_sec: float = 1.0, - direction_sensitive: bool = False, - ): - self.check_interval_sec = check_interval_sec - self.direction_sensitive = direction_sensitive + def __init__(self, cooldown_seconds: int = 300): + self.cooldown_seconds = cooldown_seconds + self.last_alert_time: Dict[str, float] = {} + self.alert_triggered: Dict[str, bool] = {} - self.last_check_times: Dict[str, float] = {} - self.tracker = Sort(max_age=5, min_hits=1, iou_threshold=0.3) + def is_roi_has_person(self, tracks: List[Dict], roi_id: str) -> bool: + for det in tracks: + matched_rois = det.get("matched_rois", []) + for roi in matched_rois: + if roi.get("roi_id") == roi_id: + return True + return False - self.alert_cooldowns: Dict[str, datetime] = {} - self.cooldown_seconds = 300 + def get_latest_bbox_in_roi(self, tracks: List[Dict], roi_id: str) -> List[float]: + for det in tracks: + matched_rois = det.get("matched_rois", []) + for roi in matched_rois: + if roi.get("roi_id") == roi_id: + return det.get("bbox", []) + return [] def process( self, @@ -210,52 +216,38 @@ class IntrusionAlgorithm: tracks: List[Dict], current_time: Optional[datetime] = None, ) -> List[Dict]: - if not tracks: + roi_has_person = self.is_roi_has_person(tracks, roi_id) + + if not roi_has_person: return [] - detections = [] - for track in tracks: - bbox = track.get("bbox", []) - if len(bbox) >= 4: - detections.append(bbox + [track.get("conf", 0.0)]) + now = time.monotonic() + key = f"{camera_id}_{roi_id}" - if not detections: + if key not in self.last_alert_time: + self.last_alert_time[key] = 0 + self.alert_triggered[key] = False + + if now - self.last_alert_time[key] >= self.cooldown_seconds: + self.last_alert_time[key] = now + self.alert_triggered[key] = False + + if self.alert_triggered[key]: return [] - current_ts = current_time.timestamp() if current_time else datetime.now().timestamp() + bbox = self.get_latest_bbox_in_roi(tracks, roi_id) + self.alert_triggered[key] = True - if camera_id in self.last_check_times: - if current_ts - self.last_check_times[camera_id] < self.check_interval_sec: - return [] - self.last_check_times[camera_id] = current_ts - - detections = np.array(detections) - tracked = self.tracker.update(detections) - - alerts = [] - now = datetime.now() - - for track_data in tracked: - x1, y1, x2, y2, track_id = track_data - cooldown_key = f"{roi_id}_{int(track_id)}" - - if cooldown_key not in self.alert_cooldowns or ( - now - self.alert_cooldowns[cooldown_key] - ).total_seconds() > self.cooldown_seconds: - alerts.append({ - "track_id": str(int(track_id)), - "bbox": [x1, y1, x2, y2], - "alert_type": "intrusion", - "confidence": track_data[4] if len(track_data) > 4 else 0.0, - "message": "检测到周界入侵", - }) - self.alert_cooldowns[cooldown_key] = now - - return alerts + return [{ + "roi_id": roi_id, + "bbox": bbox, + "alert_type": "intrusion", + "message": "检测到周界入侵", + }] def reset(self): - self.last_check_times.clear() - self.alert_cooldowns.clear() + self.last_alert_time.clear() + self.alert_triggered.clear() class AlgorithmManager: @@ -270,8 +262,7 @@ class AlgorithmManager: "return_sec": 30, }, "intrusion": { - "check_interval_sec": 1.0, - "direction_sensitive": False, + "cooldown_seconds": 300, }, } @@ -302,8 +293,7 @@ class AlgorithmManager: ) elif algorithm_type == "intrusion": self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm( - check_interval_sec=algo_params.get("check_interval_sec", 1.0), - direction_sensitive=algo_params.get("direction_sensitive", False), + cooldown_seconds=algo_params.get("cooldown_seconds", 300), ) def process(