Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a891deba00 | |||
| bfe6a559d2 | |||
| f5077a25a8 | |||
| 9c73efe1eb | |||
| bf5ddb3e7a | |||
| 56f39f1be7 | |||
| 3266241064 | |||
| c6d8430867 |
29
.dockerignore
Normal file
29
.dockerignore
Normal file
@@ -0,0 +1,29 @@
|
||||
.git
|
||||
__pycache__
|
||||
*.pyc
|
||||
*.pyo
|
||||
.idea
|
||||
.vscode
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
|
||||
# 模型和数据通过卷挂载
|
||||
models/
|
||||
data/
|
||||
logs/
|
||||
|
||||
# 测试文件
|
||||
tests/
|
||||
test_*.py
|
||||
pytest.ini
|
||||
|
||||
# 文档
|
||||
docs/
|
||||
*.md
|
||||
!CLAUDE.md
|
||||
|
||||
# 临时文件
|
||||
*.engine
|
||||
*.onnx
|
||||
*.pt
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -15,8 +15,10 @@ build/
|
||||
logs/
|
||||
*.log
|
||||
|
||||
# 模型文件(忽略中间产物)
|
||||
# 模型文件(二进制大文件,不入库,防止 git stash --include-untracked 误删)
|
||||
models/*.onnx
|
||||
models/*.engine
|
||||
models/*.pt
|
||||
|
||||
# 环境配置
|
||||
.env
|
||||
|
||||
51
Dockerfile
Normal file
51
Dockerfile
Normal file
@@ -0,0 +1,51 @@
|
||||
# ============================================================
|
||||
# 基础镜像:NVIDIA TensorRT 23.08
|
||||
# 内含:CUDA 12.1.1 | cuDNN 8.9.3 | TensorRT 8.6.1.6 | Python 3.10
|
||||
# ============================================================
|
||||
FROM nvcr.io/nvidia/tensorrt:23.08-py3
|
||||
|
||||
LABEL maintainer="AI Edge Architecture Team"
|
||||
LABEL description="Edge AI Inference Service - YOLOv11n + TensorRT"
|
||||
|
||||
# 设置时区
|
||||
ENV TZ=Asia/Shanghai
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
||||
|
||||
# 系统依赖(视频解码、OpenCV 运行时)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libgl1-mesa-glx \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# 先复制依赖文件,利用 Docker 层缓存
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装 PyTorch(CUDA 12.1 版本)+ 其余依赖
|
||||
RUN pip install --no-cache-dir \
|
||||
torch==2.1.2 torchvision==0.16.2 \
|
||||
--index-url https://download.pytorch.org/whl/cu121 \
|
||||
&& pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制项目代码
|
||||
COPY __init__.py .
|
||||
COPY main.py .
|
||||
COPY algorithms.py .
|
||||
COPY build_engine.py .
|
||||
COPY config/ ./config/
|
||||
COPY core/ ./core/
|
||||
COPY utils/ ./utils/
|
||||
|
||||
# 模型和数据通过卷挂载,不打入镜像
|
||||
# -v /path/to/models:/app/models
|
||||
# -v /path/to/data:/app/data
|
||||
|
||||
# 日志目录
|
||||
RUN mkdir -p /app/logs /app/data
|
||||
|
||||
EXPOSE 9001
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
703
algorithms.py
703
algorithms.py
@@ -1283,12 +1283,549 @@ class VehicleCongestionAlgorithm(BaseAlgorithm):
|
||||
return state_info
|
||||
|
||||
|
||||
class NonMotorVehicleParkingAlgorithm(BaseAlgorithm):
|
||||
"""
|
||||
非机动车违停检测算法(状态机版本 v1.0)
|
||||
|
||||
状态机:
|
||||
IDLE → CONFIRMING_VEHICLE → PARKED_COUNTDOWN → ALARMED → CONFIRMING_CLEAR → IDLE
|
||||
|
||||
业务流程:
|
||||
1. 检测到非机动车进入禁停区 → 车辆确认期(confirm_vehicle_sec,默认10秒,ratio>=0.6)
|
||||
2. 确认有车 → 违停倒计时(parking_countdown_sec,默认180秒/3分钟)
|
||||
3. 倒计时结束仍有车 → 触发告警(ALARMED状态)
|
||||
4. 车辆离开 → 消失确认期(confirm_clear_sec,默认60秒,ratio<0.2)
|
||||
5. 确认车辆离开 → 发送resolve事件 → 回到空闲状态
|
||||
|
||||
使用滑动窗口(10秒)抗抖动,检测自行车和摩托车。
|
||||
"""
|
||||
|
||||
# 状态定义
|
||||
STATE_IDLE = "IDLE"
|
||||
STATE_CONFIRMING_VEHICLE = "CONFIRMING_VEHICLE"
|
||||
STATE_PARKED_COUNTDOWN = "PARKED_COUNTDOWN"
|
||||
STATE_ALARMED = "ALARMED"
|
||||
STATE_CONFIRMING_CLEAR = "CONFIRMING_CLEAR"
|
||||
|
||||
# 告警级别常量(默认值,可通过 params 覆盖)
|
||||
DEFAULT_ALARM_LEVEL = 2 # 普通
|
||||
|
||||
# 滑动窗口参数
|
||||
WINDOW_SIZE_SEC = 10
|
||||
|
||||
# 阈值常量(与 IllegalParkingAlgorithm 一致)
|
||||
RATIO_CONFIRMING_DROP = 0.3
|
||||
RATIO_CONFIRM_VEHICLE = 0.6
|
||||
RATIO_PARKED_LEAVE = 0.2
|
||||
RATIO_ALARMED_CLEAR = 0.15
|
||||
RATIO_CLEAR_RETURN = 0.5
|
||||
RATIO_CLEAR_CONFIRM = 0.2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
confirm_vehicle_sec: int = 10,
|
||||
parking_countdown_sec: int = 180,
|
||||
confirm_clear_sec: int = 60,
|
||||
cooldown_sec: int = 900,
|
||||
target_classes: Optional[List[str]] = None,
|
||||
alarm_level: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.confirm_vehicle_sec = confirm_vehicle_sec
|
||||
self.parking_countdown_sec = parking_countdown_sec
|
||||
self.confirm_clear_sec = confirm_clear_sec
|
||||
self.cooldown_sec = cooldown_sec
|
||||
self.target_classes = target_classes or ["bicycle", "motorcycle"]
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
self.state: str = self.STATE_IDLE
|
||||
self.state_start_time: Optional[datetime] = None
|
||||
|
||||
# 滑动窗口:存储 (timestamp, has_vehicle: bool)
|
||||
self._detection_window: deque = deque(maxlen=1000)
|
||||
|
||||
# 告警追踪
|
||||
self._parking_start_time: Optional[datetime] = None
|
||||
|
||||
# 冷却期管理
|
||||
self.alert_cooldowns: Dict[str, datetime] = {}
|
||||
|
||||
def _check_target_classes(self, detection: Dict) -> bool:
|
||||
"""检查检测目标是否属于非机动车类别"""
|
||||
det_class = detection.get("class", "")
|
||||
return det_class in self.target_classes
|
||||
|
||||
def _update_window(self, current_time: datetime, has_vehicle: bool):
|
||||
"""更新滑动窗口"""
|
||||
self._detection_window.append((current_time, has_vehicle))
|
||||
cutoff = current_time - timedelta(seconds=self.WINDOW_SIZE_SEC)
|
||||
while self._detection_window and self._detection_window[0][0] < cutoff:
|
||||
self._detection_window.popleft()
|
||||
|
||||
def _get_window_ratio(self) -> float:
|
||||
"""获取滑动窗口内的检测命中率"""
|
||||
if not self._detection_window:
|
||||
return 0.0
|
||||
hits = sum(1 for _, has in self._detection_window if has)
|
||||
return hits / len(self._detection_window)
|
||||
|
||||
def _scan_tracks(self, tracks: List[Dict], roi_id: str) -> Tuple[bool, int, List[float], float]:
|
||||
"""
|
||||
一次遍历 tracks,返回 (has_target, count, latest_bbox, max_confidence)。
|
||||
过滤 target_classes。
|
||||
"""
|
||||
has_target = False
|
||||
count = 0
|
||||
latest_bbox: List[float] = []
|
||||
max_confidence = 0.0
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
|
||||
has_target = True
|
||||
count += 1
|
||||
if not latest_bbox:
|
||||
latest_bbox = det.get("bbox", [])
|
||||
conf = det.get("confidence", 0.0)
|
||||
if conf > max_confidence:
|
||||
max_confidence = conf
|
||||
return has_target, count, latest_bbox, max_confidence
|
||||
|
||||
def _get_latest_bbox(self, tracks: List[Dict], roi_id: str) -> List[float]:
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
|
||||
return det.get("bbox", [])
|
||||
return []
|
||||
|
||||
def _get_max_confidence(self, tracks: List[Dict], roi_id: str) -> float:
|
||||
"""获取ROI内非机动车的最高置信度"""
|
||||
max_conf = 0.0
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
|
||||
max_conf = max(max_conf, det.get("confidence", 0.0))
|
||||
return max_conf
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
"""处理单帧检测结果"""
|
||||
current_time = current_time or datetime.now()
|
||||
alerts = []
|
||||
|
||||
# 一次遍历获取所有信息
|
||||
roi_has_vehicle, vehicle_count, scan_bbox, scan_confidence = self._scan_tracks(tracks, roi_id)
|
||||
|
||||
# 更新滑动窗口
|
||||
self._update_window(current_time, roi_has_vehicle)
|
||||
|
||||
# 计算一次比率,后续分支复用
|
||||
ratio = self._get_window_ratio()
|
||||
|
||||
# === 状态机处理 ===
|
||||
|
||||
if self.state == self.STATE_IDLE:
|
||||
if roi_has_vehicle:
|
||||
self.state = self.STATE_CONFIRMING_VEHICLE
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: IDLE → CONFIRMING_VEHICLE (非机动车)")
|
||||
|
||||
elif self.state == self.STATE_CONFIRMING_VEHICLE:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_IDLE
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if ratio < self.RATIO_CONFIRMING_DROP:
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_VEHICLE → IDLE (ratio={ratio:.2f}<{self.RATIO_CONFIRMING_DROP})")
|
||||
elif elapsed >= self.confirm_vehicle_sec and ratio >= self.RATIO_CONFIRM_VEHICLE:
|
||||
self._parking_start_time = self.state_start_time
|
||||
self.state = self.STATE_PARKED_COUNTDOWN
|
||||
self.state_start_time = current_time
|
||||
logger.info(f"ROI {roi_id}: CONFIRMING_VEHICLE → PARKED_COUNTDOWN (非机动车, ratio={ratio:.2f})")
|
||||
|
||||
elif self.state == self.STATE_PARKED_COUNTDOWN:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_IDLE
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if ratio < self.RATIO_PARKED_LEAVE:
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._parking_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: PARKED_COUNTDOWN → IDLE (非机动车离开, ratio={ratio:.2f})")
|
||||
elif elapsed >= self.parking_countdown_sec:
|
||||
cooldown_key = f"{camera_id}_{roi_id}"
|
||||
if cooldown_key not in self.alert_cooldowns or \
|
||||
(current_time - self.alert_cooldowns[cooldown_key]).total_seconds() > self.cooldown_sec:
|
||||
|
||||
alerts.append({
|
||||
"roi_id": roi_id,
|
||||
"camera_id": camera_id,
|
||||
"bbox": scan_bbox,
|
||||
"alert_type": "non_motor_vehicle_parking",
|
||||
"alarm_level": self._alarm_level,
|
||||
"confidence": scan_confidence,
|
||||
"message": f"检测到非机动车违停(已停留{int(elapsed / 60)}分钟)",
|
||||
"first_frame_time": self._parking_start_time.strftime('%Y-%m-%d %H:%M:%S') if self._parking_start_time else None,
|
||||
"duration_minutes": elapsed / 60,
|
||||
})
|
||||
|
||||
self.alert_cooldowns[cooldown_key] = current_time
|
||||
self.state = self.STATE_ALARMED
|
||||
logger.warning(f"ROI {roi_id}: PARKED_COUNTDOWN → ALARMED (非机动车违停告警触发)")
|
||||
else:
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._parking_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: PARKED_COUNTDOWN → IDLE (冷却期内)")
|
||||
|
||||
elif self.state == self.STATE_ALARMED:
|
||||
if ratio < self.RATIO_ALARMED_CLEAR:
|
||||
self.state = self.STATE_CONFIRMING_CLEAR
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: ALARMED → CONFIRMING_CLEAR (ratio={ratio:.2f}<{self.RATIO_ALARMED_CLEAR})")
|
||||
|
||||
elif self.state == self.STATE_CONFIRMING_CLEAR:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_IDLE
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if ratio >= self.RATIO_CLEAR_RETURN:
|
||||
self.state = self.STATE_ALARMED
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → ALARMED (非机动车仍在)")
|
||||
elif elapsed >= self.confirm_clear_sec and ratio < self.RATIO_CLEAR_CONFIRM:
|
||||
if self._last_alarm_id and self._parking_start_time:
|
||||
duration_ms = int((current_time - self._parking_start_time).total_seconds() * 1000)
|
||||
alerts.append({
|
||||
"alert_type": "alarm_resolve",
|
||||
"resolve_alarm_id": self._last_alarm_id,
|
||||
"duration_ms": duration_ms,
|
||||
"last_frame_time": current_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"resolve_type": "vehicle_left",
|
||||
})
|
||||
logger.info(f"ROI {roi_id}: 非机动车违停告警已解决(车辆离开)")
|
||||
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._last_alarm_id = None
|
||||
self._parking_start_time = None
|
||||
self.alert_cooldowns.clear()
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → IDLE")
|
||||
|
||||
return alerts
|
||||
|
||||
def reset(self):
|
||||
"""重置算法状态"""
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._last_alarm_id = None
|
||||
self._parking_start_time = None
|
||||
self._detection_window.clear()
|
||||
self.alert_cooldowns.clear()
|
||||
|
||||
def get_state(self, current_time: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""获取当前状态"""
|
||||
current_time = current_time or datetime.now()
|
||||
window_ratio = self._get_window_ratio()
|
||||
state_info = {
|
||||
"state": self.state,
|
||||
"state_start_time": self.state_start_time.isoformat() if self.state_start_time else None,
|
||||
"window_ratio": window_ratio,
|
||||
}
|
||||
if self.state in (self.STATE_ALARMED, self.STATE_PARKED_COUNTDOWN) and self._parking_start_time:
|
||||
state_info["parking_duration_sec"] = (current_time - self._parking_start_time).total_seconds()
|
||||
state_info["alarm_id"] = self._last_alarm_id
|
||||
return state_info
|
||||
|
||||
|
||||
class GarbageDetectionAlgorithm(BaseAlgorithm):
|
||||
"""
|
||||
垃圾检测算法(状态机版本 v1.0)
|
||||
|
||||
状态机:
|
||||
IDLE → CONFIRMING_GARBAGE → ALARMED → CONFIRMING_CLEAR → IDLE
|
||||
|
||||
业务流程:
|
||||
1. 检测到垃圾 → 垃圾确认期(confirm_garbage_sec,默认60秒,ratio>=0.6)
|
||||
2. 确认有垃圾 → 触发告警(ALARMED 状态)
|
||||
3. 垃圾消失(ratio<0.15)→ 消失确认期(confirm_clear_sec,默认60秒)
|
||||
4. 消失确认期内持续 ratio<0.2 → 发送 resolve 事件 → 回到 IDLE
|
||||
|
||||
与 IllegalParking 的差异:无 PARKED_COUNTDOWN 阶段(垃圾无"临时停留"概念)。
|
||||
使用滑动窗口(10秒)抗抖动,只检测 garbage 类。
|
||||
"""
|
||||
|
||||
# 状态定义
|
||||
STATE_IDLE = "IDLE"
|
||||
STATE_CONFIRMING_GARBAGE = "CONFIRMING_GARBAGE"
|
||||
STATE_ALARMED = "ALARMED"
|
||||
STATE_CONFIRMING_CLEAR = "CONFIRMING_CLEAR"
|
||||
|
||||
# 告警级别常量(默认值,可通过 params 覆盖)
|
||||
DEFAULT_ALARM_LEVEL = 2 # 普通
|
||||
|
||||
# 滑动窗口参数
|
||||
WINDOW_SIZE_SEC = 10
|
||||
|
||||
# 阈值常量
|
||||
RATIO_CONFIRMING_DROP = 0.3 # 确认期内命中率低于此值则回到 IDLE
|
||||
RATIO_CONFIRM_GARBAGE = 0.6 # 确认有垃圾的命中率阈值
|
||||
RATIO_ALARMED_CLEAR = 0.15 # 已告警状态下进入消失确认的阈值
|
||||
RATIO_CLEAR_RETURN = 0.5 # 消失确认期间垃圾再次出现的阈值
|
||||
RATIO_CLEAR_CONFIRM = 0.2 # 消失确认完成的阈值
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
confirm_garbage_sec: int = 60,
|
||||
confirm_clear_sec: int = 60,
|
||||
cooldown_sec: int = 1800,
|
||||
target_classes: Optional[List[str]] = None,
|
||||
alarm_level: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.confirm_garbage_sec = confirm_garbage_sec
|
||||
self.confirm_clear_sec = confirm_clear_sec
|
||||
self.cooldown_sec = cooldown_sec
|
||||
self.target_classes = target_classes or ["garbage"]
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
self.state: str = self.STATE_IDLE
|
||||
self.state_start_time: Optional[datetime] = None
|
||||
|
||||
# 滑动窗口:存储 (timestamp, has_garbage: bool)
|
||||
self._detection_window: deque = deque(maxlen=1000)
|
||||
|
||||
# 告警追踪
|
||||
self._garbage_start_time: Optional[datetime] = None
|
||||
|
||||
# 冷却期管理
|
||||
self.alert_cooldowns: Dict[str, datetime] = {}
|
||||
|
||||
def _check_target_classes(self, detection: Dict) -> bool:
|
||||
"""检查检测目标是否属于垃圾类别"""
|
||||
det_class = detection.get("class", "")
|
||||
return det_class in self.target_classes
|
||||
|
||||
def _update_window(self, current_time: datetime, has_garbage: bool):
|
||||
"""更新滑动窗口"""
|
||||
self._detection_window.append((current_time, has_garbage))
|
||||
cutoff = current_time - timedelta(seconds=self.WINDOW_SIZE_SEC)
|
||||
while self._detection_window and self._detection_window[0][0] < cutoff:
|
||||
self._detection_window.popleft()
|
||||
|
||||
def _get_window_ratio(self) -> float:
|
||||
"""获取滑动窗口内的检测命中率"""
|
||||
if not self._detection_window:
|
||||
return 0.0
|
||||
hits = sum(1 for _, has in self._detection_window if has)
|
||||
return hits / len(self._detection_window)
|
||||
|
||||
def _scan_tracks(self, tracks: List[Dict], roi_id: str) -> Tuple[bool, int, List[float], float]:
|
||||
"""
|
||||
一次遍历 tracks,返回 (has_target, count, latest_bbox, max_confidence)。
|
||||
过滤 target_classes。
|
||||
"""
|
||||
has_target = False
|
||||
count = 0
|
||||
latest_bbox: List[float] = []
|
||||
max_confidence = 0.0
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
|
||||
has_target = True
|
||||
count += 1
|
||||
if not latest_bbox:
|
||||
latest_bbox = det.get("bbox", [])
|
||||
conf = det.get("confidence", 0.0)
|
||||
if conf > max_confidence:
|
||||
max_confidence = conf
|
||||
return has_target, count, latest_bbox, max_confidence
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
"""处理单帧检测结果"""
|
||||
current_time = current_time or datetime.now()
|
||||
alerts = []
|
||||
|
||||
# 一次遍历获取所有信息
|
||||
roi_has_garbage, garbage_count, scan_bbox, scan_confidence = self._scan_tracks(tracks, roi_id)
|
||||
|
||||
# 更新滑动窗口
|
||||
self._update_window(current_time, roi_has_garbage)
|
||||
|
||||
# 计算一次比率,后续分支复用
|
||||
ratio = self._get_window_ratio()
|
||||
|
||||
# === 状态机处理 ===
|
||||
|
||||
if self.state == self.STATE_IDLE:
|
||||
if roi_has_garbage:
|
||||
self.state = self.STATE_CONFIRMING_GARBAGE
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: IDLE → CONFIRMING_GARBAGE")
|
||||
|
||||
elif self.state == self.STATE_CONFIRMING_GARBAGE:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_IDLE
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if ratio < self.RATIO_CONFIRMING_DROP:
|
||||
# 命中率过低,可能只是闪现
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
logger.debug(
|
||||
f"ROI {roi_id}: CONFIRMING_GARBAGE → IDLE "
|
||||
f"(ratio={ratio:.2f}<{self.RATIO_CONFIRMING_DROP})"
|
||||
)
|
||||
elif elapsed >= self.confirm_garbage_sec and ratio >= self.RATIO_CONFIRM_GARBAGE:
|
||||
# 确认有垃圾持续存在,检查冷却期
|
||||
cooldown_key = f"{camera_id}_{roi_id}"
|
||||
if cooldown_key not in self.alert_cooldowns or \
|
||||
(current_time - self.alert_cooldowns[cooldown_key]).total_seconds() > self.cooldown_sec:
|
||||
|
||||
self._garbage_start_time = self.state_start_time
|
||||
alerts.append({
|
||||
"roi_id": roi_id,
|
||||
"camera_id": camera_id,
|
||||
"bbox": scan_bbox,
|
||||
"alert_type": "garbage",
|
||||
"alarm_level": self._alarm_level,
|
||||
"confidence": scan_confidence,
|
||||
"message": f"检测到垃圾(持续{int(elapsed)}秒,{garbage_count}处)",
|
||||
"first_frame_time": self._garbage_start_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"garbage_count": garbage_count,
|
||||
})
|
||||
|
||||
self.alert_cooldowns[cooldown_key] = current_time
|
||||
self.state = self.STATE_ALARMED
|
||||
logger.warning(f"ROI {roi_id}: CONFIRMING_GARBAGE → ALARMED (垃圾告警触发)")
|
||||
else:
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_GARBAGE → IDLE (冷却期内)")
|
||||
|
||||
elif self.state == self.STATE_ALARMED:
|
||||
if ratio < self.RATIO_ALARMED_CLEAR:
|
||||
self.state = self.STATE_CONFIRMING_CLEAR
|
||||
self.state_start_time = current_time
|
||||
logger.debug(
|
||||
f"ROI {roi_id}: ALARMED → CONFIRMING_CLEAR "
|
||||
f"(ratio={ratio:.2f}<{self.RATIO_ALARMED_CLEAR})"
|
||||
)
|
||||
|
||||
elif self.state == self.STATE_CONFIRMING_CLEAR:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_IDLE
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if ratio >= self.RATIO_CLEAR_RETURN:
|
||||
# 垃圾又出现(或清扫者挡住片刻),回到 ALARMED
|
||||
self.state = self.STATE_ALARMED
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → ALARMED (垃圾仍在)")
|
||||
elif elapsed >= self.confirm_clear_sec and ratio < self.RATIO_CLEAR_CONFIRM:
|
||||
# 确认垃圾已被清理
|
||||
if self._last_alarm_id and self._garbage_start_time:
|
||||
duration_ms = int((current_time - self._garbage_start_time).total_seconds() * 1000)
|
||||
alerts.append({
|
||||
"alert_type": "alarm_resolve",
|
||||
"resolve_alarm_id": self._last_alarm_id,
|
||||
"duration_ms": duration_ms,
|
||||
"last_frame_time": current_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"resolve_type": "garbage_removed",
|
||||
})
|
||||
logger.info(f"ROI {roi_id}: 垃圾告警已解决(垃圾被清理)")
|
||||
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._last_alarm_id = None
|
||||
self._garbage_start_time = None
|
||||
self.alert_cooldowns.clear() # 清理后清空冷却,新垃圾可正常告警
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → IDLE")
|
||||
|
||||
return alerts
|
||||
|
||||
def reset(self):
|
||||
"""重置算法状态"""
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._last_alarm_id = None
|
||||
self._garbage_start_time = None
|
||||
self._detection_window.clear()
|
||||
self.alert_cooldowns.clear()
|
||||
|
||||
def get_state(self, current_time: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""获取当前状态"""
|
||||
current_time = current_time or datetime.now()
|
||||
window_ratio = self._get_window_ratio()
|
||||
state_info = {
|
||||
"state": self.state,
|
||||
"state_start_time": self.state_start_time.isoformat() if self.state_start_time else None,
|
||||
"window_ratio": window_ratio,
|
||||
}
|
||||
if self.state in (self.STATE_ALARMED,) and self._garbage_start_time:
|
||||
state_info["garbage_duration_sec"] = (current_time - self._garbage_start_time).total_seconds()
|
||||
state_info["alarm_id"] = self._last_alarm_id
|
||||
return state_info
|
||||
|
||||
|
||||
class AlgorithmManager:
|
||||
# 参数类型定义,用于三级合并后的类型强制转换
|
||||
_PARAM_TYPES = {
|
||||
"leave_post": {
|
||||
"confirm_on_duty_sec": int, "confirm_off_duty_sec": int,
|
||||
"confirm_leave_sec": int, "leave_countdown_sec": int, "cooldown_sec": int,
|
||||
},
|
||||
"intrusion": {
|
||||
"cooldown_seconds": int, "confirm_seconds": int,
|
||||
"confirm_intrusion_seconds": int, "confirm_clear_seconds": int,
|
||||
},
|
||||
"illegal_parking": {
|
||||
"confirm_vehicle_sec": int, "parking_countdown_sec": int,
|
||||
"confirm_clear_sec": int, "cooldown_sec": int,
|
||||
},
|
||||
"vehicle_congestion": {
|
||||
"count_threshold": int, "confirm_congestion_sec": int,
|
||||
"confirm_clear_sec": int, "cooldown_sec": int,
|
||||
},
|
||||
"non_motor_vehicle_parking": {
|
||||
"confirm_vehicle_sec": int, "parking_countdown_sec": int,
|
||||
"confirm_clear_sec": int, "cooldown_sec": int,
|
||||
},
|
||||
"garbage": {
|
||||
"confirm_garbage_sec": int, "confirm_clear_sec": int,
|
||||
"cooldown_sec": int,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, working_hours: Optional[List[Dict]] = None):
|
||||
self.algorithms: Dict[str, Dict[str, Any]] = {}
|
||||
self.working_hours = working_hours or []
|
||||
self._update_lock = threading.Lock()
|
||||
self._registered_keys: set = set() # 已注册的 (roi_id, bind_id, algo_type) 缓存
|
||||
self._global_params: Dict[str, Dict] = {} # 全局参数 {algo_code: params_dict}
|
||||
|
||||
# Bug fix: 默认参数与算法构造函数一致
|
||||
self.default_params = {
|
||||
@@ -1318,12 +1855,60 @@ class AlgorithmManager:
|
||||
"cooldown_sec": 1800, # Bug fix: 与算法构造函数默认值一致(1800,非600)
|
||||
"target_classes": ["car", "truck", "bus", "motorcycle"],
|
||||
},
|
||||
"non_motor_vehicle_parking": {
|
||||
"confirm_vehicle_sec": 10,
|
||||
"parking_countdown_sec": 180,
|
||||
"confirm_clear_sec": 60,
|
||||
"cooldown_sec": 900,
|
||||
"target_classes": ["bicycle", "motorcycle"],
|
||||
},
|
||||
"garbage": {
|
||||
"confirm_garbage_sec": 60,
|
||||
"confirm_clear_sec": 60,
|
||||
"cooldown_sec": 1800,
|
||||
"target_classes": ["garbage"],
|
||||
},
|
||||
}
|
||||
|
||||
self._pubsub = None
|
||||
self._pubsub_thread = None
|
||||
self._running = False
|
||||
|
||||
def update_global_params(self, global_params_map: Dict[str, Dict]):
|
||||
"""更新全局参数
|
||||
|
||||
Args:
|
||||
global_params_map: {algo_code: params_dict} 格式的全局参数
|
||||
"""
|
||||
with self._update_lock:
|
||||
self._global_params = global_params_map or {}
|
||||
logger.info(f"全局参数已更新: {list(self._global_params.keys())}")
|
||||
|
||||
def _coerce_param_types(self, algorithm_type: str, params: dict) -> dict:
|
||||
"""强制转换参数类型,防止字符串型数字导致算法异常"""
|
||||
type_map = self._PARAM_TYPES.get(algorithm_type, {})
|
||||
for key, expected_type in type_map.items():
|
||||
if key in params and params[key] is not None:
|
||||
try:
|
||||
if not isinstance(params[key], expected_type):
|
||||
params[key] = expected_type(params[key])
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"参数类型转换失败: {algorithm_type}.{key}={params[key]!r}, 删除使用默认值")
|
||||
del params[key]
|
||||
return params
|
||||
|
||||
def get_min_alarm_duration(self, algorithm_type: str) -> Optional[int]:
|
||||
"""从全局参数获取最小告警持续时间(秒)
|
||||
|
||||
Args:
|
||||
algorithm_type: 算法类型(如 leave_post, intrusion)
|
||||
|
||||
Returns:
|
||||
最小告警持续时间秒数,未配置返回 None
|
||||
"""
|
||||
gp = self._global_params.get(algorithm_type, {})
|
||||
return gp.get("min_alarm_duration_sec")
|
||||
|
||||
def start_config_subscription(self):
|
||||
"""启动配置变更订阅"""
|
||||
try:
|
||||
@@ -1421,6 +2006,9 @@ class AlgorithmManager:
|
||||
else:
|
||||
params = {}
|
||||
|
||||
# 强制转换参数类型(防止字符串型数字)
|
||||
params = self._coerce_param_types(algo_code, params)
|
||||
|
||||
if roi_id not in self.algorithms:
|
||||
self.algorithms[roi_id] = {}
|
||||
|
||||
@@ -1537,6 +2125,66 @@ class AlgorithmManager:
|
||||
dissipation_ratio=algo_params["dissipation_ratio"],
|
||||
)
|
||||
logger.info(f"已从Redis加载拥堵算法: {key}")
|
||||
elif algo_code == "non_motor_vehicle_parking":
|
||||
configured_alarm_level = params.get("alarm_level")
|
||||
algo_params = {
|
||||
"confirm_vehicle_sec": params.get("confirm_vehicle_sec", 10),
|
||||
"parking_countdown_sec": params.get("parking_countdown_sec", 180),
|
||||
"confirm_clear_sec": params.get("confirm_clear_sec", 60),
|
||||
"cooldown_sec": params.get("cooldown_sec", 900),
|
||||
"target_classes": params.get("target_classes", ["bicycle", "motorcycle"]),
|
||||
}
|
||||
if key in self.algorithms.get(roi_id, {}) and "non_motor_vehicle_parking" in self.algorithms[roi_id].get(key, {}):
|
||||
algo = self.algorithms[roi_id][key]["non_motor_vehicle_parking"]
|
||||
algo.confirm_vehicle_sec = algo_params["confirm_vehicle_sec"]
|
||||
algo.parking_countdown_sec = algo_params["parking_countdown_sec"]
|
||||
algo.confirm_clear_sec = algo_params["confirm_clear_sec"]
|
||||
algo.cooldown_sec = algo_params["cooldown_sec"]
|
||||
algo.target_classes = algo_params["target_classes"]
|
||||
if configured_alarm_level is not None:
|
||||
algo._alarm_level = configured_alarm_level
|
||||
logger.info(f"已热更新非机动车违停算法参数: {key}")
|
||||
else:
|
||||
self.algorithms[roi_id][key] = {}
|
||||
self.algorithms[roi_id][key]["non_motor_vehicle_parking"] = NonMotorVehicleParkingAlgorithm(
|
||||
confirm_vehicle_sec=algo_params["confirm_vehicle_sec"],
|
||||
parking_countdown_sec=algo_params["parking_countdown_sec"],
|
||||
confirm_clear_sec=algo_params["confirm_clear_sec"],
|
||||
cooldown_sec=algo_params["cooldown_sec"],
|
||||
target_classes=algo_params["target_classes"],
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
logger.info(f"已从Redis加载非机动车违停算法: {key}")
|
||||
elif algo_code == "garbage":
|
||||
configured_alarm_level = params.get("alarm_level")
|
||||
algo_params = {
|
||||
"confirm_garbage_sec": params.get("confirm_garbage_sec", 60),
|
||||
"confirm_clear_sec": params.get("confirm_clear_sec", 60),
|
||||
"cooldown_sec": params.get("cooldown_sec", 1800),
|
||||
"target_classes": params.get("target_classes", ["garbage"]),
|
||||
}
|
||||
if key in self.algorithms.get(roi_id, {}) and "garbage" in self.algorithms[roi_id].get(key, {}):
|
||||
algo = self.algorithms[roi_id][key]["garbage"]
|
||||
algo.confirm_garbage_sec = algo_params["confirm_garbage_sec"]
|
||||
algo.confirm_clear_sec = algo_params["confirm_clear_sec"]
|
||||
algo.cooldown_sec = algo_params["cooldown_sec"]
|
||||
algo.target_classes = algo_params["target_classes"]
|
||||
if configured_alarm_level is not None:
|
||||
algo._alarm_level = configured_alarm_level
|
||||
logger.info(f"已热更新垃圾检测算法参数: {key}")
|
||||
else:
|
||||
if roi_id not in self.algorithms:
|
||||
self.algorithms[roi_id] = {}
|
||||
if key not in self.algorithms[roi_id]:
|
||||
self.algorithms[roi_id][key] = {}
|
||||
self.algorithms[roi_id][key]["garbage"] = GarbageDetectionAlgorithm(
|
||||
confirm_garbage_sec=algo_params["confirm_garbage_sec"],
|
||||
confirm_clear_sec=algo_params["confirm_clear_sec"],
|
||||
cooldown_sec=algo_params["cooldown_sec"],
|
||||
target_classes=algo_params["target_classes"],
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
logger.info(f"已从Redis加载垃圾检测算法: {key}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -1591,6 +2239,9 @@ class AlgorithmManager:
|
||||
params = json.loads(params_str) if isinstance(params_str, str) else params_str
|
||||
algo_code = bind_config.get("algo_code")
|
||||
|
||||
# 强制转换参数类型(防止字符串型数字)
|
||||
params = self._coerce_param_types(algo_code, params)
|
||||
|
||||
# 获取现有算法实例
|
||||
existing_algo = self.algorithms[roi_id][key].get(algo_code)
|
||||
|
||||
@@ -1661,6 +2312,31 @@ class AlgorithmManager:
|
||||
|
||||
logger.info(f"[{roi_id}_{bind_id}] 更新拥堵检测参数")
|
||||
|
||||
elif algo_code == "non_motor_vehicle_parking":
|
||||
existing_algo.confirm_vehicle_sec = params.get("confirm_vehicle_sec", 10)
|
||||
existing_algo.parking_countdown_sec = params.get("parking_countdown_sec", 180)
|
||||
existing_algo.confirm_clear_sec = params.get("confirm_clear_sec", 60)
|
||||
existing_algo.cooldown_sec = params.get("cooldown_sec", 900)
|
||||
if "target_classes" in params:
|
||||
existing_algo.target_classes = params["target_classes"]
|
||||
alarm_level = params.get("alarm_level")
|
||||
if alarm_level is not None:
|
||||
existing_algo._alarm_level = alarm_level
|
||||
|
||||
logger.info(f"[{roi_id}_{bind_id}] 更新非机动车违停检测参数")
|
||||
|
||||
elif algo_code == "garbage":
|
||||
existing_algo.confirm_garbage_sec = params.get("confirm_garbage_sec", 60)
|
||||
existing_algo.confirm_clear_sec = params.get("confirm_clear_sec", 60)
|
||||
existing_algo.cooldown_sec = params.get("cooldown_sec", 1800)
|
||||
if "target_classes" in params:
|
||||
existing_algo.target_classes = params["target_classes"]
|
||||
alarm_level = params.get("alarm_level")
|
||||
if alarm_level is not None:
|
||||
existing_algo._alarm_level = alarm_level
|
||||
|
||||
logger.info(f"[{roi_id}_{bind_id}] 更新垃圾检测参数")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -1757,9 +2433,17 @@ class AlgorithmManager:
|
||||
self.algorithms[roi_id][key] = {}
|
||||
|
||||
algo_params = self.default_params.get(algorithm_type, {}).copy()
|
||||
# 三级合并:默认参数 → 全局参数 → 绑定级参数
|
||||
with self._update_lock:
|
||||
global_p = self._global_params.get(algorithm_type, {}).copy()
|
||||
if global_p:
|
||||
algo_params.update(global_p)
|
||||
if params:
|
||||
algo_params.update(params)
|
||||
|
||||
# 强制转换参数类型(防止字符串型数字)
|
||||
algo_params = self._coerce_param_types(algorithm_type, algo_params)
|
||||
|
||||
# 从 params 中提取告警等级(前端配置下发)
|
||||
configured_alarm_level = algo_params.get("alarm_level")
|
||||
|
||||
@@ -1800,6 +2484,23 @@ class AlgorithmManager:
|
||||
alarm_level=configured_alarm_level,
|
||||
dissipation_ratio=algo_params.get("dissipation_ratio", 0.5),
|
||||
)
|
||||
elif algorithm_type == "non_motor_vehicle_parking":
|
||||
self.algorithms[roi_id][key]["non_motor_vehicle_parking"] = NonMotorVehicleParkingAlgorithm(
|
||||
confirm_vehicle_sec=algo_params.get("confirm_vehicle_sec", 10),
|
||||
parking_countdown_sec=algo_params.get("parking_countdown_sec", 180),
|
||||
confirm_clear_sec=algo_params.get("confirm_clear_sec", 60),
|
||||
cooldown_sec=algo_params.get("cooldown_sec", 900),
|
||||
target_classes=algo_params.get("target_classes", ["bicycle", "motorcycle"]),
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
elif algorithm_type == "garbage":
|
||||
self.algorithms[roi_id][key]["garbage"] = GarbageDetectionAlgorithm(
|
||||
confirm_garbage_sec=algo_params.get("confirm_garbage_sec", 60),
|
||||
confirm_clear_sec=algo_params.get("confirm_clear_sec", 60),
|
||||
cooldown_sec=algo_params.get("cooldown_sec", 1800),
|
||||
target_classes=algo_params.get("target_classes", ["garbage"]),
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
|
||||
self._registered_keys.add(cache_key)
|
||||
|
||||
@@ -1892,7 +2593,7 @@ class AlgorithmManager:
|
||||
"state": getattr(algo, "state", "WAITING"),
|
||||
"alarm_sent": getattr(algo, "alarm_sent", False),
|
||||
}
|
||||
elif algo_type in ("illegal_parking", "vehicle_congestion"):
|
||||
elif algo_type in ("illegal_parking", "vehicle_congestion", "non_motor_vehicle_parking", "garbage"):
|
||||
status[f"{algo_type}_{bind_id}"] = algo.get_state()
|
||||
else:
|
||||
status[f"{algo_type}_{bind_id}"] = {
|
||||
|
||||
@@ -22,6 +22,7 @@ class AlgorithmType(str, Enum):
|
||||
INTRUSION = "intrusion"
|
||||
ILLEGAL_PARKING = "illegal_parking"
|
||||
VEHICLE_CONGESTION = "vehicle_congestion"
|
||||
NON_MOTOR_VEHICLE_PARKING = "non_motor_vehicle_parking"
|
||||
CROWD_DETECTION = "crowd_detection"
|
||||
FACE_RECOGNITION = "face_recognition"
|
||||
|
||||
|
||||
@@ -259,6 +259,15 @@ class SQLiteManager:
|
||||
except Exception:
|
||||
pass # 列已存在,忽略
|
||||
|
||||
# 算法全局参数表
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS algo_global_params (
|
||||
algo_code TEXT PRIMARY KEY,
|
||||
params TEXT NOT NULL DEFAULT '{}',
|
||||
updated_at TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
self._init_default_algorithms()
|
||||
|
||||
def _init_default_algorithms(self):
|
||||
@@ -948,6 +957,39 @@ class SQLiteManager:
|
||||
logger.error(f"获取所有绑定ID失败: {e}")
|
||||
return []
|
||||
|
||||
def save_global_params(self, algo_code: str, params_dict: Dict[str, Any]) -> bool:
|
||||
"""保存算法全局参数(INSERT OR REPLACE)"""
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
now = datetime.now().isoformat()
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO algo_global_params (algo_code, params, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
""", (algo_code, json.dumps(params_dict, ensure_ascii=False), now))
|
||||
self._conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"保存算法全局参数失败: {e}")
|
||||
return False
|
||||
|
||||
def get_all_global_params(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取所有算法全局参数,返回 {algo_code: params_dict}"""
|
||||
result: Dict[str, Dict[str, Any]] = {}
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute("SELECT algo_code, params FROM algo_global_params")
|
||||
for row in cursor.fetchall():
|
||||
algo_code = row[0]
|
||||
params_str = row[1]
|
||||
try:
|
||||
result[algo_code] = json.loads(params_str) if params_str else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
result[algo_code] = {}
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"获取算法全局参数失败: {e}")
|
||||
return result
|
||||
|
||||
def log_config_update(
|
||||
self,
|
||||
config_type: str,
|
||||
|
||||
@@ -15,6 +15,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import socket
|
||||
|
||||
# 禁用系统代理(Clash 等代理工具会干扰 Redis TCP 长连接)
|
||||
@@ -53,6 +54,12 @@ def _build_keepalive_options():
|
||||
return opts
|
||||
|
||||
|
||||
def _camel_to_snake(name: str) -> str:
|
||||
"""将 camelCase 转换为 snake_case"""
|
||||
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||||
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
||||
|
||||
|
||||
# ==================== Redis Key 常量 ====================
|
||||
|
||||
# 云端 Redis Keys
|
||||
@@ -643,6 +650,21 @@ class ConfigSyncManager:
|
||||
# 清理 SQLite 中不在本次推送列表中的旧数据
|
||||
self._cleanup_stale_records(incoming_camera_ids, incoming_roi_ids, incoming_bind_ids)
|
||||
|
||||
# 同步全局参数
|
||||
global_params = config_data.get("global_params") or config_data.get("globalParams") or {}
|
||||
if global_params and isinstance(global_params, dict):
|
||||
for algo_code, params_dict in global_params.items():
|
||||
if isinstance(params_dict, dict):
|
||||
# 防御性转换:camelCase → snake_case
|
||||
params_dict = {_camel_to_snake(k): v for k, v in params_dict.items()}
|
||||
self._db_manager.save_global_params(algo_code, params_dict)
|
||||
logger.info(f"全局参数同步完成: {list(global_params.keys())}")
|
||||
|
||||
# 通知全局参数更新回调
|
||||
self._notify_callbacks("global_params_update", {
|
||||
"global_params": global_params,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"配置同步到 SQLite 失败: {e}")
|
||||
|
||||
|
||||
46
docker-compose.yml
Normal file
46
docker-compose.yml
Normal file
@@ -0,0 +1,46 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
edge-inference:
|
||||
build: .
|
||||
image: edge-inference:latest
|
||||
container_name: edge-inference
|
||||
restart: always
|
||||
|
||||
# GPU 访问
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
# 环境变量
|
||||
env_file:
|
||||
- .env
|
||||
|
||||
# 卷挂载
|
||||
volumes:
|
||||
- ./models:/app/models # TensorRT 引擎文件
|
||||
- ./data:/app/data # SQLite + 截图缓存
|
||||
- ./logs:/app/logs # 运行日志
|
||||
- ./.env:/app/.env # 环境配置
|
||||
|
||||
# 网络(需要访问摄像头 RTSP + 云端 API + Redis)
|
||||
network_mode: host
|
||||
|
||||
# 健康检查
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import os; assert os.path.exists('/app/main.py')"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 30s
|
||||
|
||||
# 日志限制
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "5"
|
||||
194
docs/garbage_algorithm_backend_frontend_plan.md
Normal file
194
docs/garbage_algorithm_backend_frontend_plan.md
Normal file
@@ -0,0 +1,194 @@
|
||||
# 垃圾检测算法 - WVP 后端 / 前端改动方案(未实施,预留参考)
|
||||
|
||||
## 背景
|
||||
|
||||
Edge 端的 `GarbageDetectionAlgorithm` 已实现(commit xxx)。本文档列出 WVP 后端和前端需要做的配套改动,等后续需要在 ROI 编辑器创建垃圾检测绑定时再实施。
|
||||
|
||||
---
|
||||
|
||||
## 一、WVP 后端改动
|
||||
|
||||
### 1.1 算法注册(数据库)
|
||||
|
||||
**文件:** `wvp-platform/数据库/版本号/SQL 脚本`
|
||||
|
||||
**新增算法记录:**
|
||||
```sql
|
||||
INSERT INTO wvp_ai_algorithm (
|
||||
algo_code, algo_name, algo_description,
|
||||
is_active, param_schema, global_params
|
||||
) VALUES (
|
||||
'garbage',
|
||||
'垃圾检测',
|
||||
'检测监控区域内散落垃圾的持续存在,清洁后自动解除告警',
|
||||
1,
|
||||
'{"confirm_garbage_sec": {"type": "int", "default": 60, "min": 10, "max": 600, "label": "垃圾确认时间(秒)"},
|
||||
"confirm_clear_sec": {"type": "int", "default": 60, "min": 10, "max": 600, "label": "清理确认时间(秒)"},
|
||||
"cooldown_sec": {"type": "int", "default": 1800, "min": 300, "max": 7200, "label": "告警冷却时间(秒)"},
|
||||
"alarm_level": {"type": "int", "default": 2, "min": 0, "max": 3, "label": "告警等级"}
|
||||
}',
|
||||
'{}'
|
||||
);
|
||||
```
|
||||
|
||||
### 1.2 Java 算法服务
|
||||
|
||||
**文件:** `wvp-platform/src/main/java/com/genersoft/iot/vmp/aiot/service/impl/AiAlgorithmServiceImpl.java`
|
||||
|
||||
**改动:** 算法代码白名单(约 line 42-54),添加 `"garbage"`:
|
||||
|
||||
```java
|
||||
private static final Set<String> SUPPORTED_ALGO_CODES = Set.of(
|
||||
"leave_post", "intrusion", "illegal_parking",
|
||||
"vehicle_congestion", "non_motor_vehicle_parking",
|
||||
"garbage" // 新增
|
||||
);
|
||||
```
|
||||
|
||||
### 1.3 配置下发
|
||||
|
||||
不需要改动。现有 `AiRedisConfigServiceImpl` 和 `global_params` 机制已通用。
|
||||
|
||||
---
|
||||
|
||||
## 二、iot-device-management-service 改动
|
||||
|
||||
**文件:** `app/constants.py`
|
||||
|
||||
检查 `AlarmType` 枚举是否已有 `GARBAGE`:
|
||||
- 已有 → 无需改动
|
||||
- 未有 → 添加:
|
||||
|
||||
```python
|
||||
class AlarmType(str, Enum):
|
||||
LEAVE_POST = "leave_post"
|
||||
INTRUSION = "intrusion"
|
||||
ILLEGAL_PARKING = "illegal_parking"
|
||||
VEHICLE_CONGESTION = "vehicle_congestion"
|
||||
NON_MOTOR_VEHICLE_PARKING = "non_motor_vehicle_parking"
|
||||
GARBAGE = "garbage" # 新增
|
||||
|
||||
ALARM_TYPE_NAMES: Dict[str, str] = {
|
||||
...
|
||||
AlarmType.GARBAGE: "垃圾检测",
|
||||
}
|
||||
```
|
||||
|
||||
**文件:** `app/services/vlm_service.py`(VLM 复核提示词)
|
||||
|
||||
添加 garbage 的提示词模板:
|
||||
```python
|
||||
"garbage": """你是安防监控AI复核员。算法类型:垃圾检测,监控区域:{roi_name}。
|
||||
截图显示时间:{timestamp}。
|
||||
任务:判断图中是否真的存在散落的垃圾、包装袋、废弃物等需要清理的物品。
|
||||
分析要点:
|
||||
1. 是否存在明显的垃圾(垃圾袋、纸屑、瓶罐、食品包装等)
|
||||
2. 区分垃圾与正常物品(整齐放置的物品、装饰品不算垃圾)
|
||||
3. 垃圾是否在通道/地面等不该出现的位置
|
||||
4. 排除阴影、污渍、地砖花纹等误检
|
||||
仅输出JSON:{{"confirmed":true,"description":"..."}}""",
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 三、前端改动(iot-device-management-frontend)
|
||||
|
||||
### 3.1 告警列表类型筛选
|
||||
|
||||
**文件:** `apps/web-antd/src/views/aiot/alarm/list/data.ts`
|
||||
|
||||
```typescript
|
||||
export const ALERT_TYPE_OPTIONS = [
|
||||
{ label: '人员离岗', value: 'leave_post' },
|
||||
{ label: '周界入侵', value: 'intrusion' },
|
||||
{ label: '车辆违停', value: 'illegal_parking' },
|
||||
{ label: '车辆拥堵', value: 'vehicle_congestion' },
|
||||
{ label: '非机动车违停', value: 'non_motor_vehicle_parking' },
|
||||
{ label: '垃圾检测', value: 'garbage' }, // 新增
|
||||
];
|
||||
```
|
||||
|
||||
### 3.2 告警标签颜色
|
||||
|
||||
**文件:** `apps/web-antd/src/views/aiot/alarm/list/index.vue`
|
||||
|
||||
```typescript
|
||||
const colorMap: Record<string, string> = {
|
||||
leave_post: 'orange',
|
||||
intrusion: 'red',
|
||||
illegal_parking: 'blue',
|
||||
vehicle_congestion: 'geekblue',
|
||||
non_motor_vehicle_parking: 'green',
|
||||
garbage: 'gold', // 新增 — 建议金色表达清洁主题
|
||||
};
|
||||
```
|
||||
|
||||
### 3.3 ROI 编辑器参数配置
|
||||
|
||||
**文件:** `apps/web-antd/src/views/aiot/device/roi/components/AlgorithmParamEditor.vue`
|
||||
|
||||
参数字段已通用(读自 `algo.paramSchema`),无需改动。**但需要添加参数名中文映射:**
|
||||
|
||||
**文件:** `AlgorithmParamEditor.vue` 的 `paramNameMap` 和 `paramDescMap`
|
||||
|
||||
```typescript
|
||||
const paramNameMap: Record<string, string> = {
|
||||
// ... 已有
|
||||
confirm_garbage_sec: '垃圾确认时间(秒)',
|
||||
};
|
||||
|
||||
const paramDescMap: Record<string, string> = {
|
||||
// ... 已有
|
||||
confirm_garbage_sec: '持续检测到垃圾的时间,超过该时间触发告警(建议 60-120 秒)',
|
||||
};
|
||||
```
|
||||
|
||||
### 3.4 全局参数配置页
|
||||
|
||||
**文件:** `apps/web-antd/src/views/aiot/device/algorithm/index.vue`
|
||||
|
||||
`paramNameMap` 和 `paramDescMap` 同样需要添加 `confirm_garbage_sec` 条目(参见 3.3)。
|
||||
|
||||
---
|
||||
|
||||
## 四、验证顺序(将来实施时)
|
||||
|
||||
1. **后端数据库注册算法记录**
|
||||
2. **WVP 后端重启** — 白名单生效
|
||||
3. **Service 端** constants.py 添加(如需要)
|
||||
4. **前端重启** — 下拉选项和颜色生效
|
||||
5. **ROI 编辑器创建一个 garbage 绑定,参数用默认值**
|
||||
6. **前端触发配置推送** — 验证 Edge 端收到并注册算法
|
||||
7. **Edge 日志验证:** 应看到 `已从Redis加载垃圾检测算法: roi_xxx_bind_xxx`
|
||||
8. **模拟测试:** 放个垃圾在摄像头前,60 秒后应触发告警
|
||||
9. **清理测试:** 移除垃圾 30 秒后应收到 resolve 事件
|
||||
10. **企微卡片收到告警 + 创建工单全流程**
|
||||
|
||||
---
|
||||
|
||||
## 五、TensorRT 引擎部署(最后一步)
|
||||
|
||||
当确定用微调模型替换 COCO 预训练模型时:
|
||||
|
||||
1. **导出 engine:**
|
||||
```bash
|
||||
yolo export model=yolo11s_v1_20260417.pt format=engine imgsz=480 half=True device=0
|
||||
```
|
||||
|
||||
2. **替换 Edge 端模型:**
|
||||
```bash
|
||||
cp yolo11s_v1_20260417.engine /opt/edge/models/yolo11n.engine # 注意文件名
|
||||
```
|
||||
|
||||
3. **修改 `config/settings.py` 的 COCO_CLASS_NAMES:**
|
||||
```python
|
||||
COCO_CLASS_NAMES = ['garbage', 'person', 'car', 'bicycle', 'motorcycle']
|
||||
```
|
||||
|
||||
4. **修改 `core/postprocessor.py` 的输出解析:**
|
||||
- YOLO 输出从 `[84, 8400]`(4+80类)变为 `[9, 8400]`(4+5类)
|
||||
- 类别分数范围从 `output[4:84]` 改为 `output[4:9]`
|
||||
|
||||
5. **重启 Edge 服务**
|
||||
|
||||
这一步涉及模型 + 推理管线,需要单独在生产环境测试。
|
||||
60
main.py
60
main.py
@@ -18,7 +18,6 @@ for _key in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "all_prox
|
||||
|
||||
from config.settings import get_settings, Settings
|
||||
from core.config_sync import get_config_sync_manager, ConfigSyncManager
|
||||
from core.debug_http_server import start_debug_http_server
|
||||
from core.video_stream import MultiStreamManager, VideoFrame
|
||||
from core.preprocessor import ImagePreprocessor
|
||||
from core.tensorrt_engine import TensorRTEngine, EngineManager
|
||||
@@ -56,8 +55,6 @@ class EdgeInferenceService:
|
||||
self._screenshot_handler: Optional[ScreenshotHandler] = None
|
||||
self._algorithm_manager: Optional[AlgorithmManager] = None
|
||||
self._debug_reload_thread: Optional[threading.Thread] = None
|
||||
self._debug_http_server = None
|
||||
self._debug_http_thread: Optional[threading.Thread] = None
|
||||
self._heartbeat_thread: Optional[threading.Thread] = None
|
||||
self._scheduler_thread: Optional[threading.Thread] = None
|
||||
|
||||
@@ -132,6 +129,18 @@ class EdgeInferenceService:
|
||||
daemon=True
|
||||
).start()
|
||||
self._config_manager.register_callback("config_update", _on_config_update)
|
||||
|
||||
def _on_global_params_update(topic, data):
|
||||
if self._algorithm_manager:
|
||||
global_params = data.get("global_params", {})
|
||||
self._algorithm_manager.update_global_params(global_params)
|
||||
# 只清除受影响算法的注册缓存,避免无关算法状态丢失
|
||||
affected_algos = set(global_params.keys())
|
||||
keys_to_remove = [k for k in self._algorithm_manager._registered_keys if k[2] in affected_algos]
|
||||
for key in keys_to_remove:
|
||||
self._algorithm_manager._registered_keys.discard(key)
|
||||
self._logger.info(f"全局参数回调已触发,清除 {len(keys_to_remove)} 个受影响算法的注册缓存")
|
||||
self._config_manager.register_callback("global_params_update", _on_global_params_update)
|
||||
self._logger.info("配置管理器初始化成功")
|
||||
except Exception as e:
|
||||
self._logger.error(f"配置管理器初始化失败: {e}")
|
||||
@@ -198,6 +207,18 @@ class EdgeInferenceService:
|
||||
try:
|
||||
self._algorithm_manager = AlgorithmManager()
|
||||
self._algorithm_manager.start_config_subscription()
|
||||
|
||||
# 启动时从 SQLite 加载已有全局参数
|
||||
try:
|
||||
from config.database import get_sqlite_manager
|
||||
db = get_sqlite_manager()
|
||||
saved_global_params = db.get_all_global_params()
|
||||
if saved_global_params:
|
||||
self._algorithm_manager.update_global_params(saved_global_params)
|
||||
self._logger.info(f"从 SQLite 加载全局参数: {list(saved_global_params.keys())}")
|
||||
except Exception as e:
|
||||
self._logger.warning(f"从 SQLite 加载全局参数失败: {e}")
|
||||
|
||||
self._logger.info("算法管理器初始化成功")
|
||||
except Exception as e:
|
||||
self._logger.error(f"算法管理器初始化失败: {e}")
|
||||
@@ -281,32 +302,6 @@ class EdgeInferenceService:
|
||||
)
|
||||
self._debug_reload_thread.start()
|
||||
|
||||
def _start_debug_http_server(self):
|
||||
"""本地调试:启动 HTTP 同步接口"""
|
||||
if self._settings.config_sync_mode != "LOCAL":
|
||||
return
|
||||
if not getattr(self._settings, "debug", None) or not self._settings.debug.enabled:
|
||||
return
|
||||
if self._debug_http_server is not None:
|
||||
return
|
||||
|
||||
host = self._settings.debug.host
|
||||
port = self._settings.debug.port
|
||||
self._debug_http_server = start_debug_http_server(host, port)
|
||||
|
||||
def worker():
|
||||
try:
|
||||
self._debug_http_server.serve_forever()
|
||||
except Exception as e:
|
||||
self._logger.warning(f"[DEBUG] HTTP 服务器异常: {e}")
|
||||
|
||||
self._debug_http_thread = threading.Thread(
|
||||
target=worker,
|
||||
name="DebugHttpServer",
|
||||
daemon=True,
|
||||
)
|
||||
self._debug_http_thread.start()
|
||||
|
||||
def _start_heartbeat(self):
|
||||
"""启动心跳守护线程,每 30 秒向云端上报设备状态"""
|
||||
def worker():
|
||||
@@ -379,7 +374,6 @@ class EdgeInferenceService:
|
||||
self._init_algorithm_manager()
|
||||
self._init_screenshot_handler()
|
||||
self._start_debug_reload_watcher()
|
||||
self._start_debug_http_server()
|
||||
self._start_heartbeat()
|
||||
|
||||
self._performance_stats["start_time"] = datetime.now()
|
||||
@@ -1098,12 +1092,6 @@ class EdgeInferenceService:
|
||||
if self._reporter:
|
||||
self._reporter.close()
|
||||
|
||||
if self._debug_http_server:
|
||||
try:
|
||||
self._debug_http_server.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._performance_stats["uptime_seconds"] = (
|
||||
(datetime.now() - self._performance_stats["start_time"]).total_seconds()
|
||||
)
|
||||
|
||||
@@ -1,64 +1,78 @@
|
||||
# Edge_Inference_Service 依赖清单
|
||||
# 安装命令: pip install -r requirements.txt
|
||||
# 备注:所有版本均选择最稳定版本,经过大量验证
|
||||
# 环境要求: Python 3.10 | CUDA 12.1 | cuDNN 8.9 | TensorRT 8.6.1
|
||||
# Docker 基础镜像: nvcr.io/nvidia/tensorrt:23.08-py3
|
||||
|
||||
# ============================================================
|
||||
# 核心依赖(必需)
|
||||
# GPU 推理依赖(TensorRT 8.6 + CUDA 12.1)
|
||||
# ============================================================
|
||||
|
||||
# 视频处理 - OpenCV 4.8.0,最稳定的4.x版本
|
||||
opencv-python==4.8.0.74
|
||||
# PyTorch - CUDA 12.1 下最稳定版本
|
||||
# 安装命令: pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121
|
||||
--extra-index-url https://download.pytorch.org/whl/cu121
|
||||
torch==2.1.2
|
||||
torchvision==0.16.2
|
||||
|
||||
# 数值计算 - NumPy 1.24.0,Python 3.8-3.11完美兼容
|
||||
numpy==1.24.0
|
||||
# TensorRT Python 绑定(NGC 镜像已预装,裸机需手动安装)
|
||||
tensorrt==8.6.1.6
|
||||
pycuda==2023.1.1
|
||||
|
||||
# YOLO11 目标检测框架
|
||||
ultralytics==8.3.5
|
||||
|
||||
# ONNX 模型转换与优化
|
||||
onnx==1.16.0
|
||||
onnxsim==0.4.36
|
||||
onnxruntime-gpu==1.17.1
|
||||
|
||||
# ============================================================
|
||||
# 核心依赖
|
||||
# ============================================================
|
||||
|
||||
# 视频处理
|
||||
opencv-python==4.8.0.76
|
||||
|
||||
# 数值计算(锁定 1.x,避开 NumPy 2.0 破坏性变更)
|
||||
numpy==1.26.4
|
||||
|
||||
# 图像处理
|
||||
Pillow==10.2.0
|
||||
|
||||
# ============================================================
|
||||
# 数据库依赖
|
||||
# ============================================================
|
||||
|
||||
# ORM框架 - SQLAlchemy 2.0.23,长期支持稳定版
|
||||
# ORM 框架
|
||||
sqlalchemy==2.0.23
|
||||
|
||||
# MySQL驱动 - PyMySQL 1.1.0,成熟稳定版本
|
||||
# MySQL 驱动
|
||||
pymysql==1.1.0
|
||||
|
||||
# ============================================================
|
||||
# 消息队列与缓存
|
||||
# ============================================================
|
||||
|
||||
# MQTT客户端 - Paho-MQTT 1.6.1,1.x最终稳定版
|
||||
# MQTT 客户端(1.x 最终稳定版)
|
||||
paho-mqtt==1.6.1
|
||||
|
||||
# Redis客户端 - Redis 4.6.0,4.x最终稳定版
|
||||
# Redis 客户端
|
||||
redis==4.6.0
|
||||
|
||||
# 腾讯云COS SDK - 用于截图上传
|
||||
# 腾讯云 COS SDK(截图上传)
|
||||
cos-python-sdk-v5>=1.9.30
|
||||
|
||||
# ============================================================
|
||||
# 工具库
|
||||
# ============================================================
|
||||
|
||||
# YAML解析 - PyYAML 6.0.1,安全稳定版
|
||||
pyyaml==6.0.1
|
||||
requests==2.31.0
|
||||
psutil==5.9.8
|
||||
python-dotenv==1.0.1
|
||||
|
||||
# ============================================================
|
||||
# 测试框架
|
||||
# 测试依赖
|
||||
# ============================================================
|
||||
|
||||
# 单元测试 - PyTest 7.4.4,7.x最终稳定版
|
||||
pytest==7.4.4
|
||||
|
||||
# 覆盖率报告 - PyTest-Cov 4.1.0,成熟稳定版
|
||||
pytest-cov==4.1.0
|
||||
|
||||
# ============================================================
|
||||
# 可选依赖(按需安装)
|
||||
# ============================================================
|
||||
|
||||
# GPU推理框架(需要CUDA 12.1环境)
|
||||
# tensorrt==8.6.1.6
|
||||
# pycuda==2023.1.1
|
||||
|
||||
# YOLOv8目标检测(按需安装)
|
||||
# ultralytics==8.0.228
|
||||
|
||||
314
test_garbage_algorithm.py
Normal file
314
test_garbage_algorithm.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
GarbageDetectionAlgorithm 单元测试
|
||||
|
||||
覆盖场景:
|
||||
1. 无垃圾时保持 IDLE
|
||||
2. 持续检测到垃圾 → 确认 → 告警
|
||||
3. 冷却期内不重复触发
|
||||
4. 清理后发 resolve → 回到 IDLE
|
||||
5. 清理确认期内垃圾再次出现 → 回到 ALARMED
|
||||
6. reset() 正确清理状态
|
||||
"""
|
||||
import sys
|
||||
import os
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s - %(message)s')
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from algorithms import GarbageDetectionAlgorithm
|
||||
|
||||
|
||||
# ===== 工具函数 =====
|
||||
|
||||
def make_tracks(roi_id: str, classes: list, confidences: list = None):
|
||||
"""生成模拟检测结果"""
|
||||
if confidences is None:
|
||||
confidences = [0.85] * len(classes)
|
||||
tracks = []
|
||||
for i, cls in enumerate(classes):
|
||||
tracks.append({
|
||||
"track_id": f"{roi_id}_{i}",
|
||||
"class": cls,
|
||||
"confidence": confidences[i],
|
||||
"bbox": [100 + i * 50, 100, 200 + i * 50, 300],
|
||||
"matched_rois": [{"roi_id": roi_id}],
|
||||
})
|
||||
return tracks
|
||||
|
||||
|
||||
def simulate(algo, roi_id, camera_id, get_tracks_fn, count, interval=1.0, start_time=None):
|
||||
"""连续模拟帧,返回所有 alerts 和最后时间戳"""
|
||||
t = start_time or datetime(2026, 4, 17, 10, 0, 0)
|
||||
all_alerts = []
|
||||
for i in range(count):
|
||||
tracks = get_tracks_fn(i)
|
||||
alerts = algo.process(roi_id, camera_id, tracks, t)
|
||||
if alerts:
|
||||
all_alerts.extend(alerts)
|
||||
t += timedelta(seconds=interval)
|
||||
return all_alerts, t
|
||||
|
||||
|
||||
# ===== 测试 1:无垃圾时保持 IDLE =====
|
||||
|
||||
def test_idle_when_no_garbage():
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 1: 无垃圾帧始终保持 IDLE")
|
||||
print("=" * 60)
|
||||
|
||||
algo = GarbageDetectionAlgorithm(confirm_garbage_sec=60)
|
||||
alerts, _ = simulate(
|
||||
algo, "roi_1", "cam_1",
|
||||
lambda i: make_tracks("roi_1", ["person"]), # 只有人,没有垃圾
|
||||
count=100,
|
||||
)
|
||||
|
||||
assert algo.state == "IDLE", f"Expected IDLE, got {algo.state}"
|
||||
assert len(alerts) == 0, f"Expected no alerts, got {len(alerts)}"
|
||||
print(f" 状态: {algo.state},alerts: {len(alerts)} [OK]")
|
||||
|
||||
|
||||
# ===== 测试 2:持续检测到垃圾 → 告警 =====
|
||||
|
||||
def test_garbage_triggers_alarm():
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 2: 持续 65 秒检测到垃圾 → 告警")
|
||||
print("=" * 60)
|
||||
|
||||
algo = GarbageDetectionAlgorithm(
|
||||
confirm_garbage_sec=60,
|
||||
cooldown_sec=1800,
|
||||
)
|
||||
|
||||
alerts, _ = simulate(
|
||||
algo, "roi_1", "cam_1",
|
||||
lambda i: make_tracks("roi_1", ["garbage"]),
|
||||
count=65, # 65 秒,超过 60 秒确认期
|
||||
)
|
||||
|
||||
# 应该在第 60-61 秒触发 1 个告警
|
||||
assert algo.state == "ALARMED", f"Expected ALARMED, got {algo.state}"
|
||||
assert len(alerts) == 1, f"Expected 1 alert, got {len(alerts)}"
|
||||
alert = alerts[0]
|
||||
assert alert["alert_type"] == "garbage"
|
||||
assert alert["alarm_level"] == 2
|
||||
assert alert["garbage_count"] == 1
|
||||
assert "检测到垃圾" in alert["message"]
|
||||
print(f" 状态: {algo.state},告警: {alert['message']} [OK]")
|
||||
|
||||
|
||||
# ===== 测试 3:冷却期内不重复触发 =====
|
||||
|
||||
def test_cooldown_prevents_duplicate():
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 3: 告警后冷却期内持续有垃圾,不重复触发")
|
||||
print("=" * 60)
|
||||
|
||||
algo = GarbageDetectionAlgorithm(
|
||||
confirm_garbage_sec=10, # 缩短便于测试
|
||||
confirm_clear_sec=10,
|
||||
cooldown_sec=300, # 5 分钟冷却
|
||||
)
|
||||
|
||||
# 持续 200 秒有垃圾(远超冷却时间但没超过 300 秒)
|
||||
alerts, _ = simulate(
|
||||
algo, "roi_1", "cam_1",
|
||||
lambda i: make_tracks("roi_1", ["garbage"]),
|
||||
count=200,
|
||||
)
|
||||
|
||||
assert len(alerts) == 1, f"Expected 1 alert (cooldown), got {len(alerts)}"
|
||||
assert algo.state == "ALARMED"
|
||||
print(f" 告警次数: {len(alerts)}(冷却期内不重复)[OK]")
|
||||
|
||||
|
||||
# ===== 测试 4:清理后发 resolve → IDLE =====
|
||||
|
||||
def test_resolve_after_cleaning():
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 4: 告警后清理 → 发 resolve → IDLE")
|
||||
print("=" * 60)
|
||||
|
||||
algo = GarbageDetectionAlgorithm(
|
||||
confirm_garbage_sec=10,
|
||||
confirm_clear_sec=10,
|
||||
cooldown_sec=300,
|
||||
)
|
||||
algo._last_alarm_id = "test_alarm_123" # 模拟 main.py 回填
|
||||
|
||||
t = datetime(2026, 4, 17, 10, 0, 0)
|
||||
all_alerts = []
|
||||
|
||||
# Phase 1: 15 秒有垃圾 → 触发告警
|
||||
for i in range(15):
|
||||
alerts = algo.process(
|
||||
"roi_1", "cam_1",
|
||||
make_tracks("roi_1", ["garbage"]),
|
||||
t + timedelta(seconds=i)
|
||||
)
|
||||
all_alerts.extend(alerts)
|
||||
assert algo.state == "ALARMED"
|
||||
|
||||
# Phase 2: 然后 30 秒无垃圾 → 发 resolve
|
||||
# 需要等滑动窗口(10s)清空 + confirm_clear_sec(10s) = 20+ 秒
|
||||
for i in range(15, 45):
|
||||
alerts = algo.process(
|
||||
"roi_1", "cam_1",
|
||||
make_tracks("roi_1", []), # 空
|
||||
t + timedelta(seconds=i)
|
||||
)
|
||||
all_alerts.extend(alerts)
|
||||
|
||||
assert algo.state == "IDLE", f"Expected IDLE, got {algo.state}"
|
||||
resolves = [a for a in all_alerts if a.get("alert_type") == "alarm_resolve"]
|
||||
assert len(resolves) == 1, f"Expected 1 resolve, got {len(resolves)}"
|
||||
resolve = resolves[0]
|
||||
assert resolve["resolve_alarm_id"] == "test_alarm_123"
|
||||
assert resolve["resolve_type"] == "garbage_removed"
|
||||
assert resolve["duration_ms"] > 0
|
||||
print(f" resolve: {resolve['resolve_type']}, 持续 {resolve['duration_ms']}ms [OK]")
|
||||
|
||||
|
||||
# ===== 测试 5:清理期内垃圾再出现 → 回到 ALARMED =====
|
||||
|
||||
def test_garbage_reappears_during_clearing():
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 5: 清理确认期内垃圾再出现 → 回到 ALARMED")
|
||||
print("=" * 60)
|
||||
|
||||
algo = GarbageDetectionAlgorithm(
|
||||
confirm_garbage_sec=10,
|
||||
confirm_clear_sec=20, # 较长的清理确认期
|
||||
cooldown_sec=300,
|
||||
)
|
||||
algo._last_alarm_id = "test_alarm_456"
|
||||
|
||||
t = datetime(2026, 4, 17, 10, 0, 0)
|
||||
|
||||
# Phase 1: 15 秒有垃圾 → 告警 → ALARMED
|
||||
for i in range(15):
|
||||
algo.process("roi_1", "cam_1", make_tracks("roi_1", ["garbage"]),
|
||||
t + timedelta(seconds=i))
|
||||
assert algo.state == "ALARMED"
|
||||
|
||||
# Phase 2: 5 秒无垃圾 → CONFIRMING_CLEAR
|
||||
for i in range(15, 25):
|
||||
algo.process("roi_1", "cam_1", make_tracks("roi_1", []),
|
||||
t + timedelta(seconds=i))
|
||||
assert algo.state == "CONFIRMING_CLEAR", f"got {algo.state}"
|
||||
|
||||
# Phase 3: 垃圾又出现 5 秒 → 回到 ALARMED
|
||||
for i in range(25, 40):
|
||||
algo.process("roi_1", "cam_1", make_tracks("roi_1", ["garbage"]),
|
||||
t + timedelta(seconds=i))
|
||||
|
||||
assert algo.state == "ALARMED", f"Expected ALARMED, got {algo.state}"
|
||||
print(f" 状态恢复: CONFIRMING_CLEAR → ALARMED [OK]")
|
||||
|
||||
|
||||
# ===== 测试 6:reset() 清理状态 =====
|
||||
|
||||
def test_reset_clears_state():
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 6: reset() 正确清理所有状态")
|
||||
print("=" * 60)
|
||||
|
||||
algo = GarbageDetectionAlgorithm(confirm_garbage_sec=5)
|
||||
algo._last_alarm_id = "test"
|
||||
|
||||
# 先让它进入某个状态
|
||||
t = datetime(2026, 4, 17, 10, 0, 0)
|
||||
for i in range(10):
|
||||
algo.process("roi_1", "cam_1", make_tracks("roi_1", ["garbage"]),
|
||||
t + timedelta(seconds=i))
|
||||
assert algo.state == "ALARMED"
|
||||
assert len(algo._detection_window) > 0
|
||||
assert len(algo.alert_cooldowns) > 0
|
||||
|
||||
# Reset
|
||||
algo.reset()
|
||||
|
||||
assert algo.state == "IDLE"
|
||||
assert algo.state_start_time is None
|
||||
assert algo._last_alarm_id is None
|
||||
assert algo._garbage_start_time is None
|
||||
assert len(algo._detection_window) == 0
|
||||
assert len(algo.alert_cooldowns) == 0
|
||||
print(" 所有状态已清空 [OK]")
|
||||
|
||||
|
||||
# ===== 测试 7:多个垃圾目标计数 =====
|
||||
|
||||
def test_multiple_garbage_count():
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 7: ROI 内多个垃圾目标 → garbage_count 正确")
|
||||
print("=" * 60)
|
||||
|
||||
algo = GarbageDetectionAlgorithm(confirm_garbage_sec=5)
|
||||
|
||||
alerts, _ = simulate(
|
||||
algo, "roi_1", "cam_1",
|
||||
lambda i: make_tracks("roi_1", ["garbage", "garbage", "garbage"]),
|
||||
count=10,
|
||||
)
|
||||
|
||||
assert len(alerts) == 1
|
||||
assert alerts[0]["garbage_count"] == 3
|
||||
print(f" garbage_count: {alerts[0]['garbage_count']} [OK]")
|
||||
|
||||
|
||||
# ===== 测试 8:非 target_class 不计入 =====
|
||||
|
||||
def test_non_target_class_ignored():
|
||||
print("\n" + "=" * 60)
|
||||
print("TEST 8: person/car 类不计入(只看 garbage)")
|
||||
print("=" * 60)
|
||||
|
||||
algo = GarbageDetectionAlgorithm(confirm_garbage_sec=10)
|
||||
|
||||
alerts, _ = simulate(
|
||||
algo, "roi_1", "cam_1",
|
||||
lambda i: make_tracks("roi_1", ["person", "car"]), # 都不是 garbage
|
||||
count=30,
|
||||
)
|
||||
|
||||
assert algo.state == "IDLE", f"Expected IDLE, got {algo.state}"
|
||||
assert len(alerts) == 0
|
||||
print(f" 状态: {algo.state},无告警 [OK]")
|
||||
|
||||
|
||||
# ===== 运行所有测试 =====
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [
|
||||
test_idle_when_no_garbage,
|
||||
test_garbage_triggers_alarm,
|
||||
test_cooldown_prevents_duplicate,
|
||||
test_resolve_after_cleaning,
|
||||
test_garbage_reappears_during_clearing,
|
||||
test_reset_clears_state,
|
||||
test_multiple_garbage_count,
|
||||
test_non_target_class_ignored,
|
||||
]
|
||||
passed = 0
|
||||
failed = 0
|
||||
for t in tests:
|
||||
try:
|
||||
t()
|
||||
passed += 1
|
||||
except AssertionError as e:
|
||||
print(f" FAIL: {e}")
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
failed += 1
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"结果: {passed} 通过, {failed} 失败")
|
||||
print("=" * 60)
|
||||
sys.exit(0 if failed == 0 else 1)
|
||||
Reference in New Issue
Block a user