Files
security-ai-edge/algorithms.py
16337 98595402c6 fix: 修复10个关键bug提升系统稳定性和性能
1. YOLO11输出解析错误: 移除不存在的objectness行,正确使用class_scores.max()
2. CPU NMS逻辑错误: keep_mask同时标记保留和抑制框导致NMS失效,改用独立suppressed集合
3. 坐标映射缺失: _build_tracks中scale_info未使用,添加revert_boxes还原到ROI裁剪空间
4. batch=1限制: 恢复真正的动态batch推理(1~8),BatchPreprocessor支持多图stack
5. 帧率控制缺失: _read_frame添加time.monotonic()间隔控制,按target_fps跳帧
6. 拉流推理耦合: 新增独立推理线程(InferenceWorker),生产者-消费者模式解耦
7. 攒批形同虚设: 添加50ms攒批窗口+max_batch阈值,替代>=1立即处理
8. LeavePost双重等待: LEAVING确认后直接触发告警,不再进入OFF_DUTY二次等待
9. register_algorithm每帧调用: 添加_registered_keys缓存,O(1)快速路径跳过
10. GPU context线程安全: TensorRT infer()内部加锁,防止多线程CUDA context竞争

附带修复:
- reset_algorithm中未定义algorithm_type变量(NameError)
- update_roi_params中循环变量key覆盖外层key
- AlertInfo缺少bind_id字段(TypeError)
- _logger.log_alert在标准logger上不存在(AttributeError)
- AlarmStateMachine死锁(Lock改为RLock)
- ROICropper.create_mask坐标解析错误
- 更新测试用例适配新API

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 16:47:26 +08:00

