feat: 重构数据库schema分离空间与业务配置 - 新增algorithm_registry和roi_algo_bind表 - roi_configs简化为纯空间配置 - 新增AlgorithmInfo/ROIAlgoBind数据模型
This commit is contained in:
313
algorithms.py
313
algorithms.py
@@ -290,6 +290,84 @@ class IntrusionAlgorithm:
|
||||
self.detection_start.clear()
|
||||
|
||||
|
||||
class CrowdDetectionAlgorithm:
|
||||
"""人群聚集检测算法"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_count: int = 10,
|
||||
cooldown_seconds: int = 300,
|
||||
target_class: Optional[str] = "person",
|
||||
):
|
||||
self.max_count = max_count
|
||||
self.cooldown_seconds = cooldown_seconds
|
||||
self.target_class = target_class
|
||||
|
||||
self.last_alert_time: Dict[str, datetime] = {}
|
||||
self.alert_triggered: Dict[str, bool] = {}
|
||||
|
||||
def _check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
for roi in matched_rois:
|
||||
if roi.get("roi_id") == roi_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_target_class(self, detection: Dict, target_class: Optional[str]) -> bool:
|
||||
if not target_class:
|
||||
return True
|
||||
return detection.get("class") == target_class
|
||||
|
||||
def _get_bboxes(self, tracks: List[Dict], roi_id: str) -> List[List[float]]:
|
||||
bboxes = []
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_class(det, self.target_class):
|
||||
bboxes.append(det.get("bbox", []))
|
||||
return bboxes
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
current_time = current_time or datetime.now()
|
||||
key = f"{camera_id}_{roi_id}"
|
||||
|
||||
person_count = 0
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_class(det, self.target_class):
|
||||
person_count += 1
|
||||
|
||||
if person_count <= self.max_count:
|
||||
self.alert_triggered[key] = False
|
||||
return []
|
||||
|
||||
if self.alert_triggered.get(key, False):
|
||||
elapsed_since_alert = (current_time - self.last_alert_time.get(key, datetime.min)).total_seconds()
|
||||
if elapsed_since_alert < self.cooldown_seconds:
|
||||
return []
|
||||
self.alert_triggered[key] = False
|
||||
|
||||
bboxes = self._get_bboxes(tracks, roi_id)
|
||||
self.last_alert_time[key] = current_time
|
||||
self.alert_triggered[key] = True
|
||||
|
||||
return [{
|
||||
"roi_id": roi_id,
|
||||
"camera_id": camera_id,
|
||||
"bbox": bboxes[0] if bboxes else [],
|
||||
"alert_type": "crowd_detection",
|
||||
"message": f"检测到人群聚集,当前人数: {person_count}",
|
||||
"count": person_count,
|
||||
}]
|
||||
|
||||
def reset(self):
|
||||
self.last_alert_time.clear()
|
||||
self.alert_triggered.clear()
|
||||
|
||||
|
||||
class AlgorithmManager:
|
||||
def __init__(self, working_hours: Optional[List[Dict]] = None):
|
||||
self.algorithms: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -308,6 +386,11 @@ class AlgorithmManager:
|
||||
"confirm_seconds": 5,
|
||||
"target_class": None,
|
||||
},
|
||||
"crowd_detection": {
|
||||
"max_count": 10,
|
||||
"cooldown_seconds": 300,
|
||||
"target_class": "person",
|
||||
},
|
||||
}
|
||||
|
||||
self._pubsub = None
|
||||
@@ -358,6 +441,13 @@ class AlgorithmManager:
|
||||
self.reload_algorithm(roi_id)
|
||||
else:
|
||||
self.reload_all_algorithms()
|
||||
elif data.get("type") == "bind":
|
||||
bind_ids = data.get("ids", [])
|
||||
if bind_ids:
|
||||
for bind_id in bind_ids:
|
||||
self.reload_bind_algorithm(bind_id)
|
||||
else:
|
||||
self.reload_all_algorithms()
|
||||
except Exception as e:
|
||||
logger.error(f"处理配置更新消息失败: {e}")
|
||||
except Exception as e:
|
||||
@@ -372,59 +462,95 @@ class AlgorithmManager:
|
||||
self._pubsub_thread.join(timeout=5)
|
||||
logger.info("配置订阅已停止")
|
||||
|
||||
def load_from_redis(self, roi_id: str) -> bool:
|
||||
"""从Redis加载单个ROI的算法配置"""
|
||||
def load_bind_from_redis(self, bind_id: str) -> bool:
|
||||
"""从Redis加载单个绑定配置的算法"""
|
||||
try:
|
||||
from core.config_sync import get_config_sync_manager
|
||||
config_manager = get_config_sync_manager()
|
||||
roi_config = config_manager.get_roi_from_redis(roi_id)
|
||||
bind_config = config_manager.get_algo_bind_from_redis(bind_id)
|
||||
|
||||
if not roi_config:
|
||||
if not bind_config:
|
||||
return False
|
||||
|
||||
with self._update_lock:
|
||||
algorithm_type = roi_config.get("algorithm_type", "leave_post")
|
||||
roi_id = bind_config.get("roi_id")
|
||||
algo_code = bind_config.get("algo_code", "leave_post")
|
||||
params = bind_config.get("params", {})
|
||||
|
||||
if algorithm_type == "leave_post":
|
||||
params = {
|
||||
"working_hours": roi_config.get("working_hours"),
|
||||
"confirm_on_duty_sec": roi_config.get("confirm_on_duty_sec", 10),
|
||||
"confirm_leave_sec": roi_config.get("confirm_leave_sec", 10),
|
||||
"cooldown_sec": roi_config.get("cooldown_sec", 300),
|
||||
"target_class": roi_config.get("target_class", "person"),
|
||||
if roi_id not in self.algorithms:
|
||||
self.algorithms[roi_id] = {}
|
||||
|
||||
key = f"{roi_id}_{bind_id}"
|
||||
|
||||
if algo_code == "leave_post":
|
||||
algo_params = {
|
||||
"confirm_on_duty_sec": params.get("confirm_on_duty_sec", 10),
|
||||
"confirm_leave_sec": params.get("confirm_leave_sec", 10),
|
||||
"cooldown_sec": params.get("cooldown_sec", 300),
|
||||
"working_hours": params.get("working_hours", []),
|
||||
"target_class": params.get("target_class", bind_config.get("target_class", "person")),
|
||||
}
|
||||
if roi_id in self.algorithms and "leave_post" in self.algorithms[roi_id]:
|
||||
algo = self.algorithms[roi_id]["leave_post"]
|
||||
algo.confirm_on_duty_sec = params["confirm_on_duty_sec"]
|
||||
algo.confirm_leave_sec = params["confirm_leave_sec"]
|
||||
algo.cooldown_sec = params["cooldown_sec"]
|
||||
algo.target_class = params["target_class"]
|
||||
if params["working_hours"]:
|
||||
algo.working_hours = params["working_hours"]
|
||||
logger.info(f"已热更新算法参数: {roi_id}")
|
||||
if key in self.algorithms.get(roi_id, {}) and "leave_post" in self.algorithms[roi_id].get(key, {}):
|
||||
algo = self.algorithms[roi_id][key]["leave_post"]
|
||||
algo.confirm_on_duty_sec = algo_params["confirm_on_duty_sec"]
|
||||
algo.confirm_leave_sec = algo_params["confirm_leave_sec"]
|
||||
algo.cooldown_sec = algo_params["cooldown_sec"]
|
||||
algo.target_class = algo_params["target_class"]
|
||||
if algo_params["working_hours"]:
|
||||
algo.working_hours = algo_params["working_hours"]
|
||||
logger.info(f"已热更新算法参数: {key}")
|
||||
else:
|
||||
self.register_algorithm(roi_id, "leave_post", params)
|
||||
logger.info(f"已从Redis加载算法: {roi_id}")
|
||||
else:
|
||||
params = {}
|
||||
if roi_id in self.algorithms and algorithm_type in self.algorithms[roi_id]:
|
||||
pass
|
||||
else:
|
||||
self.register_algorithm(roi_id, algorithm_type, params)
|
||||
logger.info(f"已从Redis加载算法: {roi_id}")
|
||||
self.algorithms[roi_id][key] = {}
|
||||
self.algorithms[roi_id][key]["leave_post"] = LeavePostAlgorithm(
|
||||
confirm_on_duty_sec=algo_params["confirm_on_duty_sec"],
|
||||
confirm_leave_sec=algo_params["confirm_leave_sec"],
|
||||
cooldown_sec=algo_params["cooldown_sec"],
|
||||
working_hours=algo_params["working_hours"],
|
||||
target_class=algo_params["target_class"],
|
||||
)
|
||||
logger.info(f"已从Redis加载算法: {key}")
|
||||
elif algo_code == "intrusion":
|
||||
algo_params = {
|
||||
"cooldown_seconds": params.get("cooldown_seconds", 120),
|
||||
"confirm_seconds": params.get("confirm_seconds", 5),
|
||||
"target_class": params.get("target_class", bind_config.get("target_class")),
|
||||
}
|
||||
self.algorithms[roi_id][key] = {}
|
||||
self.algorithms[roi_id][key]["intrusion"] = IntrusionAlgorithm(
|
||||
cooldown_seconds=algo_params["cooldown_seconds"],
|
||||
confirm_seconds=algo_params["confirm_seconds"],
|
||||
target_class=algo_params["target_class"],
|
||||
)
|
||||
logger.info(f"已从Redis加载算法: {key}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"从Redis加载算法配置失败: {e}")
|
||||
return False
|
||||
|
||||
def reload_bind_algorithm(self, bind_id: str) -> bool:
|
||||
"""重新加载单个绑定的算法配置"""
|
||||
return self.load_bind_from_redis(bind_id)
|
||||
|
||||
def reload_algorithm(self, roi_id: str) -> bool:
|
||||
"""重新加载单个ROI的算法配置"""
|
||||
if roi_id not in self.algorithms:
|
||||
return self.load_from_redis(roi_id)
|
||||
|
||||
self.reset_algorithm(roi_id)
|
||||
return self.load_from_redis(roi_id)
|
||||
"""重新加载单个ROI的所有算法绑定配置"""
|
||||
try:
|
||||
from core.config_sync import get_config_sync_manager
|
||||
config_manager = get_config_sync_manager()
|
||||
bindings = config_manager.get_bindings_from_redis(roi_id)
|
||||
|
||||
if not bindings:
|
||||
return False
|
||||
|
||||
for bind in bindings:
|
||||
bind_id = bind.get("bind_id")
|
||||
self.reset_algorithm(roi_id, bind_id)
|
||||
self.load_bind_from_redis(bind_id)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重新加载ROI算法配置失败: {e}")
|
||||
return False
|
||||
|
||||
def reload_all_algorithms(self) -> int:
|
||||
"""重新加载所有算法配置"""
|
||||
@@ -432,11 +558,13 @@ class AlgorithmManager:
|
||||
try:
|
||||
from core.config_sync import get_config_sync_manager
|
||||
config_manager = get_config_sync_manager()
|
||||
roi_configs = config_manager.get_all_roi_configs()
|
||||
bindings = config_manager.get_bindings_from_redis("")
|
||||
|
||||
for roi_config in roi_configs:
|
||||
roi_id = roi_config.get("roi_id")
|
||||
if self.reload_algorithm(roi_id):
|
||||
for bind in bindings:
|
||||
bind_id = bind.get("bind_id")
|
||||
roi_id = bind.get("roi_id")
|
||||
self.reset_algorithm(roi_id, bind_id)
|
||||
if self.load_bind_from_redis(bind_id):
|
||||
count += 1
|
||||
|
||||
logger.info(f"已重新加载 {count} 个算法配置")
|
||||
@@ -448,22 +576,29 @@ class AlgorithmManager:
|
||||
def register_algorithm(
|
||||
self,
|
||||
roi_id: str,
|
||||
bind_id: str,
|
||||
algorithm_type: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if roi_id in self.algorithms and algorithm_type in self.algorithms[roi_id]:
|
||||
"""注册算法(支持绑定ID)"""
|
||||
key = f"{roi_id}_{bind_id}"
|
||||
|
||||
if key in self.algorithms and algorithm_type in self.algorithms[key]:
|
||||
return
|
||||
|
||||
|
||||
if roi_id not in self.algorithms:
|
||||
self.algorithms[roi_id] = {}
|
||||
|
||||
|
||||
if key not in self.algorithms[roi_id]:
|
||||
self.algorithms[roi_id][key] = {}
|
||||
|
||||
algo_params = self.default_params.get(algorithm_type, {}).copy()
|
||||
if params:
|
||||
algo_params.update(params)
|
||||
|
||||
if algorithm_type == "leave_post":
|
||||
roi_working_hours = algo_params.get("working_hours") or self.working_hours
|
||||
self.algorithms[roi_id]["leave_post"] = LeavePostAlgorithm(
|
||||
self.algorithms[roi_id][key]["leave_post"] = LeavePostAlgorithm(
|
||||
confirm_on_duty_sec=algo_params.get("confirm_on_duty_sec", 10),
|
||||
confirm_leave_sec=algo_params.get("confirm_leave_sec", 10),
|
||||
cooldown_sec=algo_params.get("cooldown_sec", 300),
|
||||
@@ -471,21 +606,31 @@ class AlgorithmManager:
|
||||
target_class=algo_params.get("target_class", "person"),
|
||||
)
|
||||
elif algorithm_type == "intrusion":
|
||||
self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm(
|
||||
self.algorithms[roi_id][key]["intrusion"] = IntrusionAlgorithm(
|
||||
cooldown_seconds=algo_params.get("cooldown_seconds", 120),
|
||||
confirm_seconds=algo_params.get("confirm_seconds", 5),
|
||||
target_class=algo_params.get("target_class"),
|
||||
)
|
||||
|
||||
elif algorithm_type == "crowd_detection":
|
||||
from algorithms import CrowdDetectionAlgorithm
|
||||
self.algorithms[roi_id][key]["crowd_detection"] = CrowdDetectionAlgorithm(
|
||||
max_count=algo_params.get("max_count", 10),
|
||||
cooldown_seconds=algo_params.get("cooldown_seconds", 300),
|
||||
target_class=algo_params.get("target_class", "person"),
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
bind_id: str,
|
||||
camera_id: str,
|
||||
algorithm_type: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
algo = self.algorithms.get(roi_id, {}).get(algorithm_type)
|
||||
"""处理检测结果(支持绑定ID)"""
|
||||
key = f"{roi_id}_{bind_id}"
|
||||
algo = self.algorithms.get(roi_id, {}).get(key, {}).get(algorithm_type)
|
||||
if algo is None:
|
||||
return []
|
||||
return algo.process(roi_id, camera_id, tracks, current_time)
|
||||
@@ -493,47 +638,69 @@ class AlgorithmManager:
|
||||
def update_roi_params(
|
||||
self,
|
||||
roi_id: str,
|
||||
bind_id: str,
|
||||
algorithm_type: str,
|
||||
params: Dict[str, Any],
|
||||
):
|
||||
if roi_id in self.algorithms and algorithm_type in self.algorithms[roi_id]:
|
||||
algo = self.algorithms[roi_id][algorithm_type]
|
||||
"""更新算法参数(支持绑定ID)"""
|
||||
key = f"{roi_id}_{bind_id}"
|
||||
if roi_id in self.algorithms and key in self.algorithms[roi_id] and algorithm_type in self.algorithms[roi_id][key]:
|
||||
algo = self.algorithms[roi_id][key][algorithm_type]
|
||||
for key, value in params.items():
|
||||
if hasattr(algo, key):
|
||||
setattr(algo, key, value)
|
||||
|
||||
def reset_algorithm(self, roi_id: str, algorithm_type: Optional[str] = None):
|
||||
|
||||
def reset_algorithm(self, roi_id: str, bind_id: Optional[str] = None):
|
||||
"""重置算法状态(支持绑定ID)"""
|
||||
if roi_id not in self.algorithms:
|
||||
return
|
||||
|
||||
if algorithm_type:
|
||||
if algorithm_type in self.algorithms[roi_id]:
|
||||
self.algorithms[roi_id][algorithm_type].reset()
|
||||
|
||||
if bind_id:
|
||||
key = f"{roi_id}_{bind_id}"
|
||||
if key in self.algorithms[roi_id]:
|
||||
if algorithm_type in self.algorithms[roi_id][key]:
|
||||
self.algorithms[roi_id][key][algorithm_type].reset()
|
||||
else:
|
||||
for algo in self.algorithms[roi_id].values():
|
||||
algo.reset()
|
||||
|
||||
for key in self.algorithms[roi_id]:
|
||||
for algo in self.algorithms[roi_id][key].values():
|
||||
algo.reset()
|
||||
|
||||
def reset_all(self):
|
||||
"""重置所有算法"""
|
||||
for roi_algorithms in self.algorithms.values():
|
||||
for algo in roi_algorithms.values():
|
||||
algo.reset()
|
||||
|
||||
for bind_algorithms in roi_algorithms.values():
|
||||
for algo in bind_algorithms.values():
|
||||
algo.reset()
|
||||
|
||||
def remove_roi(self, roi_id: str):
|
||||
"""移除ROI的所有算法"""
|
||||
if roi_id in self.algorithms:
|
||||
self.reset_algorithm(roi_id)
|
||||
for key in list(self.algorithms[roi_id].keys()):
|
||||
self.reset_algorithm(roi_id, key.split("_")[-1] if "_" in key else None)
|
||||
del self.algorithms[roi_id]
|
||||
|
||||
|
||||
def remove_bind(self, roi_id: str, bind_id: str):
|
||||
"""移除绑定的算法"""
|
||||
key = f"{roi_id}_{bind_id}"
|
||||
if roi_id in self.algorithms and key in self.algorithms[roi_id]:
|
||||
for algo in self.algorithms[roi_id][key].values():
|
||||
algo.reset()
|
||||
del self.algorithms[roi_id][key]
|
||||
|
||||
def get_status(self, roi_id: str) -> Dict[str, Any]:
|
||||
"""获取算法状态"""
|
||||
status = {}
|
||||
if roi_id in self.algorithms:
|
||||
for algo_type, algo in self.algorithms[roi_id].items():
|
||||
if algo_type == "leave_post":
|
||||
status[algo_type] = {
|
||||
"state": getattr(algo, "state", "WAITING"),
|
||||
"alarm_sent": getattr(algo, "alarm_sent", False),
|
||||
}
|
||||
else:
|
||||
status[algo_type] = {
|
||||
"detection_count": len(getattr(algo, "detection_start", {})),
|
||||
}
|
||||
for key, bind_algorithms in self.algorithms[roi_id].items():
|
||||
bind_id = key.split("_", 1)[-1] if "_" in key else ""
|
||||
for algo_type, algo in bind_algorithms.items():
|
||||
if algo_type == "leave_post":
|
||||
status[f"{algo_type}_{bind_id}"] = {
|
||||
"state": getattr(algo, "state", "WAITING"),
|
||||
"alarm_sent": getattr(algo, "alarm_sent", False),
|
||||
}
|
||||
else:
|
||||
status[f"{algo_type}_{bind_id}"] = {
|
||||
"detection_count": len(getattr(algo, "detection_start", {})),
|
||||
}
|
||||
return status
|
||||
|
||||
Reference in New Issue
Block a user