Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cd9a90a7ac | |||
| 64b464feb8 |
23
.env
Normal file
23
.env
Normal file
@@ -0,0 +1,23 @@
|
||||
# Local debug overrides
|
||||
DEBUG_SERVER_HOST=0.0.0.0
|
||||
|
||||
# Alarm upload - cloud API URL (WVP backend)
|
||||
CLOUD_API_URL=http://124.221.55.225:18080
|
||||
|
||||
# ===== 边缘设备 ID(必须与 WVP 数据库 ROI 表的 device_id 一致) =====
|
||||
EDGE_DEVICE_ID=edge
|
||||
|
||||
# ===== 配置同步模式(REDIS=监听云端Stream, LOCAL=仅本地SQLite) =====
|
||||
CONFIG_SYNC_MODE=REDIS
|
||||
|
||||
# ===== 云端 Redis(截图处理器 + 配置同步,db=1) =====
|
||||
CLOUD_REDIS_HOST=sh-crs-6upea3zn.sql.tencentcdb.com
|
||||
CLOUD_REDIS_PORT=24637
|
||||
CLOUD_REDIS_DB=1
|
||||
CLOUD_REDIS_PASSWORD=HkVZkVnn1
|
||||
|
||||
# ===== 腾讯云 COS(截图上传) =====
|
||||
COS_REGION=ap-shanghai
|
||||
COS_BUCKET=xhwkzx-1-1389966313
|
||||
COS_SECRET_ID=AKIDVxPiqmVhYv7FCwVqytdAVddQ2TJySt9I
|
||||
COS_SECRET_KEY=1rVyEI8mMVWs21xfBUjy4BE6DA4z7KWb
|
||||
25
.env.example
25
.env.example
@@ -1,25 +0,0 @@
|
||||
# ===== 云端 API =====
|
||||
CLOUD_API_URL=http://your-server:8000
|
||||
WVP_API_URL=http://your-server:18080
|
||||
|
||||
# ===== 边缘设备 ID(必须与 WVP 数据库 ROI 表的 device_id 一致) =====
|
||||
EDGE_DEVICE_ID=edge
|
||||
|
||||
# ===== 配置同步模式(REDIS=监听云端Stream, LOCAL=仅本地SQLite) =====
|
||||
CONFIG_SYNC_MODE=REDIS
|
||||
|
||||
# ===== 云端 Redis(截图处理器 + 配置同步) =====
|
||||
CLOUD_REDIS_HOST=your-redis-host
|
||||
CLOUD_REDIS_PORT=6379
|
||||
CLOUD_REDIS_DB=1
|
||||
CLOUD_REDIS_PASSWORD=your-redis-password
|
||||
|
||||
# ===== 腾讯云 COS(截图上传) =====
|
||||
COS_REGION=ap-shanghai
|
||||
COS_BUCKET=your-bucket-name
|
||||
COS_SECRET_ID=your-cos-secret-id
|
||||
COS_SECRET_KEY=your-cos-secret-key
|
||||
|
||||
# ===== 按算法独立置信度阈值 =====
|
||||
ALGO_CONF_LEAVE_POST=0.4
|
||||
ALGO_CONF_INTRUSION=0.6
|
||||
366
CLAUDE.md
366
CLAUDE.md
@@ -1,366 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## 项目概述
|
||||
|
||||
边缘 AI 推理服务,部署在客户现场边缘端。通过 TensorRT GPU 加速实时分析视频流,检测离岗、入侵等异常行为,上报告警到云端。支持配置热更新、截图上传、告警去重等功能。
|
||||
|
||||
**部署位置:** 边缘端(客户现场)
|
||||
**运行方式:** 裸机部署或 Docker(需要 GPU 支持)
|
||||
**主要功能:**
|
||||
- RTSP 视频流接入(从 WVP 平台拉流)
|
||||
- TensorRT GPU 推理(批量处理,8 帧/批次)
|
||||
- 多算法支持(leave_post 离岗、intrusion 入侵)
|
||||
- 告警去重(ROI 级 + 摄像头级冷却)
|
||||
- 配置热更新(Redis Stream 订阅)
|
||||
- 截图上传(腾讯云 COS)
|
||||
- 告警上报(HTTP POST 到云端)
|
||||
|
||||
## 常用命令
|
||||
|
||||
### 本地开发
|
||||
|
||||
```bash
|
||||
# 安装依赖(Python 3.8+)
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 配置环境
|
||||
cp .env.example .env
|
||||
# 编辑 .env:
|
||||
# DEVICE_ID=edge_device_001
|
||||
# CLOUD_REDIS_HOST=腾讯云Redis地址
|
||||
# CLOUD_REDIS_PORT=6379
|
||||
# CLOUD_REDIS_PASSWORD=密码
|
||||
# LOCAL_REDIS_HOST=localhost
|
||||
# LOCAL_REDIS_PORT=6379
|
||||
# COS_SECRET_ID=腾讯云COS密钥ID
|
||||
# COS_SECRET_KEY=腾讯云COS密钥KEY
|
||||
# COS_BUCKET=your-bucket
|
||||
# COS_REGION=ap-beijing
|
||||
# CLOUD_API_URL=http://云端IP:8000
|
||||
|
||||
# 运行推理服务
|
||||
python main.py
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
```bash
|
||||
# 运行完整工作流测试
|
||||
python test_leave_post_full_workflow.py
|
||||
|
||||
# 运行无持续时长测试
|
||||
python test_leave_post_no_duration.py
|
||||
|
||||
# 运行单元测试
|
||||
pytest tests/
|
||||
pytest -v tests/test_config_sync.py
|
||||
```
|
||||
|
||||
### 工具脚本
|
||||
|
||||
```bash
|
||||
# 诊断缺失摄像头配置
|
||||
python diagnose_missing_cameras.py
|
||||
|
||||
# 清理旧的 ROI 配置
|
||||
python cleanup_old_rois.py
|
||||
|
||||
# 恢复摄像头配置
|
||||
python restore_cameras.py
|
||||
```
|
||||
|
||||
### Docker 部署(生产环境)
|
||||
|
||||
```bash
|
||||
# 构建镜像(需要 CUDA 12.1 + TensorRT 8.6 基础镜像)
|
||||
docker build -t edge-inference:latest .
|
||||
|
||||
# 运行容器(需要 GPU 支持)
|
||||
docker run -d \
|
||||
--name edge-inference \
|
||||
--gpus all \
|
||||
--restart=always \
|
||||
-v /path/to/models:/app/models \
|
||||
-v /path/to/.env:/app/.env \
|
||||
-v /path/to/data:/app/data \
|
||||
edge-inference:latest
|
||||
|
||||
# 查看日志
|
||||
docker logs -f edge-inference
|
||||
|
||||
# 重启容器
|
||||
docker restart edge-inference
|
||||
|
||||
# 进入容器调试
|
||||
docker exec -it edge-inference /bin/bash
|
||||
|
||||
# 检查 GPU 使用情况
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
## 架构概览
|
||||
|
||||
### 核心模块(core/)
|
||||
|
||||
- **config_sync.py** — 配置同步管理器
|
||||
- 订阅云端 Redis Stream: `device_config_stream`
|
||||
- 拉取配置:`GET device:{device_id}:config`
|
||||
- 版本控制、自动回滚、离线可用
|
||||
- 热更新视频流(启停摄像头)
|
||||
|
||||
- **video_stream.py** — 多流管理器
|
||||
- RTSP 拉流、解码、帧缓存
|
||||
- 多路并发处理
|
||||
- 流状态监控、自动重连
|
||||
|
||||
- **preprocessor.py** — 图像预处理
|
||||
- Resize、Normalize、NCHW 转换
|
||||
- 批量预处理(8 帧/批次)
|
||||
|
||||
- **tensorrt_engine.py** — TensorRT 推理引擎
|
||||
- Engine 缓存管理
|
||||
- 批量推理(提升吞吐量)
|
||||
- GPU 内存优化
|
||||
|
||||
- **postprocessor.py** — 后处理器
|
||||
- NMS(非极大值抑制)
|
||||
- ROI 区域过滤(point-in-polygon)
|
||||
- 算法分发(根据 algorithm_code)
|
||||
|
||||
- **result_reporter.py** — 结果上报器
|
||||
- 生成 alarm_id:`edge_{device_id}_{timestamp}_{uuid6}`
|
||||
- LPUSH 到本地 Redis: `local:alarm:pending`
|
||||
- 零阻塞(立即返回)
|
||||
|
||||
- **alarm_upload_worker.py** — 告警上传 Worker
|
||||
- 独立线程,BRPOP 消费队列
|
||||
- 上传截图到腾讯云 COS
|
||||
- HTTP POST 到云端:`/api/ai/alert/edge/report`
|
||||
- 失败重试(3 次)→ 死信队列
|
||||
|
||||
- **screenshot_handler.py** — 截图处理器
|
||||
- XREADGROUP 订阅云端 Redis Stream: `edge_snap_request`
|
||||
- 从视频流获取最新帧
|
||||
- 上传 COS,HTTP 回调 WVP
|
||||
|
||||
### 算法模块(algorithms.py)
|
||||
|
||||
**已实现算法:**
|
||||
- **LeavePostAlgorithm** — 离岗检测
|
||||
- 检测 ROI 内是否有人
|
||||
- 持续无人触发告警
|
||||
- 人员回归发送 resolve 通知
|
||||
|
||||
- **IntrusionAlgorithm** — 入侵检测
|
||||
- 检测 ROI 内是否有人入侵
|
||||
- 立即触发告警
|
||||
|
||||
**算法接口:**
|
||||
```python
|
||||
class AlgorithmBase:
|
||||
def process(self, detections, roi_info, camera_id, bind_info):
|
||||
"""
|
||||
处理检测结果
|
||||
Args:
|
||||
detections: 检测框列表 [{class_id, confidence, bbox}]
|
||||
roi_info: ROI 配置 {roi_id, polygon, ...}
|
||||
camera_id: 摄像头 ID
|
||||
bind_info: 算法绑定配置 {threshold, cooldown, ...}
|
||||
Returns:
|
||||
告警信息或 None
|
||||
"""
|
||||
pass
|
||||
```
|
||||
|
||||
### 数据流
|
||||
|
||||
```
|
||||
配置下发:
|
||||
WVP → XADD device_config_stream → Edge XREADGROUP
|
||||
→ 拉取 Redis config → 版本校验 → 热更新视频流
|
||||
|
||||
视频推理:
|
||||
RTSP 拉流 → 解码 → 预处理 → TensorRT 推理
|
||||
→ NMS → ROI 过滤 → 算法处理 → 告警去重
|
||||
→ LPUSH local:alarm:pending
|
||||
|
||||
告警上报:
|
||||
BRPOP 队列 → 上传 COS 截图 → HTTP POST 云端
|
||||
→ 失败重试 → 死信队列
|
||||
|
||||
截图请求:
|
||||
WVP → XADD edge_snap_request → Edge XREADGROUP
|
||||
→ 获取帧 → 上传 COS → HTTP 回调 WVP
|
||||
```
|
||||
|
||||
## Redis Key 设计
|
||||
|
||||
### 云端 Redis
|
||||
- `device:{device_id}:config` — 设备最新配置 JSON
|
||||
- `device:{device_id}:version` — 配置版本号
|
||||
- `device_config_stream` — 配置变更 Stream
|
||||
- `edge_snap_request` — 截图请求 Stream
|
||||
|
||||
### 本地 Redis
|
||||
- `local:device:config:current` — 当前生效配置
|
||||
- `local:device:config:backup` — 上次成功配置(回滚用)
|
||||
- `local:device:config:version` — 当前版本号
|
||||
- `local:alarm:pending` — 待上报告警队列
|
||||
- `local:alarm:retry` — 重试队列
|
||||
- `local:alarm:dead` — 死信队列
|
||||
|
||||
## 配置文件
|
||||
|
||||
### .env 环境变量(关键配置)
|
||||
|
||||
```bash
|
||||
# 设备标识
|
||||
DEVICE_ID=edge_device_001
|
||||
|
||||
# 云端 Redis(配置同步)
|
||||
CLOUD_REDIS_HOST=腾讯云Redis地址
|
||||
CLOUD_REDIS_PORT=6379
|
||||
CLOUD_REDIS_PASSWORD=密码
|
||||
|
||||
# 本地 Redis(告警队列、配置缓存)
|
||||
LOCAL_REDIS_HOST=localhost
|
||||
LOCAL_REDIS_PORT=6379
|
||||
|
||||
# 腾讯云 COS(截图上传)
|
||||
COS_SECRET_ID=your_secret_id
|
||||
COS_SECRET_KEY=your_secret_key
|
||||
COS_BUCKET=your-bucket-1234567890
|
||||
COS_REGION=ap-beijing
|
||||
|
||||
# 云端 API(告警上报)
|
||||
CLOUD_API_URL=http://云端IP:8000
|
||||
```
|
||||
|
||||
### config/ 目录(YAML 配置)
|
||||
|
||||
- `settings.py` — 配置加载器(读取 .env)
|
||||
- `database.py` — SQLite 管理器(本地配置持久化)
|
||||
- `config_models.py` — 配置数据模型
|
||||
|
||||
## 告警上报数据格式
|
||||
|
||||
### 告警触发(POST /api/ai/alert/edge/report)
|
||||
|
||||
```json
|
||||
{
|
||||
"alarm_id": "edge_device001_20260305120000_a1b2c3",
|
||||
"alarm_type": "leave_post",
|
||||
"device_id": "camera_001",
|
||||
"scene_id": "roi_001",
|
||||
"event_time": "2026-03-05T12:00:00Z",
|
||||
"alarm_level": 2,
|
||||
"snapshot_url": "https://cos.ap-beijing.myqcloud.com/...",
|
||||
"confidence_score": 0.92,
|
||||
"algorithm_code": "YOLO",
|
||||
"ext_data": {
|
||||
"bind_id": "bind_123",
|
||||
"bbox": [100, 100, 300, 400],
|
||||
"first_frame_time": "2026-03-05T12:00:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 告警结束(POST /api/ai/alert/edge/resolve)
|
||||
|
||||
```json
|
||||
{
|
||||
"alarm_id": "edge_device001_20260305120000_a1b2c3",
|
||||
"duration_ms": 120000,
|
||||
"last_frame_time": "2026-03-05T12:02:00Z",
|
||||
"resolve_type": "PERSON_RETURN"
|
||||
}
|
||||
```
|
||||
|
||||
## 开发工作流
|
||||
|
||||
### 添加新算法
|
||||
1. 在 `algorithms.py` 创建新的算法类,继承 `AlgorithmBase`
|
||||
2. 实现 `process()` 方法
|
||||
3. 在 `AlgorithmManager` 注册算法
|
||||
4. 更新 WVP 后端的算法配置表
|
||||
|
||||
### 修改推理流程
|
||||
1. 修改 `main.py` 中的 `EdgeInferenceService`
|
||||
2. 调整批量大小:`self._max_batch_size`
|
||||
3. 调整冷却时间:`self._camera_cooldown_seconds`
|
||||
|
||||
### 调整告警去重策略
|
||||
- ROI 级冷却:每个 ROI 独立冷却
|
||||
- 摄像头级冷却:同摄像头同类型告警冷却
|
||||
- 修改:`main.py` 中的 `_camera_alert_cooldown` 逻辑
|
||||
|
||||
### 优化性能
|
||||
1. **批量推理**:调整 `_max_batch_size`(默认 8)
|
||||
2. **GPU 内存**:减少视频流并发数
|
||||
3. **告警队列**:监控 Redis 队列长度
|
||||
4. **TensorRT 引擎**:确保引擎缓存命中
|
||||
|
||||
## 常见问题
|
||||
|
||||
### TensorRT 引擎加载慢
|
||||
首次运行会构建引擎(5-10 分钟),之后会缓存。
|
||||
检查 `models/` 目录下是否有 `.engine` 文件。
|
||||
|
||||
### 告警上报失败
|
||||
检查云端 API 是否可达:
|
||||
```bash
|
||||
curl http://云端IP:8000/health
|
||||
```
|
||||
|
||||
检查 COS 配置:
|
||||
```bash
|
||||
# 查看 .env 中的 COS 配置
|
||||
cat .env | grep COS
|
||||
```
|
||||
|
||||
### 配置不更新
|
||||
检查云端 Redis 连接:
|
||||
```bash
|
||||
redis-cli -h 云端Redis地址 -p 6379 -a 密码 ping
|
||||
```
|
||||
|
||||
检查配置版本:
|
||||
```bash
|
||||
redis-cli GET device:edge_device_001:version
|
||||
redis-cli GET local:device:config:version
|
||||
```
|
||||
|
||||
### 视频流断开
|
||||
检查 RTSP 地址是否可访问:
|
||||
```bash
|
||||
ffprobe rtsp://云端IP:10002/...
|
||||
```
|
||||
|
||||
检查 WVP 流媒体服务:
|
||||
```bash
|
||||
docker logs vsp-zlmedia
|
||||
```
|
||||
|
||||
### GPU 内存不足
|
||||
降低批量大小或减少并发流数量。
|
||||
检查 GPU 使用情况:
|
||||
```bash
|
||||
nvidia-smi
|
||||
```
|
||||
|
||||
## Git 提交规范
|
||||
|
||||
在修改代码后,使用中文提交信息:
|
||||
|
||||
```bash
|
||||
git add .
|
||||
git commit -m "功能:添加XXX功能
|
||||
|
||||
详细说明...
|
||||
|
||||
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>"
|
||||
```
|
||||
|
||||
**不要立即 push**,等待用户指示再推送到远程。
|
||||
612
algorithms.py
612
algorithms.py
@@ -43,9 +43,6 @@ class LeavePostAlgorithm:
|
||||
STATE_ALARMED = "ALARMED" # 已告警(等待回岗)
|
||||
STATE_NON_WORK_TIME = "NON_WORK_TIME" # 非工作时间
|
||||
|
||||
# 告警级别常量(默认值,可通过 params 覆盖)
|
||||
DEFAULT_ALARM_LEVEL = 2 # 普通
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
confirm_on_duty_sec: int = 10, # 上岗确认窗口(持续检测到人的时长)
|
||||
@@ -55,7 +52,6 @@ class LeavePostAlgorithm:
|
||||
cooldown_sec: int = 600, # 告警冷却期(两次告警的最小间隔)
|
||||
working_hours: Optional[List[Dict]] = None,
|
||||
target_class: Optional[str] = "person",
|
||||
alarm_level: Optional[int] = None,
|
||||
# 兼容旧参数名(向后兼容)
|
||||
confirm_leave_sec: Optional[int] = None,
|
||||
):
|
||||
@@ -69,7 +65,6 @@ class LeavePostAlgorithm:
|
||||
# 工作时间和目标类别
|
||||
self.working_hours = working_hours or []
|
||||
self.target_class = target_class
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
self.state: str = self.STATE_INIT
|
||||
@@ -246,7 +241,7 @@ class LeavePostAlgorithm:
|
||||
self.state_start_time = current_time
|
||||
self.detection_window.clear()
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_ON_DUTY → INIT (人消失)")
|
||||
elif elapsed >= self.confirm_on_duty_sec and detection_ratio >= 0.6:
|
||||
elif elapsed >= self.confirm_on_duty_sec and detection_ratio >= 0.7:
|
||||
# 上岗确认成功(命中率>=70%)
|
||||
self.state = self.STATE_ON_DUTY
|
||||
self.state_start_time = current_time
|
||||
@@ -255,8 +250,8 @@ class LeavePostAlgorithm:
|
||||
|
||||
elif self.state == self.STATE_ON_DUTY:
|
||||
# 在岗状态:监控是否离岗
|
||||
if detection_ratio < 0.2:
|
||||
# 滑动窗口内 80% 以上帧无人,进入离岗确认
|
||||
if detection_ratio == 0:
|
||||
# 滑动窗口内完全没有人,进入离岗确认
|
||||
self.state = self.STATE_CONFIRMING_OFF_DUTY
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: ON_DUTY → CONFIRMING_OFF_DUTY")
|
||||
@@ -265,12 +260,12 @@ class LeavePostAlgorithm:
|
||||
# 离岗确认中:需要持续未检测到人
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if detection_ratio >= 0.5:
|
||||
# 窗口内检测率恢复到 50% 以上,人确实回来了
|
||||
if roi_has_person:
|
||||
# 人回来了,回到ON_DUTY
|
||||
self.state = self.STATE_ON_DUTY
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_OFF_DUTY → ON_DUTY (人回来了, ratio={detection_ratio:.2f})")
|
||||
elif elapsed >= self.confirm_off_duty_sec and detection_ratio < 0.2:
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_OFF_DUTY → ON_DUTY (人回来了)")
|
||||
elif elapsed >= self.confirm_off_duty_sec and detection_ratio == 0:
|
||||
# 离岗确认成功,进入倒计时
|
||||
self.state = self.STATE_OFF_DUTY_COUNTDOWN
|
||||
self.state_start_time = current_time
|
||||
@@ -300,7 +295,6 @@ class LeavePostAlgorithm:
|
||||
"camera_id": camera_id,
|
||||
"bbox": bbox,
|
||||
"alert_type": "leave_post",
|
||||
"alarm_level": self._alarm_level,
|
||||
"message": "人员离岗告警",
|
||||
"first_frame_time": self._leave_start_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
})
|
||||
@@ -398,8 +392,8 @@ class IntrusionAlgorithm:
|
||||
STATE_ALARMED = "ALARMED" # 已告警(等待入侵消失)
|
||||
STATE_CONFIRMING_CLEAR = "CONFIRMING_CLEAR" # 入侵消失确认中
|
||||
|
||||
# 告警级别常量(默认值,可通过 params 覆盖)
|
||||
DEFAULT_ALARM_LEVEL = 1 # 重要
|
||||
# 告警级别常量
|
||||
ALARM_LEVEL_INTRUSION = 3
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -408,7 +402,6 @@ class IntrusionAlgorithm:
|
||||
confirm_intrusion_seconds: Optional[int] = None, # 入侵确认时间(默认5秒)
|
||||
confirm_clear_seconds: Optional[int] = None, # 消失确认时间(默认180秒)
|
||||
target_class: Optional[str] = None,
|
||||
alarm_level: Optional[int] = None,
|
||||
):
|
||||
self.cooldown_seconds = cooldown_seconds
|
||||
|
||||
@@ -420,7 +413,6 @@ class IntrusionAlgorithm:
|
||||
self.confirm_seconds = confirm_seconds
|
||||
|
||||
self.target_class = target_class
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
self.state: str = self.STATE_IDLE
|
||||
@@ -524,7 +516,7 @@ class IntrusionAlgorithm:
|
||||
"camera_id": camera_id,
|
||||
"bbox": bbox,
|
||||
"alert_type": "intrusion",
|
||||
"alarm_level": self._alarm_level,
|
||||
"alarm_level": self.ALARM_LEVEL_INTRUSION,
|
||||
"message": "检测到周界入侵",
|
||||
"first_frame_time": self._intrusion_start_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
})
|
||||
@@ -713,501 +705,6 @@ class IntrusionAlgorithm:
|
||||
# self.alert_triggered.clear()
|
||||
|
||||
|
||||
class IllegalParkingAlgorithm:
|
||||
"""
|
||||
车辆违停检测算法(状态机版本 v1.0)
|
||||
|
||||
状态机:
|
||||
IDLE → CONFIRMING_VEHICLE → PARKED_COUNTDOWN → ALARMED → CONFIRMING_CLEAR → IDLE
|
||||
|
||||
业务流程:
|
||||
1. 检测到车辆进入禁停区 → 车辆确认期(confirm_vehicle_sec,默认15秒,ratio≥0.6)
|
||||
2. 确认有车 → 违停倒计时(parking_countdown_sec,默认300秒/5分钟)
|
||||
3. 倒计时结束仍有车 → 触发告警(ALARMED状态)
|
||||
4. 车辆离开 → 消失确认期(confirm_clear_sec,默认30秒,ratio<0.2)
|
||||
5. 确认车辆离开 → 发送resolve事件 → 回到空闲状态
|
||||
|
||||
使用滑动窗口(10秒)抗抖动,支持多类车辆检测。
|
||||
"""
|
||||
|
||||
# 状态定义
|
||||
STATE_IDLE = "IDLE"
|
||||
STATE_CONFIRMING_VEHICLE = "CONFIRMING_VEHICLE"
|
||||
STATE_PARKED_COUNTDOWN = "PARKED_COUNTDOWN"
|
||||
STATE_ALARMED = "ALARMED"
|
||||
STATE_CONFIRMING_CLEAR = "CONFIRMING_CLEAR"
|
||||
|
||||
# 告警级别常量(默认值,可通过 params 覆盖)
|
||||
DEFAULT_ALARM_LEVEL = 1 # 重要
|
||||
|
||||
# 滑动窗口参数
|
||||
WINDOW_SIZE_SEC = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
confirm_vehicle_sec: int = 15,
|
||||
parking_countdown_sec: int = 300,
|
||||
confirm_clear_sec: int = 120,
|
||||
cooldown_sec: int = 1800,
|
||||
target_classes: Optional[List[str]] = None,
|
||||
alarm_level: Optional[int] = None,
|
||||
):
|
||||
self.confirm_vehicle_sec = confirm_vehicle_sec
|
||||
self.parking_countdown_sec = parking_countdown_sec
|
||||
self.confirm_clear_sec = confirm_clear_sec
|
||||
self.cooldown_sec = cooldown_sec
|
||||
self.target_classes = target_classes or ["car", "truck", "bus", "motorcycle"]
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
self.state: str = self.STATE_IDLE
|
||||
self.state_start_time: Optional[datetime] = None
|
||||
|
||||
# 滑动窗口:存储 (timestamp, has_vehicle: bool)
|
||||
self._detection_window: deque = deque()
|
||||
|
||||
# 告警追踪
|
||||
self._last_alarm_id: Optional[str] = None
|
||||
self._parking_start_time: Optional[datetime] = None
|
||||
|
||||
# 冷却期管理
|
||||
self.alert_cooldowns: Dict[str, datetime] = {}
|
||||
|
||||
def _check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
for roi in matched_rois:
|
||||
if roi.get("roi_id") == roi_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_target_classes(self, detection: Dict) -> bool:
|
||||
"""检查检测目标是否属于车辆类别"""
|
||||
det_class = detection.get("class", "")
|
||||
return det_class in self.target_classes
|
||||
|
||||
def _update_window(self, current_time: datetime, has_vehicle: bool):
|
||||
"""更新滑动窗口"""
|
||||
self._detection_window.append((current_time, has_vehicle))
|
||||
cutoff = current_time - timedelta(seconds=self.WINDOW_SIZE_SEC)
|
||||
while self._detection_window and self._detection_window[0][0] < cutoff:
|
||||
self._detection_window.popleft()
|
||||
|
||||
def _get_window_ratio(self) -> float:
|
||||
"""获取滑动窗口内的检测命中率"""
|
||||
if not self._detection_window:
|
||||
return 0.0
|
||||
hits = sum(1 for _, has in self._detection_window if has)
|
||||
return hits / len(self._detection_window)
|
||||
|
||||
def _get_latest_bbox(self, tracks: List[Dict], roi_id: str) -> List[float]:
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
|
||||
return det.get("bbox", [])
|
||||
return []
|
||||
|
||||
def _get_max_confidence(self, tracks: List[Dict], roi_id: str) -> float:
|
||||
"""获取ROI内车辆的最高置信度"""
|
||||
max_conf = 0.0
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
|
||||
max_conf = max(max_conf, det.get("confidence", 0.0))
|
||||
return max_conf
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
"""处理单帧检测结果"""
|
||||
current_time = current_time or datetime.now()
|
||||
alerts = []
|
||||
|
||||
# 检查ROI内是否有车辆
|
||||
roi_has_vehicle = any(
|
||||
self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det)
|
||||
for det in tracks
|
||||
)
|
||||
|
||||
# 更新滑动窗口
|
||||
self._update_window(current_time, roi_has_vehicle)
|
||||
ratio = self._get_window_ratio()
|
||||
|
||||
# === 状态机处理 ===
|
||||
|
||||
if self.state == self.STATE_IDLE:
|
||||
if roi_has_vehicle:
|
||||
self.state = self.STATE_CONFIRMING_VEHICLE
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: IDLE → CONFIRMING_VEHICLE")
|
||||
|
||||
elif self.state == self.STATE_CONFIRMING_VEHICLE:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_IDLE
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if ratio < 0.3:
|
||||
# 命中率过低,车辆可能只是路过
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_VEHICLE → IDLE (ratio={ratio:.2f}<0.3)")
|
||||
elif elapsed >= self.confirm_vehicle_sec and ratio >= 0.6:
|
||||
# 确认有车辆停留,进入倒计时
|
||||
self._parking_start_time = self.state_start_time
|
||||
self.state = self.STATE_PARKED_COUNTDOWN
|
||||
self.state_start_time = current_time
|
||||
logger.info(f"ROI {roi_id}: CONFIRMING_VEHICLE → PARKED_COUNTDOWN (ratio={ratio:.2f})")
|
||||
|
||||
elif self.state == self.STATE_PARKED_COUNTDOWN:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_IDLE
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if ratio < 0.2:
|
||||
# 车辆已离开
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._parking_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: PARKED_COUNTDOWN → IDLE (车辆离开, ratio={ratio:.2f})")
|
||||
elif elapsed >= self.parking_countdown_sec:
|
||||
# 倒计时结束,检查冷却期
|
||||
cooldown_key = f"{camera_id}_{roi_id}"
|
||||
if cooldown_key not in self.alert_cooldowns or \
|
||||
(current_time - self.alert_cooldowns[cooldown_key]).total_seconds() > self.cooldown_sec:
|
||||
|
||||
bbox = self._get_latest_bbox(tracks, roi_id)
|
||||
confidence = self._get_max_confidence(tracks, roi_id)
|
||||
|
||||
alerts.append({
|
||||
"roi_id": roi_id,
|
||||
"camera_id": camera_id,
|
||||
"bbox": bbox,
|
||||
"alert_type": "illegal_parking",
|
||||
"alarm_level": self._alarm_level,
|
||||
"confidence": confidence,
|
||||
"message": f"检测到车辆违停(已停留{int(elapsed / 60)}分钟)",
|
||||
"first_frame_time": self._parking_start_time.strftime('%Y-%m-%d %H:%M:%S') if self._parking_start_time else None,
|
||||
"duration_minutes": elapsed / 60,
|
||||
})
|
||||
|
||||
self.alert_cooldowns[cooldown_key] = current_time
|
||||
self.state = self.STATE_ALARMED
|
||||
logger.warning(f"ROI {roi_id}: PARKED_COUNTDOWN → ALARMED (违停告警触发)")
|
||||
else:
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._parking_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: PARKED_COUNTDOWN → IDLE (冷却期内)")
|
||||
|
||||
elif self.state == self.STATE_ALARMED:
|
||||
if ratio < 0.15:
|
||||
self.state = self.STATE_CONFIRMING_CLEAR
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: ALARMED → CONFIRMING_CLEAR (ratio={ratio:.2f}<0.15)")
|
||||
|
||||
elif self.state == self.STATE_CONFIRMING_CLEAR:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_IDLE
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if ratio >= 0.5:
|
||||
# 车辆又出现,回到ALARMED
|
||||
self.state = self.STATE_ALARMED
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → ALARMED (车辆仍在)")
|
||||
elif elapsed >= self.confirm_clear_sec and ratio < 0.2:
|
||||
# 确认车辆已离开
|
||||
if self._last_alarm_id and self._parking_start_time:
|
||||
duration_ms = int((current_time - self._parking_start_time).total_seconds() * 1000)
|
||||
alerts.append({
|
||||
"alert_type": "alarm_resolve",
|
||||
"resolve_alarm_id": self._last_alarm_id,
|
||||
"duration_ms": duration_ms,
|
||||
"last_frame_time": current_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"resolve_type": "vehicle_left",
|
||||
})
|
||||
logger.info(f"ROI {roi_id}: 违停告警已解决(车辆离开)")
|
||||
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._last_alarm_id = None
|
||||
self._parking_start_time = None
|
||||
self.alert_cooldowns.clear() # 车辆离开后清冷却,新车违停可正常告警
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → IDLE")
|
||||
|
||||
return alerts
|
||||
|
||||
def set_last_alarm_id(self, alarm_id: str):
|
||||
"""由 main.py 在告警生成后回填 alarm_id"""
|
||||
self._last_alarm_id = alarm_id
|
||||
|
||||
def reset(self):
|
||||
"""重置算法状态"""
|
||||
self.state = self.STATE_IDLE
|
||||
self.state_start_time = None
|
||||
self._last_alarm_id = None
|
||||
self._parking_start_time = None
|
||||
self._detection_window.clear()
|
||||
self.alert_cooldowns.clear()
|
||||
|
||||
def get_state(self, current_time: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""获取当前状态"""
|
||||
current_time = current_time or datetime.now()
|
||||
state_info = {
|
||||
"state": self.state,
|
||||
"state_start_time": self.state_start_time.isoformat() if self.state_start_time else None,
|
||||
"window_ratio": self._get_window_ratio(),
|
||||
}
|
||||
if self.state in (self.STATE_ALARMED, self.STATE_PARKED_COUNTDOWN) and self._parking_start_time:
|
||||
state_info["parking_duration_sec"] = (current_time - self._parking_start_time).total_seconds()
|
||||
state_info["alarm_id"] = self._last_alarm_id
|
||||
return state_info
|
||||
|
||||
|
||||
class VehicleCongestionAlgorithm:
|
||||
"""
|
||||
车辆拥堵检测算法(状态机版本 v1.0)
|
||||
|
||||
状态机:
|
||||
NORMAL → CONFIRMING_CONGESTION → CONGESTED → CONFIRMING_CLEAR → NORMAL
|
||||
|
||||
业务流程:
|
||||
1. 检测到车辆数量 ≥ count_threshold → 拥堵确认期(confirm_congestion_sec,默认60秒)
|
||||
2. 确认拥堵(窗口内平均车辆数 ≥ threshold)→ 触发告警
|
||||
3. 车辆减少 → 消散确认期(confirm_clear_sec,默认120秒)
|
||||
4. 确认消散(平均数 < threshold)→ 发送resolve事件 → 回到正常
|
||||
|
||||
使用滑动窗口(10秒)存储车辆计数,取平均值判断。
|
||||
"""
|
||||
|
||||
# 状态定义
|
||||
STATE_NORMAL = "NORMAL"
|
||||
STATE_CONFIRMING_CONGESTION = "CONFIRMING_CONGESTION"
|
||||
STATE_CONGESTED = "CONGESTED"
|
||||
STATE_CONFIRMING_CLEAR = "CONFIRMING_CLEAR"
|
||||
|
||||
# 告警级别(默认值,可通过 params 覆盖)
|
||||
DEFAULT_ALARM_LEVEL = 2 # 普通
|
||||
|
||||
# 滑动窗口参数
|
||||
WINDOW_SIZE_SEC = 10
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
count_threshold: int = 5,
|
||||
confirm_congestion_sec: int = 60,
|
||||
confirm_clear_sec: int = 180,
|
||||
cooldown_sec: int = 1800,
|
||||
target_classes: Optional[List[str]] = None,
|
||||
alarm_level: Optional[int] = None,
|
||||
):
|
||||
self.count_threshold = count_threshold
|
||||
self.confirm_congestion_sec = confirm_congestion_sec
|
||||
self.confirm_clear_sec = confirm_clear_sec
|
||||
self.cooldown_sec = cooldown_sec
|
||||
self.target_classes = target_classes or ["car", "truck", "bus", "motorcycle"]
|
||||
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
|
||||
|
||||
# 状态变量
|
||||
self.state: str = self.STATE_NORMAL
|
||||
self.state_start_time: Optional[datetime] = None
|
||||
|
||||
# 滑动窗口:存储 (timestamp, vehicle_count: int)
|
||||
self._count_window: deque = deque()
|
||||
|
||||
# 告警追踪
|
||||
self._last_alarm_id: Optional[str] = None
|
||||
self._congestion_start_time: Optional[datetime] = None
|
||||
|
||||
# 冷却期管理
|
||||
self.alert_cooldowns: Dict[str, datetime] = {}
|
||||
|
||||
def _check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
for roi in matched_rois:
|
||||
if roi.get("roi_id") == roi_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_target_classes(self, detection: Dict) -> bool:
|
||||
det_class = detection.get("class", "")
|
||||
return det_class in self.target_classes
|
||||
|
||||
def _count_vehicles_in_roi(self, tracks: List[Dict], roi_id: str) -> int:
|
||||
"""统计ROI内的车辆数量"""
|
||||
return sum(
|
||||
1 for det in tracks
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det)
|
||||
)
|
||||
|
||||
def _update_count_window(self, current_time: datetime, count: int):
|
||||
"""更新车辆计数滑动窗口"""
|
||||
self._count_window.append((current_time, count))
|
||||
cutoff = current_time - timedelta(seconds=self.WINDOW_SIZE_SEC)
|
||||
while self._count_window and self._count_window[0][0] < cutoff:
|
||||
self._count_window.popleft()
|
||||
|
||||
def _get_avg_count(self) -> float:
|
||||
"""获取滑动窗口内的平均车辆数"""
|
||||
if not self._count_window:
|
||||
return 0.0
|
||||
total = sum(c for _, c in self._count_window)
|
||||
return total / len(self._count_window)
|
||||
|
||||
def _get_max_confidence(self, tracks: List[Dict], roi_id: str) -> float:
|
||||
max_conf = 0.0
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
|
||||
max_conf = max(max_conf, det.get("confidence", 0.0))
|
||||
return max_conf
|
||||
|
||||
def _get_latest_bbox(self, tracks: List[Dict], roi_id: str) -> List[float]:
|
||||
for det in tracks:
|
||||
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
|
||||
return det.get("bbox", [])
|
||||
return []
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
"""处理单帧检测结果"""
|
||||
current_time = current_time or datetime.now()
|
||||
alerts = []
|
||||
|
||||
# 统计ROI内车辆数
|
||||
vehicle_count = self._count_vehicles_in_roi(tracks, roi_id)
|
||||
self._update_count_window(current_time, vehicle_count)
|
||||
avg_count = self._get_avg_count()
|
||||
|
||||
# === 状态机处理 ===
|
||||
|
||||
if self.state == self.STATE_NORMAL:
|
||||
if avg_count >= self.count_threshold:
|
||||
self.state = self.STATE_CONFIRMING_CONGESTION
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: NORMAL → CONFIRMING_CONGESTION (avg={avg_count:.1f}≥{self.count_threshold})")
|
||||
|
||||
elif self.state == self.STATE_CONFIRMING_CONGESTION:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_NORMAL
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if avg_count < self.count_threshold:
|
||||
# 车辆减少,回到正常
|
||||
self.state = self.STATE_NORMAL
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CONGESTION → NORMAL (avg={avg_count:.1f}<{self.count_threshold})")
|
||||
elif elapsed >= self.confirm_congestion_sec:
|
||||
# 确认拥堵,检查冷却期
|
||||
cooldown_key = f"{camera_id}_{roi_id}"
|
||||
if cooldown_key not in self.alert_cooldowns or \
|
||||
(current_time - self.alert_cooldowns[cooldown_key]).total_seconds() > self.cooldown_sec:
|
||||
|
||||
self._congestion_start_time = self.state_start_time
|
||||
bbox = self._get_latest_bbox(tracks, roi_id)
|
||||
confidence = self._get_max_confidence(tracks, roi_id)
|
||||
|
||||
alerts.append({
|
||||
"roi_id": roi_id,
|
||||
"camera_id": camera_id,
|
||||
"bbox": bbox,
|
||||
"alert_type": "vehicle_congestion",
|
||||
"alarm_level": self._alarm_level,
|
||||
"confidence": confidence,
|
||||
"message": f"检测到车辆拥堵(平均{avg_count:.0f}辆,持续{int(elapsed)}秒)",
|
||||
"first_frame_time": self._congestion_start_time.strftime('%Y-%m-%d %H:%M:%S') if self._congestion_start_time else None,
|
||||
"vehicle_count": int(avg_count),
|
||||
})
|
||||
|
||||
self.alert_cooldowns[cooldown_key] = current_time
|
||||
self.state = self.STATE_CONGESTED
|
||||
logger.warning(f"ROI {roi_id}: CONFIRMING_CONGESTION → CONGESTED (拥堵告警触发, avg={avg_count:.1f})")
|
||||
else:
|
||||
self.state = self.STATE_NORMAL
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CONGESTION → NORMAL (冷却期内)")
|
||||
|
||||
elif self.state == self.STATE_CONGESTED:
|
||||
# 车辆数降到阈值的一半以下才开始确认消散(避免抖动)
|
||||
if avg_count < self.count_threshold * 0.5:
|
||||
self.state = self.STATE_CONFIRMING_CLEAR
|
||||
self.state_start_time = current_time
|
||||
logger.debug(f"ROI {roi_id}: CONGESTED → CONFIRMING_CLEAR (avg={avg_count:.1f}<{self.count_threshold * 0.5:.1f})")
|
||||
|
||||
elif self.state == self.STATE_CONFIRMING_CLEAR:
|
||||
if self.state_start_time is None:
|
||||
self.state = self.STATE_NORMAL
|
||||
return alerts
|
||||
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if avg_count >= self.count_threshold:
|
||||
# 又拥堵了,回到CONGESTED
|
||||
self.state = self.STATE_CONGESTED
|
||||
self.state_start_time = None
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → CONGESTED (avg={avg_count:.1f}≥{self.count_threshold})")
|
||||
elif elapsed >= self.confirm_clear_sec:
|
||||
# 确认消散
|
||||
if self._last_alarm_id and self._congestion_start_time:
|
||||
duration_ms = int((current_time - self._congestion_start_time).total_seconds() * 1000)
|
||||
alerts.append({
|
||||
"alert_type": "alarm_resolve",
|
||||
"resolve_alarm_id": self._last_alarm_id,
|
||||
"duration_ms": duration_ms,
|
||||
"last_frame_time": current_time.strftime('%Y-%m-%d %H:%M:%S'),
|
||||
"resolve_type": "congestion_cleared",
|
||||
})
|
||||
logger.info(f"ROI {roi_id}: 拥堵告警已解决(车辆减少)")
|
||||
|
||||
self.state = self.STATE_NORMAL
|
||||
self.state_start_time = None
|
||||
self._last_alarm_id = None
|
||||
self._congestion_start_time = None
|
||||
self.alert_cooldowns.clear() # 拥堵消散后清冷却,再次拥堵可正常告警
|
||||
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → NORMAL")
|
||||
|
||||
return alerts
|
||||
|
||||
def set_last_alarm_id(self, alarm_id: str):
|
||||
"""由 main.py 在告警生成后回填 alarm_id"""
|
||||
self._last_alarm_id = alarm_id
|
||||
|
||||
def reset(self):
|
||||
"""重置算法状态"""
|
||||
self.state = self.STATE_NORMAL
|
||||
self.state_start_time = None
|
||||
self._last_alarm_id = None
|
||||
self._congestion_start_time = None
|
||||
self._count_window.clear()
|
||||
self.alert_cooldowns.clear()
|
||||
|
||||
def get_state(self, current_time: Optional[datetime] = None) -> Dict[str, Any]:
|
||||
"""获取当前状态"""
|
||||
current_time = current_time or datetime.now()
|
||||
state_info = {
|
||||
"state": self.state,
|
||||
"state_start_time": self.state_start_time.isoformat() if self.state_start_time else None,
|
||||
"avg_vehicle_count": self._get_avg_count(),
|
||||
}
|
||||
if self.state in (self.STATE_CONGESTED, self.STATE_CONFIRMING_CLEAR) and self._congestion_start_time:
|
||||
state_info["congestion_duration_sec"] = (current_time - self._congestion_start_time).total_seconds()
|
||||
state_info["alarm_id"] = self._last_alarm_id
|
||||
return state_info
|
||||
|
||||
|
||||
class AlgorithmManager:
|
||||
def __init__(self, working_hours: Optional[List[Dict]] = None):
|
||||
self.algorithms: Dict[str, Dict[str, Any]] = {}
|
||||
@@ -1227,20 +724,11 @@ class AlgorithmManager:
|
||||
"confirm_seconds": 5,
|
||||
"target_class": None,
|
||||
},
|
||||
"illegal_parking": {
|
||||
"confirm_vehicle_sec": 15,
|
||||
"parking_countdown_sec": 300,
|
||||
"confirm_clear_sec": 30,
|
||||
"cooldown_sec": 600,
|
||||
"target_classes": ["car", "truck", "bus", "motorcycle"],
|
||||
},
|
||||
"vehicle_congestion": {
|
||||
"count_threshold": 3,
|
||||
"confirm_congestion_sec": 60,
|
||||
"confirm_clear_sec": 120,
|
||||
"cooldown_sec": 600,
|
||||
"target_classes": ["car", "truck", "bus", "motorcycle"],
|
||||
},
|
||||
# "crowd_detection": {
|
||||
# "max_count": 10,
|
||||
# "cooldown_seconds": 300,
|
||||
# "target_class": "person",
|
||||
# },
|
||||
}
|
||||
|
||||
self._pubsub = None
|
||||
@@ -1350,7 +838,6 @@ class AlgorithmManager:
|
||||
key = f"{roi_id}_{bind_id}"
|
||||
|
||||
if algo_code == "leave_post":
|
||||
configured_alarm_level = params.get("alarm_level")
|
||||
algo_params = {
|
||||
"confirm_on_duty_sec": params.get("confirm_on_duty_sec", 10),
|
||||
"confirm_leave_sec": params.get("confirm_leave_sec", 30),
|
||||
@@ -1364,8 +851,6 @@ class AlgorithmManager:
|
||||
algo.confirm_leave_sec = algo_params["confirm_leave_sec"]
|
||||
algo.cooldown_sec = algo_params["cooldown_sec"]
|
||||
algo.target_class = algo_params["target_class"]
|
||||
if configured_alarm_level is not None:
|
||||
algo._alarm_level = configured_alarm_level
|
||||
if algo_params["working_hours"]:
|
||||
algo.working_hours = algo_params["working_hours"]
|
||||
logger.info(f"已热更新算法参数: {key}")
|
||||
@@ -1374,15 +859,13 @@ class AlgorithmManager:
|
||||
self.algorithms[roi_id][key]["leave_post"] = LeavePostAlgorithm(
|
||||
confirm_on_duty_sec=algo_params["confirm_on_duty_sec"],
|
||||
confirm_leave_sec=algo_params["confirm_leave_sec"],
|
||||
leave_countdown_sec=algo_params.get("leave_countdown_sec", 300),
|
||||
leave_countdown_sec=algo_params.get("leave_countdown_sec", 300), # 离岗倒计时,默认5分钟
|
||||
cooldown_sec=algo_params["cooldown_sec"],
|
||||
working_hours=algo_params["working_hours"],
|
||||
target_class=algo_params["target_class"],
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
logger.info(f"已从Redis加载算法: {key}")
|
||||
elif algo_code == "intrusion":
|
||||
configured_alarm_level = params.get("alarm_level")
|
||||
algo_params = {
|
||||
"cooldown_seconds": params.get("cooldown_seconds", 300),
|
||||
"confirm_seconds": params.get("confirm_seconds", 5),
|
||||
@@ -1393,7 +876,6 @@ class AlgorithmManager:
|
||||
cooldown_seconds=algo_params["cooldown_seconds"],
|
||||
confirm_seconds=algo_params["confirm_seconds"],
|
||||
target_class=algo_params["target_class"],
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
logger.info(f"已从Redis加载算法: {key}")
|
||||
|
||||
@@ -1516,13 +998,9 @@ class AlgorithmManager:
|
||||
config_manager = get_config_sync_manager()
|
||||
bindings = config_manager.get_bindings_from_redis("")
|
||||
|
||||
# 收集当前配置中有效的 (roi_id, bind_id) 组合
|
||||
valid_keys = set()
|
||||
|
||||
for bind in bindings:
|
||||
bind_id = bind.get("bind_id")
|
||||
roi_id = bind.get("roi_id")
|
||||
valid_keys.add((roi_id, bind_id))
|
||||
|
||||
if preserve_state:
|
||||
# 仅更新参数,不重置状态
|
||||
@@ -1534,32 +1012,6 @@ class AlgorithmManager:
|
||||
if self.load_bind_from_redis(bind_id):
|
||||
count += 1
|
||||
|
||||
# 清理内存中已被删除的算法实例
|
||||
removed_count = 0
|
||||
with self._update_lock:
|
||||
for roi_id in list(self.algorithms.keys()):
|
||||
for key in list(self.algorithms[roi_id].keys()):
|
||||
# key 格式: "roi_id_bind_id"
|
||||
if key.startswith(roi_id + "_"):
|
||||
bind_id = key[len(roi_id) + 1:]
|
||||
if (roi_id, bind_id) not in valid_keys:
|
||||
for algo in self.algorithms[roi_id][key].values():
|
||||
algo.reset()
|
||||
del self.algorithms[roi_id][key]
|
||||
# 清除注册缓存
|
||||
self._registered_keys = {
|
||||
k for k in self._registered_keys
|
||||
if not (k[0] == roi_id and k[1] == bind_id)
|
||||
}
|
||||
removed_count += 1
|
||||
logger.info(f"清理已删除的算法实例: {key}")
|
||||
# 如果 roi 下已无算法实例,清理空字典
|
||||
if not self.algorithms[roi_id]:
|
||||
del self.algorithms[roi_id]
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"已清理 {removed_count} 个孤立算法实例")
|
||||
|
||||
logger.info(f"已重新加载 {count} 个算法配置 (preserve_state={preserve_state})")
|
||||
return count
|
||||
except Exception as e:
|
||||
@@ -1592,9 +1044,6 @@ class AlgorithmManager:
|
||||
if params:
|
||||
algo_params.update(params)
|
||||
|
||||
# 从 params 中提取告警等级(前端配置下发)
|
||||
configured_alarm_level = algo_params.get("alarm_level")
|
||||
|
||||
if algorithm_type == "leave_post":
|
||||
roi_working_hours = algo_params.get("working_hours") or self.working_hours
|
||||
self.algorithms[roi_id][key]["leave_post"] = LeavePostAlgorithm(
|
||||
@@ -1604,33 +1053,20 @@ class AlgorithmManager:
|
||||
cooldown_sec=algo_params.get("cooldown_sec", 600),
|
||||
working_hours=roi_working_hours,
|
||||
target_class=algo_params.get("target_class", "person"),
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
elif algorithm_type == "intrusion":
|
||||
self.algorithms[roi_id][key]["intrusion"] = IntrusionAlgorithm(
|
||||
cooldown_seconds=algo_params.get("cooldown_seconds", 300),
|
||||
confirm_seconds=algo_params.get("confirm_seconds", 5),
|
||||
target_class=algo_params.get("target_class"),
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
elif algorithm_type == "illegal_parking":
|
||||
self.algorithms[roi_id][key]["illegal_parking"] = IllegalParkingAlgorithm(
|
||||
confirm_vehicle_sec=algo_params.get("confirm_vehicle_sec", 15),
|
||||
parking_countdown_sec=algo_params.get("parking_countdown_sec", 300),
|
||||
confirm_clear_sec=algo_params.get("confirm_clear_sec", 30),
|
||||
cooldown_sec=algo_params.get("cooldown_sec", 600),
|
||||
target_classes=algo_params.get("target_classes", ["car", "truck", "bus", "motorcycle"]),
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
elif algorithm_type == "vehicle_congestion":
|
||||
self.algorithms[roi_id][key]["vehicle_congestion"] = VehicleCongestionAlgorithm(
|
||||
count_threshold=algo_params.get("count_threshold", 3),
|
||||
confirm_congestion_sec=algo_params.get("confirm_congestion_sec", 60),
|
||||
confirm_clear_sec=algo_params.get("confirm_clear_sec", 120),
|
||||
cooldown_sec=algo_params.get("cooldown_sec", 600),
|
||||
target_classes=algo_params.get("target_classes", ["car", "truck", "bus", "motorcycle"]),
|
||||
alarm_level=configured_alarm_level,
|
||||
)
|
||||
# elif algorithm_type == "crowd_detection":
|
||||
# from algorithms import CrowdDetectionAlgorithm
|
||||
# self.algorithms[roi_id][key]["crowd_detection"] = CrowdDetectionAlgorithm(
|
||||
# max_count=algo_params.get("max_count", 10),
|
||||
# cooldown_seconds=algo_params.get("cooldown_seconds", 300),
|
||||
# target_class=algo_params.get("target_class", "person"),
|
||||
# )
|
||||
|
||||
self._registered_keys.add(cache_key)
|
||||
|
||||
@@ -1723,8 +1159,6 @@ class AlgorithmManager:
|
||||
"state": getattr(algo, "state", "WAITING"),
|
||||
"alarm_sent": getattr(algo, "alarm_sent", False),
|
||||
}
|
||||
elif algo_type in ("illegal_parking", "vehicle_congestion"):
|
||||
status[f"{algo_type}_{bind_id}"] = algo.get_state()
|
||||
else:
|
||||
status[f"{algo_type}_{bind_id}"] = {
|
||||
"detection_count": len(getattr(algo, "detection_start", {})),
|
||||
|
||||
@@ -49,8 +49,7 @@ class CameraInfo:
|
||||
enabled: bool = True
|
||||
location: Optional[str] = None
|
||||
extra_params: Optional[Dict[str, Any]] = None
|
||||
area_id: Optional[int] = None
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
@@ -61,9 +60,8 @@ class CameraInfo:
|
||||
"enabled": self.enabled,
|
||||
"location": self.location,
|
||||
"extra_params": self.extra_params,
|
||||
"area_id": self.area_id,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'CameraInfo':
|
||||
"""从字典创建实例"""
|
||||
@@ -75,7 +73,6 @@ class CameraInfo:
|
||||
enabled=data.get("enabled", True),
|
||||
location=data.get("location"),
|
||||
extra_params=data.get("extra_params"),
|
||||
area_id=data.get("area_id"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -233,7 +233,6 @@ class SQLiteManager:
|
||||
location TEXT,
|
||||
roi_group_id TEXT,
|
||||
extra_params TEXT,
|
||||
area_id INTEGER,
|
||||
updated_at TEXT
|
||||
)
|
||||
""")
|
||||
@@ -251,14 +250,7 @@ class SQLiteManager:
|
||||
""")
|
||||
|
||||
self._conn.commit()
|
||||
|
||||
# 迁移:为已有数据库添加 area_id 列
|
||||
try:
|
||||
cursor.execute("ALTER TABLE camera_configs ADD COLUMN area_id INTEGER")
|
||||
self._conn.commit()
|
||||
except Exception:
|
||||
pass # 列已存在,忽略
|
||||
|
||||
|
||||
self._init_default_algorithms()
|
||||
|
||||
def _init_default_algorithms(self):
|
||||
@@ -563,8 +555,8 @@ class SQLiteManager:
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO camera_configs (
|
||||
camera_id, rtsp_url, camera_name, status, enabled,
|
||||
location, roi_group_id, extra_params, area_id, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
location, roi_group_id, extra_params, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
camera_id, rtsp_url,
|
||||
kwargs.get('camera_name'),
|
||||
@@ -573,7 +565,6 @@ class SQLiteManager:
|
||||
kwargs.get('location'),
|
||||
kwargs.get('roi_group_id'),
|
||||
str(kwargs.get('extra_params')) if kwargs.get('extra_params') else None,
|
||||
kwargs.get('area_id'),
|
||||
now
|
||||
))
|
||||
self._conn.commit()
|
||||
@@ -589,7 +580,8 @@ class SQLiteManager:
|
||||
cursor.execute("SELECT * FROM camera_configs WHERE camera_id = ?", (camera_id,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
columns = ['camera_id', 'rtsp_url', 'camera_name', 'status',
|
||||
'enabled', 'location', 'roi_group_id', 'extra_params', 'updated_at']
|
||||
return dict(zip(columns, row))
|
||||
return None
|
||||
except Exception as e:
|
||||
@@ -601,7 +593,8 @@ class SQLiteManager:
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute("SELECT * FROM camera_configs ORDER BY camera_id")
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
columns = ['camera_id', 'rtsp_url', 'camera_name', 'status',
|
||||
'enabled', 'location', 'roi_group_id', 'extra_params', 'updated_at']
|
||||
return [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有摄像头配置失败: {e}")
|
||||
@@ -906,17 +899,7 @@ class SQLiteManager:
|
||||
except Exception as e:
|
||||
logger.error(f"删除ROI算法绑定失败: {e}")
|
||||
return 0
|
||||
|
||||
def get_all_bind_ids(self) -> List[str]:
|
||||
"""获取所有算法绑定的 bind_id 列表(用于清理孤立绑定)"""
|
||||
try:
|
||||
cursor = self._conn.cursor()
|
||||
cursor.execute("SELECT bind_id FROM roi_algo_bind")
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
except Exception as e:
|
||||
logger.error(f"获取所有绑定ID失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def log_config_update(
|
||||
self,
|
||||
config_type: str,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -93,7 +93,6 @@ class COSConfig:
|
||||
class AlarmUploadConfig:
|
||||
"""告警上报配置"""
|
||||
cloud_api_url: str = "http://localhost:8000"
|
||||
wvp_api_url: str = "" # WVP 平台地址(心跳同步用)
|
||||
edge_token: str = ""
|
||||
retry_max: int = 3
|
||||
retry_interval: int = 5
|
||||
@@ -118,17 +117,10 @@ class InferenceConfig:
|
||||
input_width: int = 480
|
||||
input_height: int = 480
|
||||
batch_size: int = 1
|
||||
conf_threshold: float = 0.45
|
||||
nms_threshold: float = 0.5
|
||||
conf_threshold: float = 0.4
|
||||
nms_threshold: float = 0.45
|
||||
device_id: int = 0
|
||||
fp16_mode: bool = True
|
||||
# 按算法类型覆盖置信度阈值,key=algo_code, value=threshold
|
||||
# 未命中时回退到 conf_threshold
|
||||
algo_conf_thresholds: Dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def get_conf_threshold(self, algo_code: str) -> float:
|
||||
"""获取指定算法的置信度阈值,未配置则回退全局值"""
|
||||
return self.algo_conf_thresholds.get(algo_code, self.conf_threshold)
|
||||
|
||||
|
||||
# ===================== Debug / Local Sync =====================
|
||||
@@ -264,7 +256,6 @@ class Settings:
|
||||
|
||||
self.alarm_upload = AlarmUploadConfig(
|
||||
cloud_api_url=os.getenv("CLOUD_API_URL", "http://localhost:8000"),
|
||||
wvp_api_url=os.getenv("WVP_API_URL", ""),
|
||||
edge_token=os.getenv("EDGE_TOKEN", ""),
|
||||
retry_max=int(os.getenv("ALARM_RETRY_MAX", "3")),
|
||||
retry_interval=int(os.getenv("ALARM_RETRY_INTERVAL", "5")),
|
||||
@@ -282,9 +273,8 @@ class Settings:
|
||||
input_width=int(os.getenv("INPUT_WIDTH", "480")),
|
||||
input_height=int(os.getenv("INPUT_HEIGHT", "480")),
|
||||
batch_size=int(os.getenv("BATCH_SIZE", "4")),
|
||||
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.45")),
|
||||
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.5")),
|
||||
algo_conf_thresholds=self._parse_algo_conf_thresholds(),
|
||||
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.4")),
|
||||
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")),
|
||||
)
|
||||
|
||||
self.config_sync_mode = os.getenv("CONFIG_SYNC_MODE", "LOCAL").upper()
|
||||
@@ -307,25 +297,6 @@ class Settings:
|
||||
# 使用 COCO 类别名称
|
||||
self.class_names = COCO_CLASS_NAMES
|
||||
|
||||
@staticmethod
|
||||
def _parse_algo_conf_thresholds() -> Dict[str, float]:
|
||||
"""解析 ALGO_CONF_* 环境变量,返回 {algo_code: threshold} 字典
|
||||
|
||||
环境变量命名规则: ALGO_CONF_{ALGO_CODE},如:
|
||||
ALGO_CONF_LEAVE_POST=0.35
|
||||
ALGO_CONF_INTRUSION=0.55
|
||||
"""
|
||||
prefix = "ALGO_CONF_"
|
||||
result = {}
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith(prefix) and value:
|
||||
algo_code = key[len(prefix):].lower()
|
||||
try:
|
||||
result[algo_code] = float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _parse_working_hours(self) -> List[dict]:
|
||||
"""解析工作时间配置"""
|
||||
working_hours_str = os.getenv("WORKING_HOURS", "")
|
||||
|
||||
@@ -17,9 +17,7 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
_BEIJING_TZ = timezone(timedelta(hours=8))
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import redis
|
||||
@@ -296,7 +294,7 @@ class AlarmUploadWorker:
|
||||
return None
|
||||
|
||||
# 生成 Object Key: alarms/{device_id}/{yyyy-MM-dd}/{alarm_id}.jpg
|
||||
date_str = datetime.now(_BEIJING_TZ).strftime("%Y-%m-%d")
|
||||
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
object_key = f"alarms/{device_id}/{date_str}/{alarm_id}.jpg"
|
||||
|
||||
try:
|
||||
@@ -394,7 +392,7 @@ class AlarmUploadWorker:
|
||||
if retry_count > max_retry:
|
||||
# 超过最大重试次数,写入死信队列
|
||||
alarm_data["_dead_reason"] = error
|
||||
alarm_data["_dead_at"] = datetime.now(_BEIJING_TZ).isoformat()
|
||||
alarm_data["_dead_at"] = datetime.now(timezone.utc).isoformat()
|
||||
dead_json = json.dumps(alarm_data, ensure_ascii=False)
|
||||
self._redis.lpush(REDIS_KEY_ALARM_DEAD, dead_json)
|
||||
self._stats["dead_lettered"] += 1
|
||||
|
||||
@@ -540,7 +540,6 @@ class ConfigSyncManager:
|
||||
camera_name=cam.get("camera_name", ""),
|
||||
enabled=cam.get("enabled", True),
|
||||
location=cam.get("location", ""),
|
||||
area_id=cam.get("area_id"),
|
||||
)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
@@ -641,15 +640,6 @@ class ConfigSyncManager:
|
||||
removed += 1 + bind_count
|
||||
logger.info(f"清理旧 ROI: {old_roi_id} (含 {bind_count} 条算法绑定)")
|
||||
|
||||
# 清理孤立的算法绑定(ROI 仍存在但绑定已被删除的情况)
|
||||
if incoming_bind_ids:
|
||||
existing_bind_ids = self._db_manager.get_all_bind_ids()
|
||||
for old_bind_id in existing_bind_ids:
|
||||
if old_bind_id not in incoming_bind_ids:
|
||||
self._db_manager.delete_roi_algo_bind(old_bind_id)
|
||||
removed += 1
|
||||
logger.info(f"清理孤立算法绑定: {old_bind_id}")
|
||||
|
||||
if removed > 0:
|
||||
logger.info(f"旧数据清理完成: 共删除 {removed} 条过期记录")
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class NMSProcessor:
|
||||
可选 GPU 加速 (torchvision.ops.nms)
|
||||
"""
|
||||
|
||||
def __init__(self, nms_threshold: float = 0.5, use_gpu: bool = False):
|
||||
def __init__(self, nms_threshold: float = 0.45, use_gpu: bool = False):
|
||||
self.nms_threshold = nms_threshold
|
||||
self.use_gpu = use_gpu and _HAS_TORCH
|
||||
self._logger = get_logger("postprocessor")
|
||||
@@ -162,7 +162,7 @@ class NMSProcessor:
|
||||
boxes: np.ndarray,
|
||||
scores: np.ndarray,
|
||||
class_ids: np.ndarray,
|
||||
conf_threshold: float = 0.45,
|
||||
conf_threshold: float = 0.5,
|
||||
max_output_size: int = 300
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""
|
||||
@@ -579,8 +579,8 @@ class PostProcessor:
|
||||
"alert_cooldown": 300,
|
||||
}
|
||||
|
||||
self.nms_threshold = config.get("nms_threshold", 0.5)
|
||||
self.conf_threshold = config.get("conf_threshold", 0.45)
|
||||
self.nms_threshold = config.get("nms_threshold", 0.45)
|
||||
self.conf_threshold = config.get("conf_threshold", 0.5)
|
||||
|
||||
MAX_DETECTIONS = 8400
|
||||
self._buffer_xyxy = np.zeros((MAX_DETECTIONS, 4), dtype=np.float32)
|
||||
@@ -707,24 +707,22 @@ class PostProcessor:
|
||||
batch_outputs: List[np.ndarray],
|
||||
batch_size: int,
|
||||
conf_threshold: Optional[float] = None,
|
||||
nms_threshold: Optional[float] = None,
|
||||
per_item_conf_thresholds: Optional[List[float]] = None,
|
||||
nms_threshold: Optional[float] = None
|
||||
) -> List[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
||||
"""
|
||||
批量处理多 ROI 检测结果
|
||||
|
||||
|
||||
Args:
|
||||
batch_outputs: 模型输出 (可能是 [output] 或单个 batch 数组)
|
||||
batch_size: batch 大小
|
||||
conf_threshold: 置信度阈值
|
||||
nms_threshold: NMS阈值
|
||||
per_item_conf_thresholds: 每个 batch item 独立的置信度阈值列表,
|
||||
长度必须等于 batch_size。传入时忽略 conf_threshold 参数。
|
||||
|
||||
|
||||
Returns:
|
||||
List of (检测框, 置信度, 类别ID) per ROI
|
||||
"""
|
||||
default_conf = conf_threshold if conf_threshold is not None else self.conf_threshold
|
||||
if conf_threshold is None:
|
||||
conf_threshold = self.conf_threshold
|
||||
if nms_threshold is None:
|
||||
nms_threshold = self.nms_threshold
|
||||
|
||||
@@ -795,16 +793,11 @@ class PostProcessor:
|
||||
self._buffer_xyxy[:valid_count, 2] = boxes[:, 0] + boxes[:, 2] / 2
|
||||
self._buffer_xyxy[:valid_count, 3] = boxes[:, 1] + boxes[:, 3] / 2
|
||||
|
||||
item_conf = (
|
||||
per_item_conf_thresholds[batch_idx]
|
||||
if per_item_conf_thresholds is not None
|
||||
else default_conf
|
||||
)
|
||||
keep_boxes, keep_scores, keep_classes = nms_processor.process_with_confidence_filter(
|
||||
self._buffer_xyxy[:valid_count],
|
||||
scores_coarse.astype(np.float32),
|
||||
class_ids,
|
||||
item_conf
|
||||
conf_threshold
|
||||
)
|
||||
|
||||
results.append((keep_boxes, keep_scores, keep_classes))
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -29,7 +29,7 @@ class AlarmInfo:
|
||||
device_id: str
|
||||
scene_id: str
|
||||
event_time: str # ISO8601
|
||||
alarm_level: int # 0紧急 1重要 2普通 3轻微
|
||||
alarm_level: int # 1-4
|
||||
snapshot_b64: Optional[str] = None # Base64 编码的 JPEG 截图
|
||||
algorithm_code: Optional[str] = None
|
||||
confidence_score: Optional[float] = None
|
||||
@@ -56,7 +56,7 @@ def generate_alarm_id(device_id: str) -> str:
|
||||
生成告警ID
|
||||
格式: edge_{device_id}_{YYYYMMDDHHmmss}_{6位uuid}
|
||||
"""
|
||||
timestamp = datetime.now(timezone(timedelta(hours=8))).strftime("%Y%m%d%H%M%S")
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S")
|
||||
uid = uuid.uuid4().hex[:6]
|
||||
return f"edge_{device_id}_{timestamp}_{uid}"
|
||||
|
||||
|
||||
@@ -326,8 +326,7 @@ class ScreenshotHandler:
|
||||
if callback_url:
|
||||
try:
|
||||
url = callback_url.rstrip("/") + SNAP_CALLBACK_PATH
|
||||
resp = requests.post(url, json=body, timeout=SNAP_CALLBACK_TIMEOUT,
|
||||
proxies={"http": None, "https": None})
|
||||
resp = requests.post(url, json=body, timeout=SNAP_CALLBACK_TIMEOUT)
|
||||
if resp.status_code < 300:
|
||||
logger.info("[截图] HTTP 回调成功: request_id=%s", request_id)
|
||||
return
|
||||
|
||||
155
main.py
155
main.py
@@ -54,7 +54,6 @@ class EdgeInferenceService:
|
||||
self._debug_reload_thread: Optional[threading.Thread] = None
|
||||
self._debug_http_server = None
|
||||
self._debug_http_thread: Optional[threading.Thread] = None
|
||||
self._heartbeat_thread: Optional[threading.Thread] = None
|
||||
|
||||
self._processing_threads: Dict[str, threading.Thread] = {}
|
||||
self._stop_event = threading.Event()
|
||||
@@ -185,27 +184,29 @@ class EdgeInferenceService:
|
||||
def _init_screenshot_handler(self):
|
||||
"""初始化截图处理器"""
|
||||
try:
|
||||
# 截图处理器必须使用独立的 Redis 连接(不能与 config_sync 共用,
|
||||
# 因为两者都做阻塞 XREAD/XREADGROUP,共用连接会互相干扰)
|
||||
cloud_redis = None
|
||||
try:
|
||||
import redis
|
||||
cfg = self._settings.cloud_redis
|
||||
cloud_redis = redis.Redis(
|
||||
host=cfg.host,
|
||||
port=cfg.port,
|
||||
db=cfg.db,
|
||||
password=cfg.password,
|
||||
decode_responses=cfg.decode_responses,
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=10,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
cloud_redis.ping()
|
||||
self._logger.info(f"截图处理器独立连接云端 Redis 成功: {cfg.host}:{cfg.port}/{cfg.db}")
|
||||
except Exception as e:
|
||||
self._logger.warning(f"截图处理器无法连接云端 Redis: {e}")
|
||||
cloud_redis = None
|
||||
# 优先从 config_manager 获取已有的云端 Redis 连接
|
||||
cloud_redis = getattr(self._config_manager, '_cloud_redis', None)
|
||||
|
||||
# LOCAL 模式下 config_manager 不初始化云端 Redis,需要独立创建
|
||||
if cloud_redis is None:
|
||||
try:
|
||||
import redis
|
||||
cfg = self._settings.cloud_redis
|
||||
cloud_redis = redis.Redis(
|
||||
host=cfg.host,
|
||||
port=cfg.port,
|
||||
db=cfg.db,
|
||||
password=cfg.password,
|
||||
decode_responses=cfg.decode_responses,
|
||||
socket_connect_timeout=5,
|
||||
socket_timeout=5,
|
||||
retry_on_timeout=True,
|
||||
)
|
||||
cloud_redis.ping()
|
||||
self._logger.info(f"截图处理器独立连接云端 Redis 成功: {cfg.host}:{cfg.port}/{cfg.db}")
|
||||
except Exception as e:
|
||||
self._logger.warning(f"截图处理器无法连接云端 Redis: {e}")
|
||||
cloud_redis = None
|
||||
|
||||
if cloud_redis and self._stream_manager:
|
||||
self._screenshot_handler = ScreenshotHandler(
|
||||
@@ -283,61 +284,6 @@ class EdgeInferenceService:
|
||||
)
|
||||
self._debug_http_thread.start()
|
||||
|
||||
def _start_heartbeat(self):
|
||||
"""启动心跳守护线程,每 30 秒向云端上报设备状态"""
|
||||
def worker():
|
||||
import requests
|
||||
base_url = self._settings.alarm_upload.cloud_api_url.rstrip("/")
|
||||
wvp_url = self._settings.alarm_upload.wvp_api_url.rstrip("/") if self._settings.alarm_upload.wvp_api_url else ""
|
||||
urls = [f"{base_url}/api/ai/device/heartbeat"]
|
||||
if wvp_url:
|
||||
urls.append(f"{wvp_url}/api/ai/device/heartbeat")
|
||||
device_id = self._settings.mqtt.device_id
|
||||
|
||||
self._logger.info(f"[心跳] 守护线程已启动, 目标: {urls}, device_id={device_id}")
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
start_time = self._performance_stats.get("start_time")
|
||||
uptime = (datetime.now() - start_time).total_seconds() if start_time else 0
|
||||
stream_count = len(self._stream_manager._streams) if self._stream_manager else 0
|
||||
config_version = self._config_manager.config_version if self._config_manager else None
|
||||
|
||||
payload = {
|
||||
"device_id": device_id,
|
||||
"status": {
|
||||
"uptime_seconds": int(uptime),
|
||||
"frames_processed": self._performance_stats.get("total_frames_processed", 0),
|
||||
"alerts_generated": self._performance_stats.get("total_alerts_generated", 0),
|
||||
"stream_count": stream_count,
|
||||
"config_version": config_version,
|
||||
"stream_stats": {
|
||||
"active_streams": stream_count,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
resp = requests.post(url, json=payload, timeout=10)
|
||||
if resp.status_code == 200:
|
||||
self._logger.debug(f"[心跳] {url} 上报成功")
|
||||
else:
|
||||
self._logger.warning(f"[心跳] {url} 上报失败: HTTP {resp.status_code}")
|
||||
except Exception as e:
|
||||
self._logger.warning(f"[心跳] {url} 上报异常: {e}")
|
||||
except Exception as e:
|
||||
self._logger.warning(f"[心跳] 上报异常: {e}")
|
||||
|
||||
self._stop_event.wait(30)
|
||||
|
||||
self._heartbeat_thread = threading.Thread(
|
||||
target=worker,
|
||||
name="HeartbeatWorker",
|
||||
daemon=True,
|
||||
)
|
||||
self._heartbeat_thread.start()
|
||||
|
||||
def initialize(self):
|
||||
"""初始化所有组件"""
|
||||
self._logger.info("=" * 50)
|
||||
@@ -355,7 +301,6 @@ class EdgeInferenceService:
|
||||
self._init_screenshot_handler()
|
||||
self._start_debug_reload_watcher()
|
||||
self._start_debug_http_server()
|
||||
self._start_heartbeat()
|
||||
|
||||
self._performance_stats["start_time"] = datetime.now()
|
||||
|
||||
@@ -368,12 +313,7 @@ class EdgeInferenceService:
|
||||
)
|
||||
|
||||
self._logger.info("所有组件初始化完成")
|
||||
|
||||
def _get_camera_area_id(self, camera_id: str) -> Optional[int]:
|
||||
"""获取摄像头的 area_id"""
|
||||
cam = self._get_camera_config_by_id(camera_id)
|
||||
return cam.area_id if cam else None
|
||||
|
||||
|
||||
def _get_camera_ids_with_roi(self) -> set:
|
||||
"""获取有ROI配置的摄像头ID集合
|
||||
|
||||
@@ -658,24 +598,14 @@ class EdgeInferenceService:
|
||||
self._logger.debug(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms")
|
||||
|
||||
batch_size = len(chunk)
|
||||
|
||||
# 按算法类型获取每个 item 的独立置信度阈值
|
||||
per_item_conf = [
|
||||
self._settings.inference.get_conf_threshold(item[2].algo_code)
|
||||
for item in chunk
|
||||
]
|
||||
|
||||
batch_results = self._postprocessor.batch_process_detections(
|
||||
outputs,
|
||||
batch_size,
|
||||
per_item_conf_thresholds=per_item_conf,
|
||||
conf_threshold=self._settings.inference.conf_threshold
|
||||
)
|
||||
|
||||
total_detections = sum(len(r[0]) for r in batch_results)
|
||||
self._logger.debug(
|
||||
f"[推理] batch_size={batch_size}, 总检测数={total_detections}, "
|
||||
f"conf_thresholds={per_item_conf}"
|
||||
)
|
||||
self._logger.debug(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}")
|
||||
|
||||
for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(chunk):
|
||||
boxes, scores, class_ids = batch_results[idx]
|
||||
@@ -823,16 +753,12 @@ class EdgeInferenceService:
|
||||
self._camera_alert_cooldown[dedup_key] = now
|
||||
self._performance_stats["total_alerts_generated"] += 1
|
||||
|
||||
# 获取算法的事件开始时间(泛化:支持所有算法类型)
|
||||
first_frame_time = None
|
||||
algo = self._algorithm_manager.algorithms.get(roi_id, {}).get(f"{roi_id}_{bind.bind_id}", {}).get(alert_type)
|
||||
if algo:
|
||||
# 各算法使用不同的内部变量名存储开始时间
|
||||
for attr in ('_leave_start_time', '_parking_start_time', '_congestion_start_time', '_intrusion_start_time'):
|
||||
val = getattr(algo, attr, None)
|
||||
if val:
|
||||
first_frame_time = val.isoformat()
|
||||
break
|
||||
# 获取算法的离岗开始时间
|
||||
leave_start_time = None
|
||||
if alert_type == "leave_post":
|
||||
algo = self._algorithm_manager.algorithms.get(roi_id, {}).get(f"{roi_id}_{bind.bind_id}", {}).get("leave_post")
|
||||
if algo and hasattr(algo, '_leave_start_time') and algo._leave_start_time:
|
||||
leave_start_time = algo._leave_start_time.isoformat()
|
||||
|
||||
from core.result_reporter import AlarmInfo, generate_alarm_id
|
||||
alarm_info = AlarmInfo(
|
||||
@@ -852,17 +778,20 @@ class EdgeInferenceService:
|
||||
"bind_id": bind.bind_id,
|
||||
"message": alert.get("message", ""),
|
||||
"edge_node_id": self._settings.mqtt.device_id,
|
||||
"first_frame_time": first_frame_time,
|
||||
"vehicle_count": alert.get("vehicle_count"),
|
||||
"area_id": self._get_camera_area_id(camera_id),
|
||||
"first_frame_time": leave_start_time,
|
||||
},
|
||||
)
|
||||
self._reporter.report_alarm(alarm_info, screenshot=frame.image)
|
||||
|
||||
# 回填 alarm_id 到算法实例(用于后续 resolve 追踪,泛化支持所有算法类型)
|
||||
algo = self._algorithm_manager.algorithms.get(roi_id, {}).get(f"{roi_id}_{bind.bind_id}", {}).get(alert_type)
|
||||
if algo and hasattr(algo, 'set_last_alarm_id'):
|
||||
algo.set_last_alarm_id(alarm_info.alarm_id)
|
||||
# 回填 alarm_id 到算法实例(用于后续 resolve 追踪)
|
||||
if alert_type == "leave_post":
|
||||
algo = self._algorithm_manager.algorithms.get(roi_id, {}).get(f"{roi_id}_{bind.bind_id}", {}).get("leave_post")
|
||||
if algo and hasattr(algo, 'set_last_alarm_id'):
|
||||
algo.set_last_alarm_id(alarm_info.alarm_id)
|
||||
elif alert_type == "intrusion":
|
||||
algo = self._algorithm_manager.algorithms.get(roi_id, {}).get(f"{roi_id}_{bind.bind_id}", {}).get("intrusion")
|
||||
if algo and hasattr(algo, 'set_last_alarm_id'):
|
||||
algo.set_last_alarm_id(alarm_info.alarm_id)
|
||||
|
||||
self._logger.info(
|
||||
f"告警已生成: type={alert_type}, "
|
||||
|
||||
Reference in New Issue
Block a user