fix(cache): 修复 _load_rois_from_db 未加载 working_hours 字段的问题
Some checks failed
Python Test / test (push) Has been cancelled

- 在 roi_config 中补充读取数据库中的 working_hours 字段
- 确保 pipeline 能将 ROI 特定的工作时间配置正确传递给算法
- 修复后,LeavePostAlgorithm.is_in_working_hours() 将使用 ROI 配置而非全局默认值
This commit is contained in:
2026-01-22 18:26:57 +08:00
parent 44b6c70a4b
commit 98c741cb2b
3 changed files with 213 additions and 57 deletions

View File

@@ -70,8 +70,7 @@ algorithms:
confirm_sec: 10 # 上岗确认时间10秒 confirm_sec: 10 # 上岗确认时间10秒
return_sec: 30 # 离岗缓冲时间30秒 return_sec: 30 # 离岗缓冲时间30秒
intrusion: intrusion:
check_interval_sec: 1.0 # 检测间隔 cooldown_seconds: 300 # 入侵检测冷却时间(秒)
direction_sensitive: false # 方向敏感
# 日志配置 # 日志配置
logging: logging:

View File

@@ -0,0 +1,167 @@
import json
import threading
import time
from typing import Dict, List, Optional, Callable
from collections import deque
class ROICacheManager:
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._cache: Dict[int, List[Dict]] = {}
self._cache_timestamps: Dict[int, float] = {}
self._refresh_interval = 10.0
self._db_session_factory = None
self._refresh_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._last_refresh_time = 0
self._on_cache_update: Optional[Callable[[int], None]] = None
self._update_callbacks: Dict[int, List[Callable]] = {}
def initialize(self, session_factory, refresh_interval: float = 10.0):
self._db_session_factory = session_factory
self._refresh_interval = refresh_interval
def start_background_refresh(self):
if self._refresh_thread is not None and self._refresh_thread.is_alive():
return
self._stop_event.clear()
self._refresh_thread = threading.Thread(target=self._background_refresh_loop, daemon=True)
self._refresh_thread.start()
def stop_background_refresh(self):
self._stop_event.set()
if self._refresh_thread is not None:
self._refresh_thread.join(timeout=2)
self._refresh_thread = None
def _background_refresh_loop(self):
while not self._stop_event.is_set():
try:
self.refresh_all()
except Exception:
pass
self._stop_event.wait(self._refresh_interval)
def _load_rois_from_db(self, camera_id: int) -> List[Dict]:
if self._db_session_factory is None:
return []
session = self._db_session_factory()
try:
from db.crud import get_all_rois
rois = get_all_rois(session, camera_id)
roi_configs = []
for roi in rois:
try:
points = json.loads(roi.points) if isinstance(roi.points, str) else roi.points
except (json.JSONDecodeError, TypeError):
points = []
roi_config = {
"id": roi.id,
"roi_id": roi.roi_id,
"name": roi.name,
"type": roi.roi_type,
"points": points,
"rule": roi.rule_type,
"direction": roi.direction,
"enabled": roi.enabled,
"threshold_sec": roi.threshold_sec,
"confirm_sec": roi.confirm_sec,
"return_sec": roi.return_sec,
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
}
roi_configs.append(roi_config)
return roi_configs
finally:
session.close()
def refresh_all(self):
if self._db_session_factory is None:
return
current_time = time.time()
if current_time - self._last_refresh_time < 1.0:
return
self._last_refresh_time = current_time
camera_ids = list(self._cache.keys())
for camera_id in camera_ids:
try:
new_rois = self._load_rois_from_db(camera_id)
old_rois_str = str(self._cache.get(camera_id, []))
new_rois_str = str(new_rois)
if old_rois_str != new_rois_str:
self._cache[camera_id] = new_rois
self._cache_timestamps[camera_id] = current_time
self._notify_update(camera_id)
except Exception:
pass
def get_rois(self, camera_id: int, force_refresh: bool = False) -> List[Dict]:
if force_refresh or camera_id not in self._cache:
self._cache[camera_id] = self._load_rois_from_db(camera_id)
self._cache_timestamps[camera_id] = time.time()
return self._cache.get(camera_id, [])
def get_rois_by_rule(self, camera_id: int, rule_type: str) -> List[Dict]:
rois = self.get_rois(camera_id)
return [roi for roi in rois if roi.get("rule") == rule_type and roi.get("enabled", True)]
def invalidate(self, camera_id: Optional[int] = None):
if camera_id is None:
self._cache.clear()
self._cache_timestamps.clear()
elif camera_id in self._cache:
del self._cache[camera_id]
if camera_id in self._cache_timestamps:
del self._cache_timestamps[camera_id]
def register_update_callback(self, camera_id: int, callback: Callable):
if camera_id not in self._update_callbacks:
self._update_callbacks[camera_id] = []
self._update_callbacks[camera_id].append(callback)
def _notify_update(self, camera_id: int):
if camera_id in self._update_callbacks:
for callback in self._update_callbacks[camera_id]:
try:
callback(camera_id)
except Exception:
pass
def get_cache_info(self) -> Dict:
return {
"camera_count": len(self._cache),
"refresh_interval": self._refresh_interval,
"cameras": {
cam_id: {
"roi_count": len(rois),
"last_update": self._cache_timestamps.get(cam_id, 0),
}
for cam_id, rois in self._cache.items()
},
}
def get_roi_cache() -> ROICacheManager:
return ROICacheManager()