738 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import sys
import threading
import time
from collections import deque
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
import redis
logger = logging.getLogger(__name__)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class LeavePostAlgorithm:
STATE_WAITING = "WAITING"
STATE_ON_DUTY = "ON_DUTY"
STATE_LEAVING = "LEAVING"
STATE_OFF_DUTY = "OFF_DUTY"
STATE_NON_WORK_TIME = "NON_WORK_TIME"
def __init__(
self,
confirm_on_duty_sec: int = 10,
confirm_leave_sec: int = 10,
cooldown_sec: int = 300,
working_hours: Optional[List[Dict]] = None,
target_class: Optional[str] = "person",
):
self.confirm_on_duty_sec = confirm_on_duty_sec
self.confirm_leave_sec = confirm_leave_sec
self.cooldown_sec = cooldown_sec
self.working_hours = working_hours or []
self.target_class = target_class
self.alert_cooldowns: Dict[str, datetime] = {}
self.state: str = self.STATE_WAITING
self.state_start_time: Optional[datetime] = None
self.detection_history: deque = deque()
self.alarm_sent: bool = False
self.last_person_time: Optional[datetime] = None
def _is_in_working_hours(self, dt: Optional[datetime] = None) -> bool:
if not self.working_hours:
return True
import json
working_hours = self.working_hours
if isinstance(working_hours, str):
try:
working_hours = json.loads(working_hours)
except:
return True
if not working_hours:
return True
dt = dt or datetime.now()
current_minutes = dt.hour * 60 + dt.minute
for period in working_hours:
start_str = period["start"] if isinstance(period, dict) else period
end_str = period["end"] if isinstance(period, dict) else period
start_minutes = self._parse_time_to_minutes(start_str)
end_minutes = self._parse_time_to_minutes(end_str)
if start_minutes <= current_minutes < end_minutes:
return True
return False
def _parse_time_to_minutes(self, time_str: str) -> int:
"""将时间字符串转换为分钟数"""
if isinstance(time_str, int):
return time_str
try:
parts = time_str.split(":")
return int(parts[0]) * 60 + int(parts[1])
except:
return 0
def _check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
matched_rois = detection.get("matched_rois", [])
for roi in matched_rois:
if roi.get("roi_id") == roi_id:
return True
return False
def _check_target_class(self, detection: Dict, target_class: Optional[str]) -> bool:
if not target_class:
return True
return detection.get("class") == target_class
def _get_detection_window(self, current_time: datetime) -> List[bool]:
detections = []
while self.detection_history and (current_time - self.detection_history[0][0]).total_seconds() > max(self.confirm_on_duty_sec, self.confirm_leave_sec):
self.detection_history.popleft()
for _, has_person in self.detection_history:
detections.append(has_person)
return detections
def _get_latest_bbox(self, tracks: List[Dict], roi_id: str) -> List[float]:
for det in tracks:
if self._check_detection_in_roi(det, roi_id):
return det.get("bbox", [])
return []
def process(
self,
roi_id: str,
camera_id: str,
tracks: List[Dict],
current_time: Optional[datetime] = None,
) -> List[Dict]:
current_time = current_time or datetime.now()
in_work = self._is_in_working_hours(current_time)
alerts = []
if not in_work:
self.state = self.STATE_NON_WORK_TIME
self.detection_history.clear()
self.alarm_sent = False
return []
if self.state == self.STATE_NON_WORK_TIME:
self.state = self.STATE_WAITING
self.detection_history.clear()
self.alarm_sent = False
roi_has_person = False
for det in tracks:
if self._check_detection_in_roi(det, roi_id) and self._check_target_class(det, self.target_class):
roi_has_person = True
break
if self.state == self.STATE_WAITING:
if roi_has_person:
self.state = self.STATE_ON_DUTY
self.state_start_time = current_time
self.detection_history.clear()
self.detection_history.append((current_time, True))
else:
pass
elif self.state == self.STATE_ON_DUTY:
self.detection_history.append((current_time, roi_has_person))
if not roi_has_person:
self.state = self.STATE_LEAVING
self.state_start_time = current_time
elif self.state == self.STATE_LEAVING:
self.detection_history.append((current_time, roi_has_person))
elapsed = (current_time - self.state_start_time).total_seconds()
if roi_has_person:
self.state = self.STATE_ON_DUTY
self.state_start_time = current_time
elif elapsed >= self.confirm_leave_sec:
# 确认离岗后直接触发告警,不再进入 OFF_DUTY 二次等待
self.state = self.STATE_OFF_DUTY
self.state_start_time = current_time
cooldown_key = f"{camera_id}_{roi_id}"
now = datetime.now()
if cooldown_key not in self.alert_cooldowns or (now - self.alert_cooldowns[cooldown_key]).total_seconds() > self.cooldown_sec:
bbox = self._get_latest_bbox(tracks, roi_id)
elapsed_minutes = int(elapsed / 60)
alerts.append({
"track_id": roi_id,
"camera_id": camera_id,
"bbox": bbox,
"duration_minutes": elapsed_minutes,
"alert_type": "leave_post",
"message": f"离岗 {elapsed_minutes} 分钟",
})
self.alert_cooldowns[cooldown_key] = now
elif self.state == self.STATE_OFF_DUTY:
# OFF_DUTY 状态:等待人员回岗或冷却后可再次告警
if roi_has_person:
self.state = self.STATE_ON_DUTY
self.state_start_time = current_time
else:
elapsed = (current_time - self.state_start_time).total_seconds()
cooldown_key = f"{camera_id}_{roi_id}"
now = datetime.now()
if cooldown_key in self.alert_cooldowns and (now - self.alert_cooldowns[cooldown_key]).total_seconds() > self.cooldown_sec:
bbox = self._get_latest_bbox(tracks, roi_id)
elapsed_minutes = int(elapsed / 60)
alerts.append({
"track_id": roi_id,
"camera_id": camera_id,
"bbox": bbox,
"duration_minutes": elapsed_minutes,
"alert_type": "leave_post",
"message": f"持续离岗 {elapsed_minutes} 分钟",
})
self.alert_cooldowns[cooldown_key] = now
return alerts
def reset(self):
self.state = self.STATE_WAITING
self.state_start_time = None
self.detection_history.clear()
self.alarm_sent = False
self.last_person_time = None
self.alert_cooldowns.clear()
def get_state(self, roi_id: str) -> Dict[str, Any]:
return {
"state": self.state,
"alarm_sent": self.alarm_sent,
"last_person_time": self.last_person_time,
}
class IntrusionAlgorithm:
def __init__(
self,
cooldown_seconds: int = 120,
confirm_seconds: int = 5,
target_class: Optional[str] = None,
):
self.cooldown_seconds = cooldown_seconds
self.confirm_seconds = confirm_seconds
self.target_class = target_class
self.last_alert_time: Dict[str, datetime] = {}
self.alert_triggered: Dict[str, bool] = {}
self.detection_start: Dict[str, Optional[datetime]] = {}
def _check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
matched_rois = detection.get("matched_rois", [])
for roi in matched_rois:
if roi.get("roi_id") == roi_id:
return True
return False
def _check_target_class(self, detection: Dict, target_class: Optional[str]) -> bool:
if not target_class:
return True
return detection.get("class") == target_class
def _get_latest_bbox(self, tracks: List[Dict], roi_id: str) -> List[float]:
for det in tracks:
if self._check_detection_in_roi(det, roi_id):
return det.get("bbox", [])
return []
def process(
self,
roi_id: str,
camera_id: str,
tracks: List[Dict],
current_time: Optional[datetime] = None,
) -> List[Dict]:
current_time = current_time or datetime.now()
key = f"{camera_id}_{roi_id}"
roi_has_person = False
for det in tracks:
if self._check_detection_in_roi(det, roi_id) and self._check_target_class(det, self.target_class):
roi_has_person = True
break
if not roi_has_person:
self.detection_start.pop(key, None)
self.alert_triggered[key] = False
return []
if self.alert_triggered.get(key, False):
elapsed_since_alert = (current_time - self.last_alert_time.get(key, datetime.min)).total_seconds()
if elapsed_since_alert < self.cooldown_seconds:
return []
self.alert_triggered[key] = False
if self.detection_start.get(key) is None:
self.detection_start[key] = current_time
elapsed = (current_time - self.detection_start[key]).total_seconds()
if elapsed < self.confirm_seconds:
return []
bbox = self._get_latest_bbox(tracks, roi_id)
self.last_alert_time[key] = current_time
self.alert_triggered[key] = True
self.detection_start[key] = None
return [{
"roi_id": roi_id,
"camera_id": camera_id,
"bbox": bbox,
"alert_type": "intrusion",
"message": "检测到周界入侵",
}]
def reset(self):
self.last_alert_time.clear()
self.alert_triggered.clear()
self.detection_start.clear()
class CrowdDetectionAlgorithm:
"""人群聚集检测算法"""
def __init__(
self,
max_count: int = 10,
cooldown_seconds: int = 300,
target_class: Optional[str] = "person",
):
self.max_count = max_count
self.cooldown_seconds = cooldown_seconds
self.target_class = target_class
self.last_alert_time: Dict[str, datetime] = {}
self.alert_triggered: Dict[str, bool] = {}
def _check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
matched_rois = detection.get("matched_rois", [])
for roi in matched_rois:
if roi.get("roi_id") == roi_id:
return True
return False
def _check_target_class(self, detection: Dict, target_class: Optional[str]) -> bool:
if not target_class:
return True
return detection.get("class") == target_class
def _get_bboxes(self, tracks: List[Dict], roi_id: str) -> List[List[float]]:
bboxes = []
for det in tracks:
if self._check_detection_in_roi(det, roi_id) and self._check_target_class(det, self.target_class):
bboxes.append(det.get("bbox", []))
return bboxes
def process(
self,
roi_id: str,
camera_id: str,
tracks: List[Dict],
current_time: Optional[datetime] = None,
) -> List[Dict]:
current_time = current_time or datetime.now()
key = f"{camera_id}_{roi_id}"
person_count = 0
for det in tracks:
if self._check_detection_in_roi(det, roi_id) and self._check_target_class(det, self.target_class):
person_count += 1
if person_count <= self.max_count:
self.alert_triggered[key] = False
return []
if self.alert_triggered.get(key, False):
elapsed_since_alert = (current_time - self.last_alert_time.get(key, datetime.min)).total_seconds()
if elapsed_since_alert < self.cooldown_seconds:
return []
self.alert_triggered[key] = False
bboxes = self._get_bboxes(tracks, roi_id)
self.last_alert_time[key] = current_time
self.alert_triggered[key] = True
return [{
"roi_id": roi_id,
"camera_id": camera_id,
"bbox": bboxes[0] if bboxes else [],
"alert_type": "crowd_detection",
"message": f"检测到人群聚集,当前人数: {person_count}",
"count": person_count,
}]
def reset(self):
self.last_alert_time.clear()
self.alert_triggered.clear()
class AlgorithmManager:
def __init__(self, working_hours: Optional[List[Dict]] = None):
self.algorithms: Dict[str, Dict[str, Any]] = {}
self.working_hours = working_hours or []
self._update_lock = threading.Lock()
self._registered_keys: set = set() # 已注册的 (roi_id, bind_id, algo_type) 缓存
self.default_params = {
"leave_post": {
"confirm_on_duty_sec": 10,
"confirm_leave_sec": 10,
"cooldown_sec": 300,
"target_class": "person",
},
"intrusion": {
"cooldown_seconds": 120,
"confirm_seconds": 5,
"target_class": None,
},
"crowd_detection": {
"max_count": 10,
"cooldown_seconds": 300,
"target_class": "person",
},
}
self._pubsub = None
self._pubsub_thread = None
self._running = False
def start_config_subscription(self):
"""启动配置变更订阅"""
try:
from config.settings import get_settings
settings = get_settings()
redis_client = redis.Redis(
host=settings.redis.host,
port=settings.redis.port,
db=settings.redis.db,
password=settings.redis.password,
decode_responses=True,
)
self._pubsub = redis_client.pubsub()
self._pubsub.subscribe("config_update")
self._running = True
self._pubsub_thread = threading.Thread(
target=self._config_update_worker,
name="ConfigUpdateSub",
daemon=True
)
self._pubsub_thread.start()
logger.info("已启动配置变更订阅")
except Exception as e:
logger.error(f"启动配置订阅失败: {e}")
def _config_update_worker(self):
"""配置更新订阅工作线程"""
try:
for message in self._pubsub.listen():
if not self._running:
break
if message["type"] == "message":
try:
import json
data = json.loads(message["data"])
if data.get("type") == "roi":
roi_ids = data.get("ids", [])
if roi_ids:
for roi_id in roi_ids:
self.reload_algorithm(roi_id)
else:
self.reload_all_algorithms()
elif data.get("type") == "bind":
bind_ids = data.get("ids", [])
if bind_ids:
for bind_id in bind_ids:
self.reload_bind_algorithm(bind_id)
else:
self.reload_all_algorithms()
except Exception as e:
logger.error(f"处理配置更新消息失败: {e}")
except Exception as e:
logger.error(f"配置订阅线程异常: {e}")
def stop_config_subscription(self):
"""停止配置变更订阅"""
self._running = False
if self._pubsub:
self._pubsub.close()
if self._pubsub_thread and self._pubsub_thread.is_alive():
self._pubsub_thread.join(timeout=5)
logger.info("配置订阅已停止")
def load_bind_from_redis(self, bind_id: str) -> bool:
"""从Redis加载单个绑定配置的算法"""
try:
from core.config_sync import get_config_sync_manager
config_manager = get_config_sync_manager()
bind_config = config_manager.get_algo_bind_from_redis(bind_id)
if not bind_config:
return False
with self._update_lock:
roi_id = bind_config.get("roi_id")
algo_code = bind_config.get("algo_code", "leave_post")
params = bind_config.get("params", {})
if roi_id not in self.algorithms:
self.algorithms[roi_id] = {}
key = f"{roi_id}_{bind_id}"
if algo_code == "leave_post":
algo_params = {
"confirm_on_duty_sec": params.get("confirm_on_duty_sec", 10),
"confirm_leave_sec": params.get("confirm_leave_sec", 10),
"cooldown_sec": params.get("cooldown_sec", 300),
"working_hours": params.get("working_hours", []),
"target_class": params.get("target_class", bind_config.get("target_class", "person")),
}
if key in self.algorithms.get(roi_id, {}) and "leave_post" in self.algorithms[roi_id].get(key, {}):
algo = self.algorithms[roi_id][key]["leave_post"]
algo.confirm_on_duty_sec = algo_params["confirm_on_duty_sec"]
algo.confirm_leave_sec = algo_params["confirm_leave_sec"]
algo.cooldown_sec = algo_params["cooldown_sec"]
algo.target_class = algo_params["target_class"]
if algo_params["working_hours"]:
algo.working_hours = algo_params["working_hours"]
logger.info(f"已热更新算法参数: {key}")
else:
self.algorithms[roi_id][key] = {}
self.algorithms[roi_id][key]["leave_post"] = LeavePostAlgorithm(
confirm_on_duty_sec=algo_params["confirm_on_duty_sec"],
confirm_leave_sec=algo_params["confirm_leave_sec"],
cooldown_sec=algo_params["cooldown_sec"],
working_hours=algo_params["working_hours"],
target_class=algo_params["target_class"],
)
logger.info(f"已从Redis加载算法: {key}")
elif algo_code == "intrusion":
algo_params = {
"cooldown_seconds": params.get("cooldown_seconds", 120),
"confirm_seconds": params.get("confirm_seconds", 5),
"target_class": params.get("target_class", bind_config.get("target_class")),
}
self.algorithms[roi_id][key] = {}
self.algorithms[roi_id][key]["intrusion"] = IntrusionAlgorithm(
cooldown_seconds=algo_params["cooldown_seconds"],
confirm_seconds=algo_params["confirm_seconds"],
target_class=algo_params["target_class"],
)
logger.info(f"已从Redis加载算法: {key}")
return True
except Exception as e:
logger.error(f"从Redis加载算法配置失败: {e}")
return False
def reload_bind_algorithm(self, bind_id: str) -> bool:
"""重新加载单个绑定的算法配置"""
return self.load_bind_from_redis(bind_id)
def reload_algorithm(self, roi_id: str) -> bool:
"""重新加载单个ROI的所有算法绑定配置"""
try:
from core.config_sync import get_config_sync_manager
config_manager = get_config_sync_manager()
bindings = config_manager.get_bindings_from_redis(roi_id)
if not bindings:
return False
for bind in bindings:
bind_id = bind.get("bind_id")
self.reset_algorithm(roi_id, bind_id)
self.load_bind_from_redis(bind_id)
return True
except Exception as e:
logger.error(f"重新加载ROI算法配置失败: {e}")
return False
def reload_all_algorithms(self) -> int:
"""重新加载所有算法配置"""
count = 0
try:
from core.config_sync import get_config_sync_manager
config_manager = get_config_sync_manager()
bindings = config_manager.get_bindings_from_redis("")
for bind in bindings:
bind_id = bind.get("bind_id")
roi_id = bind.get("roi_id")
self.reset_algorithm(roi_id, bind_id)
if self.load_bind_from_redis(bind_id):
count += 1
logger.info(f"已重新加载 {count} 个算法配置")
return count
except Exception as e:
logger.error(f"重新加载所有算法配置失败: {e}")
return count
def register_algorithm(
self,
roi_id: str,
bind_id: str,
algorithm_type: str,
params: Optional[Dict[str, Any]] = None,
):
"""注册算法支持绑定ID使用缓存避免每帧重复查询"""
cache_key = (roi_id, bind_id, algorithm_type)
# 快速路径:已注册直接返回
if cache_key in self._registered_keys:
return
key = f"{roi_id}_{bind_id}"
if roi_id not in self.algorithms:
self.algorithms[roi_id] = {}
if key not in self.algorithms[roi_id]:
self.algorithms[roi_id][key] = {}
algo_params = self.default_params.get(algorithm_type, {}).copy()
if params:
algo_params.update(params)
if algorithm_type == "leave_post":
roi_working_hours = algo_params.get("working_hours") or self.working_hours
self.algorithms[roi_id][key]["leave_post"] = LeavePostAlgorithm(
confirm_on_duty_sec=algo_params.get("confirm_on_duty_sec", 10),
confirm_leave_sec=algo_params.get("confirm_leave_sec", 10),
cooldown_sec=algo_params.get("cooldown_sec", 300),
working_hours=roi_working_hours,
target_class=algo_params.get("target_class", "person"),
)
elif algorithm_type == "intrusion":
self.algorithms[roi_id][key]["intrusion"] = IntrusionAlgorithm(
cooldown_seconds=algo_params.get("cooldown_seconds", 120),
confirm_seconds=algo_params.get("confirm_seconds", 5),
target_class=algo_params.get("target_class"),
)
elif algorithm_type == "crowd_detection":
from algorithms import CrowdDetectionAlgorithm
self.algorithms[roi_id][key]["crowd_detection"] = CrowdDetectionAlgorithm(
max_count=algo_params.get("max_count", 10),
cooldown_seconds=algo_params.get("cooldown_seconds", 300),
target_class=algo_params.get("target_class", "person"),
)
self._registered_keys.add(cache_key)
def process(
self,
roi_id: str,
bind_id: str,
camera_id: str,
algorithm_type: str,
tracks: List[Dict],
current_time: Optional[datetime] = None,
) -> List[Dict]:
"""处理检测结果支持绑定ID"""
key = f"{roi_id}_{bind_id}"
algo = self.algorithms.get(roi_id, {}).get(key, {}).get(algorithm_type)
if algo is None:
return []
return algo.process(roi_id, camera_id, tracks, current_time)
def update_roi_params(
self,
roi_id: str,
bind_id: str,
algorithm_type: str,
params: Dict[str, Any],
):
"""更新算法参数支持绑定ID"""
key = f"{roi_id}_{bind_id}"
if roi_id in self.algorithms and key in self.algorithms[roi_id] and algorithm_type in self.algorithms[roi_id][key]:
algo = self.algorithms[roi_id][key][algorithm_type]
for param_key, value in params.items():
if hasattr(algo, param_key):
setattr(algo, param_key, value)
def reset_algorithm(self, roi_id: str, bind_id: Optional[str] = None):
"""重置算法状态支持绑定ID"""
if roi_id not in self.algorithms:
return
if bind_id:
key = f"{roi_id}_{bind_id}"
if key in self.algorithms[roi_id]:
for algo in self.algorithms[roi_id][key].values():
algo.reset()
# 清除注册缓存
self._registered_keys = {
k for k in self._registered_keys
if not (k[0] == roi_id and k[1] == bind_id)
}
else:
for key in self.algorithms[roi_id]:
for algo in self.algorithms[roi_id][key].values():
algo.reset()
# 清除该 roi 的所有注册缓存
self._registered_keys = {
k for k in self._registered_keys if k[0] != roi_id
}
def reset_all(self):
"""重置所有算法"""
for roi_algorithms in self.algorithms.values():
for bind_algorithms in roi_algorithms.values():
for algo in bind_algorithms.values():
algo.reset()
def remove_roi(self, roi_id: str):
"""移除ROI的所有算法"""
if roi_id in self.algorithms:
for key in list(self.algorithms[roi_id].keys()):
self.reset_algorithm(roi_id, key.split("_")[-1] if "_" in key else None)
del self.algorithms[roi_id]
def remove_bind(self, roi_id: str, bind_id: str):
"""移除绑定的算法"""
key = f"{roi_id}_{bind_id}"
if roi_id in self.algorithms and key in self.algorithms[roi_id]:
for algo in self.algorithms[roi_id][key].values():
algo.reset()
del self.algorithms[roi_id][key]
def get_status(self, roi_id: str) -> Dict[str, Any]:
"""获取算法状态"""
status = {}
if roi_id in self.algorithms:
for key, bind_algorithms in self.algorithms[roi_id].items():
bind_id = key.split("_", 1)[-1] if "_" in key else ""
for algo_type, algo in bind_algorithms.items():
if algo_type == "leave_post":
status[f"{algo_type}_{bind_id}"] = {
"state": getattr(algo, "state", "WAITING"),
"alarm_sent": getattr(algo, "alarm_sent", False),
}
else:
status[f"{algo_type}_{bind_id}"] = {
"detection_count": len(getattr(algo, "detection_start", {})),
}
return status