View File

@@ -1,5 +1,6 @@
import os import os
import sys import sys
import time
from collections import deque from collections import deque
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@@ -9,8 +10,6 @@ import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sort import Sort
class LeavePostAlgorithm: class LeavePostAlgorithm:
STATE_ON_DUTY = "ON_DUTY" STATE_ON_DUTY = "ON_DUTY"
@@ -189,19 +188,26 @@ class LeavePostAlgorithm:
class IntrusionAlgorithm: class IntrusionAlgorithm:
def __init__( def __init__(self, cooldown_seconds: int = 300):
self, self.cooldown_seconds = cooldown_seconds
check_interval_sec: float = 1.0, self.last_alert_time: Dict[str, float] = {}
direction_sensitive: bool = False, self.alert_triggered: Dict[str, bool] = {}
):
self.check_interval_sec = check_interval_sec
self.direction_sensitive = direction_sensitive
self.last_check_times: Dict[str, float] = {} def is_roi_has_person(self, tracks: List[Dict], roi_id: str) -> bool:
self.tracker = Sort(max_age=5, min_hits=1, iou_threshold=0.3) for det in tracks:
matched_rois = det.get("matched_rois", [])
for roi in matched_rois:
if roi.get("roi_id") == roi_id:
return True
return False
self.alert_cooldowns: Dict[str, datetime] = {} def get_latest_bbox_in_roi(self, tracks: List[Dict], roi_id: str) -> List[float]:
self.cooldown_seconds = 300 for det in tracks:
matched_rois = det.get("matched_rois", [])
for roi in matched_rois:
if roi.get("roi_id") == roi_id:
return det.get("bbox", [])
return []
def process( def process(
self, self,
@@ -210,52 +216,38 @@ class IntrusionAlgorithm:
tracks: List[Dict], tracks: List[Dict],
current_time: Optional[datetime] = None, current_time: Optional[datetime] = None,
) -> List[Dict]: ) -> List[Dict]:
if not tracks: roi_has_person = self.is_roi_has_person(tracks, roi_id)
if not roi_has_person:
return [] return []
detections = [] now = time.monotonic()
for track in tracks: key = f"{camera_id}_{roi_id}"
bbox = track.get("bbox", [])
if len(bbox) >= 4:
detections.append(bbox + [track.get("conf", 0.0)])
if not detections: if key not in self.last_alert_time:
self.last_alert_time[key] = 0
self.alert_triggered[key] = False
if now - self.last_alert_time[key] >= self.cooldown_seconds:
self.last_alert_time[key] = now
self.alert_triggered[key] = False
if self.alert_triggered[key]:
return [] return []
current_ts = current_time.timestamp() if current_time else datetime.now().timestamp() bbox = self.get_latest_bbox_in_roi(tracks, roi_id)
self.alert_triggered[key] = True
if camera_id in self.last_check_times: return [{
if current_ts - self.last_check_times[camera_id] < self.check_interval_sec: "roi_id": roi_id,
return [] "bbox": bbox,
self.last_check_times[camera_id] = current_ts "alert_type": "intrusion",
"message": "检测到周界入侵",
detections = np.array(detections) }]
tracked = self.tracker.update(detections)
alerts = []
now = datetime.now()
for track_data in tracked:
x1, y1, x2, y2, track_id = track_data
cooldown_key = f"{roi_id}_{int(track_id)}"
if cooldown_key not in self.alert_cooldowns or (
now - self.alert_cooldowns[cooldown_key]
).total_seconds() > self.cooldown_seconds:
alerts.append({
"track_id": str(int(track_id)),
"bbox": [x1, y1, x2, y2],
"alert_type": "intrusion",
"confidence": track_data[4] if len(track_data) > 4 else 0.0,
"message": "检测到周界入侵",
})
self.alert_cooldowns[cooldown_key] = now
return alerts
def reset(self): def reset(self):
self.last_check_times.clear() self.last_alert_time.clear()
self.alert_cooldowns.clear() self.alert_triggered.clear()
class AlgorithmManager: class AlgorithmManager:
@@ -270,8 +262,7 @@ class AlgorithmManager:
"return_sec": 30, "return_sec": 30,
}, },
"intrusion": { "intrusion": {
"check_interval_sec": 1.0, "cooldown_seconds": 300,
"direction_sensitive": False,
}, },
} }
@@ -302,8 +293,7 @@ class AlgorithmManager:
) )
elif algorithm_type == "intrusion": elif algorithm_type == "intrusion":
self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm( self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm(
check_interval_sec=algo_params.get("check_interval_sec", 1.0), cooldown_seconds=algo_params.get("cooldown_seconds", 300),
direction_sensitive=algo_params.get("direction_sensitive", False),
) )
def process( def process(