Compare commits

...

25 Commits

Author SHA1 Message Date
cdx
8e2b285893 revert afa9e122a5
Some checks failed
Python Test / test (push) Has been cancelled
revert feat: 添加model(cancel)
2026-01-24 23:09:45 +08:00
afa9e122a5 feat: 添加model
Some checks failed
Python Test / test (push) Has been cancelled
2026-01-23 15:27:06 +08:00
7a10a983c8 feat: 添加摄像头配置与状态同步功能
Some checks failed
Python Test / test (push) Has been cancelled
2026-01-23 10:35:33 +08:00
98c741cb2b fix(cache): 修复 _load_rois_from_db 未加载 working_hours 字段的问题
Some checks failed
Python Test / test (push) Has been cancelled
- 在 roi_config 中补充读取数据库中的 working_hours 字段
- 确保 pipeline 能将 ROI 特定的工作时间配置正确传递给算法
- 修复后,LeavePostAlgorithm.is_in_working_hours() 将使用 ROI 配置而非全局默认值
2026-01-22 18:26:57 +08:00
44b6c70a4b fix(roi):修复 working_hours 为空数组 [] 时被错误识别为 None 的问题,导致全天模式失效。
Some checks failed
Python Test / test (push) Has been cancelled
- 修改 `api/roi.py` 中的判断逻辑:使用 `is not None` 替代 truthy 检查,确保 `[]` 能正确序列化并存入数据库
- 更新 `config.yaml` 全局默认值为 `working_hours: []`,表示“全天开启”
2026-01-22 18:02:19 +08:00
3af7a0f805 fix(roi-editor):修复 ROI 编辑器中时间段选择器(TimePicker.RangePicker)因连续调用两次状态更新导致的清空问题。
Some checks failed
Python Test / test (push) Has been cancelled
- 新增 `updateWorkingHoursRange` 批量更新函数,将 start/end 作为原子操作同步更新
- 在 onChange 回调中添加 `Array.isArray(dates) && dates.length >= 2` 类型校验
- 避免 React 异步 setState 冲突导致 workingHoursList 意外重置
2026-01-22 17:26:28 +08:00
cb46d12cfa fix:修复因数据库缺少 working_hours 列导致 ROI 配置失败的问题。
- 手动执行 SQL:ALTER TABLE rois ADD COLUMN working_hours TEXT
- 确保现有 SQLite 数据库(security_monitor.db)结构与模型定义一致
- 避免因字段缺失引发 API 或算法读取异常
2026-01-22 16:44:26 +08:00
123903950b refactor(model):移除版本不一致的模型engine文件 2026-01-22 15:55:30 +08:00
2d5ada2909 fix: 统一配置参数 2026-01-22 15:53:31 +08:00
6fc17ccf64 fix:修复参数顺序错误、ROI 匹配失效、状态机缺失 INIT 状态
Some checks failed
Python Test / test (push) Has been cancelled
- 调整 process() 函数参数顺序,确保 roi_id 和 camera_id 正确对应
- 重构 ROI 匹配逻辑,使用明确的 roi_id 进行区域归属判断
- 引入 INIT 状态:启动时若 ROI 无人,先进入 INIT,
  经 off_duty_confirm_sec 确认后才进入 OFF_DUTY 倒计时
2026-01-22 15:08:28 +08:00
6116f0b982 fix:修复 ROI 多边形未传递及空 ROI 判断逻辑错误导致的离岗告警失效问题。
根本原因:
1. pipeline.py 中调用 register_algorithm 时未传入 roi_polygon,导致算法内 roi_polygon 为空
2. is_point_in_roi 函数在 roi_polygon 为空或点数 <3 时错误返回 True,使系统误判“有人在岗”
3. 因此即使 ROI 内无人,算法也永远不会进入离岗倒计时

修复措施:
- 在注册算法时正确传递 ROI 多边形坐标
- 修正 is_point_in_roi:当 ROI 无效时应返回 False(无人)
- 确保无检测框时仍能触发状态机超时逻辑
2026-01-22 13:34:04 +08:00
20f295a491 修复 ROI 区域内人员离开十几分钟未触发告警的问题。
Some checks failed
Python Test / test (push) Has been cancelled
1. 仍在 confirm_sec 滑动窗口内(未完成确认)
2. threshold_sec 阈值设置过长(需检查数据库实际配置值)
3. 新算法未被正确调用
2026-01-22 12:22:26 +08:00
cc4f33c0fd refactor:移除 LeavePostAlgorithm,改用 ROI 区域入侵检测
-修复sort导入问题
2026-01-22 11:43:40 +08:00
2e9bf2b50c ci: add Gitea Actions for Python 2026-01-22 11:35:02 +08:00
248a240524 ci: add drone pipeline for python tests 2026-01-22 11:23:12 +08:00
10b9fb1804 refactor:使用状态机优化离岗检测逻辑,并移除排序相关算法 2026-01-22 11:03:01 +08:00
1a94854c52 chore:删除无关原始算法 Python 文件 2026-01-22 10:57:19 +08:00
13afc654ab fix(timezone):修复触发告警记录的时间混乱问题 2026-01-22 09:08:44 +08:00
804c6a60e9 fix(alarm): update alarm status endpoint to accept JSON body 2026-01-21 17:50:19 +08:00
20634c2ad4 fix(dashboard): align camera status field with backend API response 2026-01-21 17:13:05 +08:00
46ee360d46 fix(api): update camera creation endpoint to accept JSON body 2026-01-21 17:03:59 +08:00
6712a311f8 fix(camera): resolve camera status display issue 2026-01-21 16:06:39 +08:00
294b0e1abb feat(alarm): add alarm snapshot viewing functionality 2026-01-21 15:16:25 +08:00
1c7190bbb0 fix(inference): resolve multiple YOLO inference and API issues 2026-01-21 14:48:01 +08:00
1b344aeb2e feat(db): add missing columns for cloud sync and alarm region data 2026-01-21 13:49:01 +08:00
25 changed files with 1594 additions and 1703 deletions

View File

@@ -0,0 +1,34 @@
name: Python Test
on:
push:
branches: [ "master", "main" ]
pull_request:
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f "requirements.txt" ]; then
pip install -r requirements.txt
fi
pip install pytest
- name: Run tests
run: |
if [ -d "tests" ]; then
pytest tests/ --verbose
else
echo "No tests directory found, skipping tests."
fi

View File

@@ -1,7 +1,8 @@
from datetime import datetime
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from pydantic import BaseModel
from sqlalchemy.orm import Session
from db.crud import (
@@ -16,6 +17,41 @@ from inference.pipeline import get_pipeline
router = APIRouter(prefix="/api/alarms", tags=["告警管理"])
class AlarmUpdateRequest(BaseModel):
llm_checked: Optional[bool] = None
llm_result: Optional[str] = None
processed: Optional[bool] = None
def convert_to_china_time(dt: Optional[datetime]) -> Optional[str]:
"""将 UTC 时间转换为中国时间 (UTC+8)"""
if dt is None:
return None
try:
china_tz = timezone(timedelta(hours=8))
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.astimezone(china_tz).isoformat()
except Exception:
return dt.isoformat() if dt else None
def format_alarm_response(alarm) -> dict:
"""格式化告警响应,将 UTC 时间转换为中国时间"""
return {
"id": alarm.id,
"camera_id": alarm.camera_id,
"roi_id": alarm.roi_id,
"event_type": alarm.event_type,
"confidence": alarm.confidence,
"snapshot_path": alarm.snapshot_path,
"llm_checked": alarm.llm_checked,
"llm_result": alarm.llm_result,
"processed": alarm.processed,
"created_at": convert_to_china_time(alarm.created_at),
}
@router.get("", response_model=List[dict])
def list_alarms(
camera_id: Optional[int] = None,
@@ -25,21 +61,7 @@ def list_alarms(
db: Session = Depends(get_db),
):
alarms = get_alarms(db, camera_id=camera_id, event_type=event_type, limit=limit, offset=offset)
return [
{
"id": alarm.id,
"camera_id": alarm.camera_id,
"roi_id": alarm.roi_id,
"event_type": alarm.event_type,
"confidence": alarm.confidence,
"snapshot_path": alarm.snapshot_path,
"llm_checked": alarm.llm_checked,
"llm_result": alarm.llm_result,
"processed": alarm.processed,
"created_at": alarm.created_at.isoformat() if alarm.created_at else None,
}
for alarm in alarms
]
return [format_alarm_response(alarm) for alarm in alarms]
@router.get("/stats")
@@ -55,29 +77,16 @@ def get_alarm(alarm_id: int, db: Session = Depends(get_db)):
alarm = next((a for a in alarms if a.id == alarm_id), None)
if not alarm:
raise HTTPException(status_code=404, detail="告警不存在")
return {
"id": alarm.id,
"camera_id": alarm.camera_id,
"roi_id": alarm.roi_id,
"event_type": alarm.event_type,
"confidence": alarm.confidence,
"snapshot_path": alarm.snapshot_path,
"llm_checked": alarm.llm_checked,
"llm_result": alarm.llm_result,
"processed": alarm.processed,
"created_at": alarm.created_at.isoformat() if alarm.created_at else None,
}
return format_alarm_response(alarm)
@router.put("/{alarm_id}")
def update_alarm_status(
alarm_id: int,
llm_checked: Optional[bool] = None,
llm_result: Optional[str] = None,
processed: Optional[bool] = None,
request: AlarmUpdateRequest = Body(...),
db: Session = Depends(get_db),
):
alarm = update_alarm(db, alarm_id, llm_checked=llm_checked, llm_result=llm_result, processed=processed)
alarm = update_alarm(db, alarm_id, llm_checked=request.llm_checked, llm_result=request.llm_result, processed=request.processed)
if not alarm:
raise HTTPException(status_code=404, detail="告警不存在")
return {"message": "更新成功"}

View File

@@ -1,3 +1,4 @@
from datetime import datetime, timezone, timedelta
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Body
@@ -15,6 +16,7 @@ from db.models import get_db
from inference.pipeline import get_pipeline
router = APIRouter(prefix="/api/cameras", tags=["摄像头管理"])
router2 = APIRouter(prefix="/api/camera", tags=["摄像头状态"])
class CameraUpdateRequest(BaseModel):
@@ -25,6 +27,19 @@ class CameraUpdateRequest(BaseModel):
enabled: Optional[bool] = None
def convert_to_china_time(dt: Optional[datetime]) -> Optional[str]:
"""将 UTC 时间转换为中国时间 (UTC+8)"""
if dt is None:
return None
try:
china_tz = timezone(timedelta(hours=8))
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt.astimezone(china_tz).isoformat()
except Exception:
return dt.isoformat() if dt else None
@router.get("", response_model=List[dict])
def list_cameras(
enabled_only: bool = True,
@@ -39,7 +54,7 @@ def list_cameras(
"enabled": cam.enabled,
"fps_limit": cam.fps_limit,
"process_every_n_frames": cam.process_every_n_frames,
"created_at": cam.created_at.isoformat() if cam.created_at else None,
"created_at": convert_to_china_time(cam.created_at),
}
for cam in cameras
]
@@ -57,24 +72,21 @@ def get_camera(camera_id: int, db: Session = Depends(get_db)):
"enabled": camera.enabled,
"fps_limit": camera.fps_limit,
"process_every_n_frames": camera.process_every_n_frames,
"created_at": camera.created_at.isoformat() if camera.created_at else None,
"created_at": convert_to_china_time(camera.created_at),
}
@router.post("", response_model=dict)
def add_camera(
name: str,
rtsp_url: str,
fps_limit: int = 30,
process_every_n_frames: int = 3,
request: CameraUpdateRequest = Body(...),
db: Session = Depends(get_db),
):
camera = create_camera(
db,
name=name,
rtsp_url=rtsp_url,
fps_limit=fps_limit,
process_every_n_frames=process_every_n_frames,
name=request.name,
rtsp_url=request.rtsp_url,
fps_limit=request.fps_limit or 30,
process_every_n_frames=request.process_every_n_frames or 3,
)
if camera.enabled:
@@ -163,3 +175,42 @@ def get_camera_status(camera_id: int, db: Session = Depends(get_db)):
"last_check_time": None,
"stream": stream_info,
}
router2 = APIRouter(prefix="/api/camera", tags=["摄像头状态"])
@router2.get("/status/all")
def get_all_camera_status(db: Session = Depends(get_db)):
from db.crud import get_all_cameras, get_camera_status as get_status
cameras = get_all_cameras(db, enabled_only=False)
pipeline = get_pipeline()
result = []
for cam in cameras:
status = get_status(db, cam.id)
stream = pipeline.stream_manager.get_stream(str(cam.id))
stream_info = stream.get_info() if stream else None
if status:
result.append({
"camera_id": cam.id,
"is_running": status.is_running,
"fps": status.fps,
"error_message": status.error_message,
"last_check_time": status.last_check_time.isoformat() if status.last_check_time else None,
"stream": stream_info,
})
else:
result.append({
"camera_id": cam.id,
"is_running": False,
"fps": 0.0,
"error_message": None,
"last_check_time": None,
"stream": stream_info,
})
return result

View File

@@ -1,7 +1,8 @@
import json
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Body
from pydantic import BaseModel
from sqlalchemy.orm import Session
from db.crud import (
@@ -20,6 +21,34 @@ from inference.roi.roi_filter import ROIFilter
router = APIRouter(prefix="/api/camera", tags=["ROI管理"])
class CreateROIRequest(BaseModel):
roi_id: str
name: str
roi_type: str
points: List[List[float]]
rule_type: str
direction: Optional[str] = None
stay_time: Optional[int] = None
threshold_sec: int = 300
confirm_sec: int = 10
return_sec: int = 30
working_hours: Optional[List[dict]] = None
class UpdateROIRequest(BaseModel):
name: Optional[str] = None
roi_type: Optional[str] = None
points: Optional[List[List[float]]] = None
rule_type: Optional[str] = None
direction: Optional[str] = None
stay_time: Optional[int] = None
enabled: Optional[bool] = None
threshold_sec: Optional[int] = None
confirm_sec: Optional[int] = None
return_sec: Optional[int] = None
working_hours: Optional[List[dict]] = None
def _invalidate_roi_cache(camera_id: int):
pipeline = get_pipeline()
pipeline.roi_filter.clear_cache(camera_id)
@@ -43,6 +72,7 @@ def list_rois(camera_id: int, db: Session = Depends(get_db)):
"threshold_sec": roi.threshold_sec,
"confirm_sec": roi.confirm_sec,
"return_sec": roi.return_sec,
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
}
for roi in roi_configs
]
@@ -66,37 +96,34 @@ def get_roi(camera_id: int, roi_id: int, db: Session = Depends(get_db)):
"threshold_sec": roi.threshold_sec,
"confirm_sec": roi.confirm_sec,
"return_sec": roi.return_sec,
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
}
@router.post("/{camera_id}/roi", response_model=dict)
def add_roi(
camera_id: int,
roi_id: str,
name: str,
roi_type: str,
points: List[List[float]],
rule_type: str,
direction: Optional[str] = None,
stay_time: Optional[int] = None,
threshold_sec: int = 360,
confirm_sec: int = 30,
return_sec: int = 5,
request: CreateROIRequest,
db: Session = Depends(get_db),
):
import json
working_hours_json = json.dumps(request.working_hours) if request.working_hours is not None else None
roi = create_roi(
db,
camera_id=camera_id,
roi_id=roi_id,
name=name,
roi_type=roi_type,
points=points,
rule_type=rule_type,
direction=direction,
stay_time=stay_time,
threshold_sec=threshold_sec,
confirm_sec=confirm_sec,
return_sec=return_sec,
roi_id=request.roi_id,
name=request.name,
roi_type=request.roi_type,
points=request.points,
rule_type=request.rule_type,
direction=request.direction,
stay_time=request.stay_time,
threshold_sec=request.threshold_sec,
confirm_sec=request.confirm_sec,
return_sec=request.return_sec,
working_hours=working_hours_json,
)
_invalidate_roi_cache(camera_id)
@@ -106,9 +133,15 @@ def add_roi(
"roi_id": roi.roi_id,
"name": roi.name,
"type": roi.roi_type,
"points": points,
"points": request.points,
"rule": roi.rule_type,
"direction": roi.direction,
"stay_time": roi.stay_time,
"enabled": roi.enabled,
"threshold_sec": roi.threshold_sec,
"confirm_sec": roi.confirm_sec,
"return_sec": roi.return_sec,
"working_hours": request.working_hours,
}
@@ -116,29 +149,25 @@ def add_roi(
def modify_roi(
camera_id: int,
roi_id: int,
name: Optional[str] = None,
points: Optional[List[List[float]]] = None,
rule_type: Optional[str] = None,
direction: Optional[str] = None,
stay_time: Optional[int] = None,
enabled: Optional[bool] = None,
threshold_sec: Optional[int] = None,
confirm_sec: Optional[int] = None,
return_sec: Optional[int] = None,
request: UpdateROIRequest,
db: Session = Depends(get_db),
):
import json
working_hours_json = json.dumps(request.working_hours) if request.working_hours is not None else None
roi = update_roi(
db,
roi_id=roi_id,
name=name,
points=points,
rule_type=rule_type,
direction=direction,
stay_time=stay_time,
enabled=enabled,
threshold_sec=threshold_sec,
confirm_sec=confirm_sec,
return_sec=return_sec,
name=request.name,
points=request.points,
rule_type=request.rule_type,
direction=request.direction,
stay_time=request.stay_time,
enabled=request.enabled,
threshold_sec=request.threshold_sec,
confirm_sec=request.confirm_sec,
return_sec=request.return_sec,
working_hours=working_hours_json,
)
if not roi:
raise HTTPException(status_code=404, detail="ROI不存在")
@@ -153,6 +182,7 @@ def modify_roi(
"points": json.loads(roi.points),
"rule": roi.rule_type,
"enabled": roi.enabled,
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
}

116
api/sync.py Normal file
View File

@@ -0,0 +1,116 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
from db.models import get_db
from services.sync_service import get_sync_service
router = APIRouter(prefix="/api/sync", tags=["同步"])
@router.get("/status")
def get_sync_status(db: Session = Depends(get_db)):
"""获取同步服务状态"""
from sqlalchemy import text
service = get_sync_service()
status = service.get_status()
pending_cameras = db.execute(
text("SELECT COUNT(*) FROM cameras WHERE pending_sync = 1")
).scalar() or 0
pending_rois = db.execute(
text("SELECT COUNT(*) FROM rois WHERE pending_sync = 1")
).scalar() or 0
pending_alarms = db.execute(
text("SELECT COUNT(*) FROM alarms WHERE upload_status = 'pending' OR upload_status = 'retry'")
).scalar() or 0
return {
"running": status["running"],
"cloud_enabled": status["cloud_enabled"],
"network_status": status["network_status"],
"device_id": status["device_id"],
"pending_sync": pending_cameras + pending_rois,
"pending_alarms": pending_alarms,
"details": {
"pending_cameras": pending_cameras,
"pending_rois": pending_rois,
"pending_alarms": pending_alarms
}
}
@router.get("/pending")
def get_pending_syncs(db: Session = Depends(get_db)):
"""获取待同步项列表"""
from sqlalchemy import text
from db.models import Camera, ROI, Alarm
pending_cameras = db.query(Camera).filter(Camera.pending_sync == True).all()
pending_rois = db.query(ROI).filter(ROI.pending_sync == True).all()
from db.session import SessionLocal
temp_db = SessionLocal()
try:
pending_alarms = temp_db.query(Alarm).filter(
Alarm.upload_status.in_(['pending', 'retry'])
).limit(100).all()
finally:
temp_db.close()
return {
"cameras": [{"id": c.id, "name": c.name} for c in pending_cameras],
"rois": [{"id": r.id, "name": r.name, "camera_id": r.camera_id} for r in pending_rois],
"alarms": [{"id": a.id, "camera_id": a.camera_id, "type": a.event_type} for a in pending_alarms]
}
@router.post("/trigger")
def trigger_sync():
"""手动触发同步"""
service = get_sync_service()
from db.session import SessionLocal
from db.crud import get_all_cameras, get_all_rois
from db.models import Camera, ROI
db = SessionLocal()
try:
cameras = get_all_cameras(db)
for camera in cameras:
service.queue_camera_sync(camera.id, 'update', {
'name': camera.name,
'rtsp_url': camera.rtsp_url,
'enabled': camera.enabled
})
db.query(Camera).filter(Camera.id == camera.id).update({'pending_sync': False})
db.commit()
rois = get_all_rois(db)
for roi in rois:
service.queue_roi_sync(roi.id, 'update', {
'name': roi.name,
'roi_type': roi.roi_type,
'points': roi.points,
'enabled': roi.enabled
})
db.query(ROI).filter(ROI.id == roi.id).update({'pending_sync': False})
db.commit()
return {"message": "同步任务已加入队列", "count": len(cameras) + len(rois)}
finally:
db.close()
@router.post("/clear-failed")
def clear_failed_syncs(db: Session = Depends(get_db)):
"""清除失败的同步标记"""
from sqlalchemy import text
db.execute(text("UPDATE cameras SET pending_sync = 0, sync_failed_at = NULL, sync_retry_count = 0"))
db.execute(text("UPDATE rois SET pending_sync = 0, sync_failed_at = NULL, sync_retry_count = 0"))
db.commit()
return {"message": "已清除所有失败的同步标记"}

View File

@@ -23,16 +23,17 @@ class DatabaseConfig(BaseModel):
class ModelConfig(BaseModel):
engine_path: str = "models/yolo11s.engine"
onnx_path: str = "models/yolo11s.onnx"
pt_model_path: str = "models/yolo11s.pt"
engine_path: str = "models/yolo11n.engine"
onnx_path: str = "models/yolo11n.onnx"
pt_model_path: str = "models/yolo11n.pt"
imgsz: List[int] = [640, 640]
conf_threshold: float = 0.5
iou_threshold: float = 0.45
device: int = 0
batch_size: int = 8
half: bool = True
use_onnx: bool = True
half: bool = False
use_onnx: bool = False
use_trt: bool = False
class StreamConfig(BaseModel):
@@ -65,9 +66,9 @@ class WorkingHours(BaseModel):
class AlgorithmsConfig(BaseModel):
leave_post_threshold_sec: int = 360
leave_post_confirm_sec: int = 30
leave_post_return_sec: int = 5
leave_post_threshold_sec: int = 300
leave_post_confirm_sec: int = 10
leave_post_return_sec: int = 30
intrusion_check_interval_sec: float = 1.0
intrusion_direction_sensitive: bool = False

View File

@@ -60,22 +60,17 @@ roi:
- "line"
max_points: 50 # 多边形最大顶点数
# 工作时间配置(全局默认)
working_hours:
- start: [8, 30] # 8:30
end: [11, 0] # 11:00
- start: [12, 0] # 12:00
end: [17, 30] # 17:30
# 工作时间配置(全局默认,空数组表示全天开启
working_hours: []
# 算法默认参数
algorithms:
leave_post:
default_threshold_sec: 360 # 离岗超时(6分钟)
confirm_sec: 30 # 岗确认时间
return_sec: 5 # 上岗确认时间
threshold_sec: 300 # 离岗超时(5分钟)
confirm_sec: 10 # 岗确认时间10秒
return_sec: 30 # 离岗缓冲时间30秒
intrusion:
check_interval_sec: 1.0 # 检测间隔
direction_sensitive: false # 方向敏感
cooldown_seconds: 300 # 入侵检测冷却时间(秒)
# 日志配置
logging:
@@ -94,7 +89,7 @@ monitoring:
# 大模型配置(预留)
llm:
enabled: false
api_key: ""
base_url: ""
api_key: "sk-21e61bef09074682b589da3bdbfe07a2"
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
model: "qwen3-vl-max"
timeout: 30

View File

@@ -136,9 +136,10 @@ def create_roi(
rule_type: str,
direction: Optional[str] = None,
stay_time: Optional[int] = None,
threshold_sec: int = 360,
confirm_sec: int = 30,
return_sec: int = 5,
threshold_sec: int = 300,
confirm_sec: int = 10,
return_sec: int = 30,
working_hours: Optional[str] = None,
) -> ROI:
import json
@@ -154,6 +155,7 @@ def create_roi(
threshold_sec=threshold_sec,
confirm_sec=confirm_sec,
return_sec=return_sec,
working_hours=working_hours,
)
db.add(roi)
db.commit()
@@ -173,6 +175,7 @@ def update_roi(
threshold_sec: Optional[int] = None,
confirm_sec: Optional[int] = None,
return_sec: Optional[int] = None,
working_hours: Optional[str] = None,
) -> Optional[ROI]:
import json
@@ -198,6 +201,8 @@ def update_roi(
roi.confirm_sec = confirm_sec
if return_sec is not None:
roi.return_sec = return_sec
if working_hours is not None:
roi.working_hours = working_hours
db.commit()
db.refresh(roi)
@@ -232,6 +237,7 @@ def get_roi_points(db: Session, camera_id: int) -> List[dict]:
"threshold_sec": roi.threshold_sec,
"confirm_sec": roi.confirm_sec,
"return_sec": roi.return_sec,
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
}
for roi in rois
]

View File

@@ -90,9 +90,10 @@ class ROI(Base):
direction: Mapped[Optional[str]] = mapped_column(String(32))
stay_time: Mapped[Optional[int]] = mapped_column(Integer)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
threshold_sec: Mapped[int] = mapped_column(Integer, default=360)
confirm_sec: Mapped[int] = mapped_column(Integer, default=30)
return_sec: Mapped[int] = mapped_column(Integer, default=5)
threshold_sec: Mapped[int] = mapped_column(Integer, default=300)
confirm_sec: Mapped[int] = mapped_column(Integer, default=10)
return_sec: Mapped[int] = mapped_column(Integer, default=30)
working_hours: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
pending_sync: Mapped[bool] = mapped_column(Boolean, default=False)
sync_version: Mapped[int] = mapped_column(Integer, default=0)
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)

View File

@@ -38,14 +38,48 @@ const CameraManagement: React.FC = () => {
const fetchCameras = async () => {
setLoading(true);
try {
const [camerasRes, statusRes] = await Promise.all([
const [camerasRes, statusRes, pipelineRes] = await Promise.all([
axios.get('/api/cameras?enabled_only=false'),
axios.get('/api/camera/status/all'),
axios.get('/api/pipeline/status')
]);
setCameras(camerasRes.data);
setCameraStatus(statusRes.data.cameras || {});
const statusMap: Record<number, CameraStatus> = {};
for (const cam of camerasRes.data) {
const camId = cam.id;
let status: CameraStatus = {
is_running: false,
fps: 0,
error_message: null,
last_check_time: null,
};
const pipelineStatus = pipelineRes.data.cameras?.[String(camId)];
if (pipelineStatus) {
status.is_running = pipelineStatus.is_running || false;
status.fps = pipelineStatus.fps || 0;
status.last_check_time = pipelineStatus.last_check_time;
}
const dbStatus = statusRes.data.find((s: any) => s.camera_id === camId);
if (dbStatus) {
if (!status.is_running) {
status.is_running = dbStatus.is_running || false;
}
status.fps = dbStatus.fps || status.fps;
status.error_message = dbStatus.error_message;
status.last_check_time = status.last_check_time || dbStatus.last_check_time;
}
statusMap[camId] = status;
}
setCameraStatus(statusMap);
} catch (err) {
message.error('获取摄像头列表失败');
console.error('获取摄像头状态失败', err);
} finally {
setLoading(false);
}
@@ -59,8 +93,8 @@ const CameraManagement: React.FC = () => {
const extractIP = (url: string): string => {
try {
const match = url.match(/:\/\/([^:]+):?(\d+)?\//);
return match ? match[1] : '未知';
const ipMatch = url.match(/(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})/);
return ipMatch ? ipMatch[1] : '未知';
} catch {
return '未知';
}

View File

@@ -24,6 +24,10 @@ const Dashboard: React.FC = () => {
const [recentAlerts, setRecentAlerts] = useState<Alert[]>([]);
const [cameraStatus, setCameraStatus] = useState<any[]>([]);
const handleViewSnapshot = (alert: Alert) => {
window.open(`/api/alarms/${alert.id}/snapshot`, '_blank');
};
useEffect(() => {
fetchStats();
fetchAlerts();
@@ -60,7 +64,8 @@ const Dashboard: React.FC = () => {
const res = await axios.get('/api/pipeline/status');
const cameras = Object.entries(res.data.cameras || {}).map(([id, info]) => ({
id,
...info as any,
is_running: (info as any).is_running || false,
fps: (info as any).fps || 0,
}));
setCameraStatus(cameras);
} catch (err) {
@@ -139,7 +144,7 @@ const Dashboard: React.FC = () => {
description={formatTime(alert.created_at)}
/>
{alert.snapshot_path && (
<Button type="link" size="small">
<Button type="link" size="small" onClick={() => handleViewSnapshot(alert)}>
</Button>
)}
@@ -159,8 +164,8 @@ const Dashboard: React.FC = () => {
title={`摄像头 ${cam.id}`}
description={
<Space>
<Tag color={cam.running ? 'green' : 'red'}>
{cam.running ? '运行中' : '已停止'}
<Tag color={cam.is_running ? 'green' : 'red'}>
{cam.is_running ? '运行中' : '已停止'}
</Tag>
<span>{cam.fps?.toFixed(1) || 0} FPS</span>
</Space>

View File

@@ -1,7 +1,14 @@
import React, { useEffect, useState, useRef } from 'react';
import { Card, Button, Space, Select, message, Drawer, Form, Input, InputNumber, Switch } from 'antd';
import { Card, Button, Space, Select, message, Drawer, Form, Input, InputNumber, Switch, TimePicker, Divider } from 'antd';
import { Stage, Layer, Rect, Line, Circle, Text as KonvaText } from 'react-konva';
import { RangePickerProps } from 'antd/es/date-picker';
import axios from 'axios';
import dayjs from 'dayjs';
interface WorkingHours {
start: number[];
end: number[];
}
interface ROI {
id: number;
@@ -13,6 +20,7 @@ interface ROI {
threshold_sec: number;
confirm_sec: number;
return_sec: number;
working_hours: WorkingHours[] | null;
}
interface Camera {
@@ -29,12 +37,38 @@ const ROIEditor: React.FC = () => {
const [selectedROI, setSelectedROI] = useState<ROI | null>(null);
const [drawerVisible, setDrawerVisible] = useState(false);
const [form] = Form.useForm();
const [workingHoursList, setWorkingHoursList] = useState<{start: dayjs.Dayjs | null, end: dayjs.Dayjs | null}[]>([]);
const [isDrawing, setIsDrawing] = useState(false);
const [tempPoints, setTempPoints] = useState<number[][]>([]);
const [backgroundImage, setBackgroundImage] = useState<HTMLImageElement | null>(null);
const stageRef = useRef<any>(null);
const addWorkingHours = () => {
setWorkingHoursList([...workingHoursList, { start: null, end: null }]);
};
const removeWorkingHours = (index: number) => {
const newList = workingHoursList.filter((_, i) => i !== index);
setWorkingHoursList(newList);
};
const updateWorkingHours = (index: number, field: 'start' | 'end', value: dayjs.Dayjs | null) => {
const newList = [...workingHoursList];
newList[index] = { ...newList[index], [field]: value };
setWorkingHoursList(newList);
};
const updateWorkingHoursRange = (index: number, start: dayjs.Dayjs | null, end: dayjs.Dayjs | null) => {
setWorkingHoursList(prev => {
const newList = [...prev];
if (newList[index]) {
newList[index] = { start, end };
}
return newList;
});
};
const fetchCameras = async () => {
try {
const res = await axios.get('/api/cameras?enabled_only=true');
@@ -95,16 +129,25 @@ const ROIEditor: React.FC = () => {
const handleSaveROI = async (values: any) => {
if (!selectedCamera || !selectedROI) return;
try {
const workingHours = workingHoursList
.filter(item => item.start && item.end)
.map(item => ({
start: [item.start!.hour(), item.start!.minute()],
end: [item.end!.hour(), item.end!.minute()],
}));
await axios.put(`/api/camera/${selectedCamera}/roi/${selectedROI.id}`, {
name: values.name,
roi_type: values.roi_type,
rule_type: values.rule_type,
threshold_sec: values.threshold_sec,
confirm_sec: values.confirm_sec,
working_hours: workingHours,
enabled: values.enabled,
});
message.success('保存成功');
setDrawerVisible(false);
setWorkingHoursList([]);
fetchROIs();
} catch (err: any) {
message.error(`保存失败: ${err.response?.data?.detail || '未知错误'}`);
@@ -150,6 +193,7 @@ const ROIEditor: React.FC = () => {
threshold_sec: 60,
confirm_sec: 5,
return_sec: 5,
working_hours: [],
})
.then(() => {
message.success('ROI添加成功');
@@ -212,6 +256,10 @@ const ROIEditor: React.FC = () => {
confirm_sec: roi.confirm_sec,
enabled: roi.enabled,
});
setWorkingHoursList(roi.working_hours?.map((wh: WorkingHours) => ({
start: wh.start ? dayjs().hour(wh.start[0]).minute(wh.start[1]) : null,
end: wh.end ? dayjs().hour(wh.end[0]).minute(wh.end[1]) : null,
})) || []);
setDrawerVisible(true);
}}
onMouseEnter={(e) => {
@@ -369,6 +417,10 @@ const ROIEditor: React.FC = () => {
confirm_sec: roi.confirm_sec,
enabled: roi.enabled,
});
setWorkingHoursList(roi.working_hours?.map((wh: WorkingHours) => ({
start: wh.start ? dayjs().hour(wh.start[0]).minute(wh.start[1]) : null,
end: wh.end ? dayjs().hour(wh.end[0]).minute(wh.end[1]) : null,
})) || []);
setDrawerVisible(true);
}}
>
@@ -403,6 +455,7 @@ const ROIEditor: React.FC = () => {
onClose={() => {
setDrawerVisible(false);
setSelectedROI(null);
setWorkingHoursList([]);
}}
width={400}
>
@@ -434,6 +487,39 @@ const ROIEditor: React.FC = () => {
<Form.Item name="confirm_sec" label="确认时间(秒)" rules={[{ required: true }]}>
<InputNumber min={5} style={{ width: '100%' }} />
</Form.Item>
<Divider></Divider>
<div>
{workingHoursList.map((item, index) => (
<Space key={index} align="baseline" style={{ display: 'flex', marginBottom: 8 }}>
<Form.Item label={index === 0 ? '时间段' : ''} style={{ marginBottom: 0 }}>
<TimePicker.RangePicker
format="HH:mm"
value={item.start && item.end ? [item.start, item.end] : null}
onChange={(dates) => {
if (dates && Array.isArray(dates) && dates.length >= 2 && dates[0] && dates[1]) {
updateWorkingHoursRange(index, dates[0], dates[1]);
} else {
updateWorkingHoursRange(index, null, null);
}
}}
/>
</Form.Item>
<Button
type="link"
danger
onClick={() => removeWorkingHours(index)}
>
</Button>
</Space>
))}
<Button type="dashed" onClick={addWorkingHours} block>
</Button>
</div>
<Form.Item style={{ fontSize: 12, color: '#999' }}>
使
</Form.Item>
</>
)}
<Form.Item name="enabled" label="启用状态" valuePropName="checked">
@@ -447,6 +533,7 @@ const ROIEditor: React.FC = () => {
<Button onClick={() => {
setDrawerVisible(false);
setSelectedROI(null);
setWorkingHoursList([]);
}}>
</Button>

View File

@@ -49,6 +49,10 @@ class ONNXEngine:
return img
def postprocess(self, output: np.ndarray, orig_img: np.ndarray) -> List[Results]:
import torch
import numpy as np
from ultralytics.engine.results import Boxes as BoxesObj, Results
c, n = output.shape
output = output.T
@@ -74,6 +78,9 @@ class ONNXEngine:
orig_h, orig_w = orig_img.shape[:2]
scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0]
if len(indices) == 0:
return [Results(orig_img=orig_img, path="", names={0: "person"})]
filtered_boxes = []
for idx in indices:
if idx >= len(boxes):
@@ -82,30 +89,30 @@ class ONNXEngine:
x1, y1, x2, y2 = box
w, h = x2 - x1, y2 - y1
filtered_boxes.append([
int(x1 * scale_x),
int(y1 * scale_y),
int(w * scale_x),
int(h * scale_y),
float(x1 * scale_x),
float(y1 * scale_y),
float(w * scale_x),
float(h * scale_y),
float(scores[idx]),
int(classes[idx])
])
from ultralytics.engine.results import Boxes as BoxesObj
if filtered_boxes:
box_tensor = torch.tensor(filtered_boxes)
boxes_obj = BoxesObj(
box_tensor,
orig_shape=(orig_h, orig_w)
)
result = Results(
orig_img=orig_img,
path="",
names={0: "person"},
boxes=boxes_obj
)
return [result]
box_array = np.array(filtered_boxes, dtype=np.float32)
else:
box_array = np.zeros((0, 6), dtype=np.float32)
return [Results(orig_img=orig_img, path="", names={0: "person"})]
boxes_obj = BoxesObj(
torch.from_numpy(box_array),
orig_shape=(orig_h, orig_w)
)
result = Results(
orig_img=orig_img,
path="",
names={0: "person"},
boxes=boxes_obj
)
return [result]
def inference(self, images: List[np.ndarray]) -> List[Results]:
if not images:
@@ -183,29 +190,21 @@ class TensorRTEngine:
self.context = self.engine.create_execution_context()
self.stream = torch.cuda.Stream(device=self.device)
self.batch_size = 1
for i in range(self.engine.num_io_tensors):
name = self.engine.get_tensor_name(i)
dtype = self.engine.get_tensor_dtype(name)
shape = list(self.engine.get_tensor_shape(name))
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
if -1 in shape:
shape = [self.batch_size if d == -1 else d for d in shape]
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
self.input_buffer = buffer
self.input_name = name
else:
if -1 in shape:
shape = [self.batch_size if d == -1 else d for d in shape]
if dtype == trt.float16:
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
else:
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
self.output_buffers.append(buffer)
if self.output_name is None:
self.output_name = name
@@ -215,8 +214,6 @@ class TensorRTEngine:
stream_handle = torch.cuda.current_stream(self.device).cuda_stream
self.context.set_optimization_profile_async(0, stream_handle)
self.batch_size = 1
def preprocess(self, frame: np.ndarray) -> torch.Tensor:
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, self.imgsz)
@@ -247,9 +244,6 @@ class TensorRTEngine:
self.input_name, input_tensor.contiguous().data_ptr()
)
input_shape = list(input_tensor.shape)
self.context.set_input_shape(self.input_name, input_shape)
torch.cuda.synchronize(self.stream)
self.context.execute_async_v3(self.stream.cuda_stream)
torch.cuda.synchronize(self.stream)
@@ -336,6 +330,10 @@ class Boxes:
self.orig_shape = orig_shape
self.is_track = is_track
@property
def ndim(self) -> int:
return self.data.ndim
@property
def xyxy(self):
if self.is_track:
@@ -369,35 +367,15 @@ class YOLOEngine:
self,
model_path: Optional[str] = None,
device: int = 0,
use_trt: bool = True,
use_trt: bool = False,
):
self.use_trt = False
self.onnx_engine = None
self.trt_engine = None
self.model = None
self.device = device
config = get_config()
if use_trt:
try:
self.trt_engine = TensorRTEngine(device=device)
self.trt_engine.warmup()
self.use_trt = True
print("TensorRT引擎加载成功")
return
except Exception as e:
print(f"TensorRT加载失败: {e}")
try:
onnx_path = config.model.onnx_path
if os.path.exists(onnx_path):
self.onnx_engine = ONNXEngine(device=device)
self.onnx_engine.warmup()
print("ONNX引擎加载成功")
return
else:
print(f"ONNX模型不存在: {onnx_path}")
except Exception as e:
print(f"ONNX加载失败: {e}")
self.config = config
try:
pt_path = model_path or config.model.pt_model_path
@@ -409,26 +387,17 @@ class YOLOEngine:
raise FileNotFoundError(f"PT文件无效或不存在: {pt_path}")
except Exception as e:
print(f"PyTorch加载失败: {e}")
raise RuntimeError("所有模型加载方式均失败")
raise RuntimeError("无法加载模型")
def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]:
if self.use_trt and self.trt_engine:
if self.model is not None:
try:
return self.trt_engine.inference_single(frame)
return self.model(frame, imgsz=self.config.model.imgsz, conf=self.config.model.conf_threshold, iou=self.config.model.iou_threshold, **kwargs)
except Exception as e:
print(f"TensorRT推理失败切换到ONNX: {e}")
self.use_trt = False
if self.onnx_engine:
return self.onnx_engine.inference_single(frame)
elif self.model:
return self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
else:
return []
elif self.onnx_engine:
return self.onnx_engine.inference_single(frame)
else:
results = self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
return results
print(f"PyTorch推理失败: {e}")
print("警告: 模型不可用,返回空结果")
return []
def __del__(self):
if self.trt_engine:

View File

@@ -7,6 +7,7 @@ from collections import deque
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
from config import get_config
@@ -186,9 +187,10 @@ class InferencePipeline:
roi_id,
rule_type,
{
"threshold_sec": roi_config.get("threshold_sec", 360),
"confirm_sec": roi_config.get("confirm_sec", 30),
"return_sec": roi_config.get("return_sec", 5),
"threshold_sec": roi_config.get("threshold_sec", 300),
"confirm_sec": roi_config.get("confirm_sec", 10),
"return_sec": roi_config.get("return_sec", 30),
"working_hours": roi_config.get("working_hours"),
},
)
@@ -216,22 +218,30 @@ class InferencePipeline:
else:
filtered_detections = detections
roi_detections: Dict[str, List[Dict]] = {}
for detection in filtered_detections:
matched_rois = detection.get("matched_rois", [])
for roi_conf in matched_rois:
roi_id = roi_conf["roi_id"]
rule_type = roi_conf["rule"]
if roi_id not in roi_detections:
roi_detections[roi_id] = []
roi_detections[roi_id].append(detection)
alerts = self.algo_manager.process(
roi_id,
str(camera_id),
rule_type,
[detection],
datetime.now(),
)
for roi_config in roi_configs:
roi_id = roi_config["roi_id"]
rule_type = roi_config["rule"]
roi_dets = roi_detections.get(roi_id, [])
for alert in alerts:
self._handle_alert(camera_id, alert, frame, roi_conf)
alerts = self.algo_manager.process(
roi_id,
str(camera_id),
rule_type,
roi_dets,
datetime.now(),
)
for alert in alerts:
self._handle_alert(camera_id, alert, frame, roi_config)
def _handle_alert(
self,
@@ -322,20 +332,23 @@ class InferencePipeline:
print("推理pipeline已停止")
def get_status(self) -> Dict[str, Any]:
return {
result = {
"running": self.running,
"camera_count": len(self.camera_threads),
"cameras": {
cid: {
"running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(),
"fps": self.get_camera_fps(cid),
"frame_time": self.camera_frame_times.get(cid).isoformat() if self.camera_frame_times.get(cid) else None,
}
for cid in self.camera_threads
},
"cameras": {},
"event_queue_size": len(self.event_queue),
}
for cid in self.camera_threads:
frame_time = self.camera_frame_times.get(cid)
result["cameras"][str(cid)] = {
"is_running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(),
"fps": self.get_camera_fps(cid),
"last_check_time": frame_time.isoformat() if frame_time else None,
}
return result
_pipeline: Optional[InferencePipeline] = None

View File

@@ -0,0 +1,167 @@
import json
import threading
import time
from typing import Dict, List, Optional, Callable
from collections import deque
class ROICacheManager:
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
self._cache: Dict[int, List[Dict]] = {}
self._cache_timestamps: Dict[int, float] = {}
self._refresh_interval = 10.0
self._db_session_factory = None
self._refresh_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._last_refresh_time = 0
self._on_cache_update: Optional[Callable[[int], None]] = None
self._update_callbacks: Dict[int, List[Callable]] = {}
def initialize(self, session_factory, refresh_interval: float = 10.0):
self._db_session_factory = session_factory
self._refresh_interval = refresh_interval
def start_background_refresh(self):
if self._refresh_thread is not None and self._refresh_thread.is_alive():
return
self._stop_event.clear()
self._refresh_thread = threading.Thread(target=self._background_refresh_loop, daemon=True)
self._refresh_thread.start()
def stop_background_refresh(self):
self._stop_event.set()
if self._refresh_thread is not None:
self._refresh_thread.join(timeout=2)
self._refresh_thread = None
def _background_refresh_loop(self):
while not self._stop_event.is_set():
try:
self.refresh_all()
except Exception:
pass
self._stop_event.wait(self._refresh_interval)
def _load_rois_from_db(self, camera_id: int) -> List[Dict]:
if self._db_session_factory is None:
return []
session = self._db_session_factory()
try:
from db.crud import get_all_rois
rois = get_all_rois(session, camera_id)
roi_configs = []
for roi in rois:
try:
points = json.loads(roi.points) if isinstance(roi.points, str) else roi.points
except (json.JSONDecodeError, TypeError):
points = []
roi_config = {
"id": roi.id,
"roi_id": roi.roi_id,
"name": roi.name,
"type": roi.roi_type,
"points": points,
"rule": roi.rule_type,
"direction": roi.direction,
"enabled": roi.enabled,
"threshold_sec": roi.threshold_sec,
"confirm_sec": roi.confirm_sec,
"return_sec": roi.return_sec,
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
}
roi_configs.append(roi_config)
return roi_configs
finally:
session.close()
def refresh_all(self):
if self._db_session_factory is None:
return
current_time = time.time()
if current_time - self._last_refresh_time < 1.0:
return
self._last_refresh_time = current_time
camera_ids = list(self._cache.keys())
for camera_id in camera_ids:
try:
new_rois = self._load_rois_from_db(camera_id)
old_rois_str = str(self._cache.get(camera_id, []))
new_rois_str = str(new_rois)
if old_rois_str != new_rois_str:
self._cache[camera_id] = new_rois
self._cache_timestamps[camera_id] = current_time
self._notify_update(camera_id)
except Exception:
pass
def get_rois(self, camera_id: int, force_refresh: bool = False) -> List[Dict]:
if force_refresh or camera_id not in self._cache:
self._cache[camera_id] = self._load_rois_from_db(camera_id)
self._cache_timestamps[camera_id] = time.time()
return self._cache.get(camera_id, [])
def get_rois_by_rule(self, camera_id: int, rule_type: str) -> List[Dict]:
rois = self.get_rois(camera_id)
return [roi for roi in rois if roi.get("rule") == rule_type and roi.get("enabled", True)]
def invalidate(self, camera_id: Optional[int] = None):
if camera_id is None:
self._cache.clear()
self._cache_timestamps.clear()
elif camera_id in self._cache:
del self._cache[camera_id]
if camera_id in self._cache_timestamps:
del self._cache_timestamps[camera_id]
def register_update_callback(self, camera_id: int, callback: Callable):
if camera_id not in self._update_callbacks:
self._update_callbacks[camera_id] = []
self._update_callbacks[camera_id].append(callback)
def _notify_update(self, camera_id: int):
if camera_id in self._update_callbacks:
for callback in self._update_callbacks[camera_id]:
try:
callback(camera_id)
except Exception:
pass
def get_cache_info(self) -> Dict:
return {
"camera_count": len(self._cache),
"refresh_interval": self._refresh_interval,
"cameras": {
cam_id: {
"roi_count": len(rois),
"last_update": self._cache_timestamps.get(cam_id, 0),
}
for cam_id, rois in self._cache.items()
},
}
def get_roi_cache() -> ROICacheManager:
return ROICacheManager()

View File

@@ -1,5 +1,7 @@
import os
import sys
import time
from collections import deque
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
@@ -8,15 +10,18 @@ import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sort import Sort
class LeavePostAlgorithm:
STATE_ON_DUTY = "ON_DUTY"
STATE_OFF_DUTY_COUNTDOWN = "OFF_DUTY_COUNTDOWN"
STATE_NON_WORK_TIME = "NON_WORK_TIME"
STATE_INIT = "INIT"
def __init__(
self,
threshold_sec: int = 360,
confirm_sec: int = 30,
return_sec: int = 5,
threshold_sec: int = 300,
confirm_sec: int = 10,
return_sec: int = 30,
working_hours: Optional[List[Dict]] = None,
):
self.threshold_sec = threshold_sec
@@ -24,12 +29,17 @@ class LeavePostAlgorithm:
self.return_sec = return_sec
self.working_hours = working_hours or []
self.track_states: Dict[str, Dict[str, Any]] = {}
self.tracker = Sort(max_age=10, min_hits=2, iou_threshold=0.3)
self.alert_cooldowns: Dict[str, datetime] = {}
self.cooldown_seconds = 300
self.state: str = self.STATE_INIT
self.state_start_time: Optional[datetime] = None
self.on_duty_window = deque()
self.alarm_sent: bool = False
self.last_person_seen_time: Optional[datetime] = None
self.last_detection_time: Optional[datetime] = None
self.init_start_time: Optional[datetime] = None
def is_in_working_hours(self, dt: Optional[datetime] = None) -> bool:
if not self.working_hours:
return True
@@ -45,159 +55,199 @@ class LeavePostAlgorithm:
return False
def check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
matched_rois = detection.get("matched_rois", [])
for roi in matched_rois:
if roi.get("roi_id") == roi_id:
return True
return False
def process(
self,
roi_id: str,
camera_id: str,
tracks: List[Dict],
current_time: Optional[datetime] = None,
) -> List[Dict]:
if not self.is_in_working_hours(current_time):
return []
if not tracks:
return []
detections = []
for track in tracks:
bbox = track.get("bbox", [])
if len(bbox) >= 4:
detections.append(bbox + [track.get("conf", 0.0)])
if not detections:
return []
detections = np.array(detections)
tracked = self.tracker.update(detections)
alerts = []
current_time = current_time or datetime.now()
for track_data in tracked:
x1, y1, x2, y2, track_id = track_data
track_id = str(int(track_id))
roi_has_person = False
for det in tracks:
if self.check_detection_in_roi(det, roi_id):
roi_has_person = True
break
if track_id not in self.track_states:
self.track_states[track_id] = {
"first_seen": current_time,
"last_seen": current_time,
"off_duty_start": None,
"alerted": False,
"last_position": (x1, y1, x2, y2),
}
in_work = self.is_in_working_hours(current_time)
alerts = []
state = self.track_states[track_id]
state["last_seen"] = current_time
state["last_position"] = (x1, y1, x2, y2)
if not in_work:
self.state = self.STATE_NON_WORK_TIME
self.last_person_seen_time = None
self.last_detection_time = None
self.on_duty_window.clear()
self.alarm_sent = False
self.init_start_time = None
else:
if self.state == self.STATE_NON_WORK_TIME:
self.state = self.STATE_INIT
self.init_start_time = current_time
self.on_duty_window.clear()
self.alarm_sent = False
if state["off_duty_start"] is None:
off_duty_duration = (current_time - state["first_seen"]).total_seconds()
if off_duty_duration > self.confirm_sec:
state["off_duty_start"] = current_time
else:
elapsed = (current_time - state["off_duty_start"]).total_seconds()
if elapsed > self.threshold_sec:
if not state["alerted"]:
cooldown_key = f"{camera_id}_{track_id}"
now = datetime.now()
if self.state == self.STATE_INIT:
if roi_has_person:
self.state = self.STATE_ON_DUTY
self.state_start_time = current_time
self.on_duty_window.clear()
self.on_duty_window.append((current_time, True))
self.last_person_seen_time = current_time
self.last_detection_time = current_time
self.init_start_time = None
else:
if self.init_start_time is None:
self.init_start_time = current_time
elapsed_since_init = (current_time - self.init_start_time).total_seconds()
if elapsed_since_init >= self.threshold_sec:
self.state = self.STATE_OFF_DUTY_COUNTDOWN
self.state_start_time = current_time
self.alarm_sent = False
elif self.state == self.STATE_ON_DUTY:
if roi_has_person:
self.last_person_seen_time = current_time
self.last_detection_time = current_time
self.on_duty_window.append((current_time, True))
while self.on_duty_window and (current_time - self.on_duty_window[0][0]).total_seconds() > self.confirm_sec:
self.on_duty_window.popleft()
else:
self.on_duty_window.append((current_time, False))
while self.on_duty_window and (current_time - self.on_duty_window[0][0]).total_seconds() > self.confirm_sec:
self.on_duty_window.popleft()
hit_ratio = sum(1 for t, detected in self.on_duty_window if detected) / max(len(self.on_duty_window), 1)
if hit_ratio == 0:
self.state = self.STATE_OFF_DUTY_COUNTDOWN
self.state_start_time = current_time
self.alarm_sent = False
elif self.state == self.STATE_OFF_DUTY_COUNTDOWN:
elapsed = (current_time - self.state_start_time).total_seconds()
if roi_has_person:
self.state = self.STATE_ON_DUTY
self.state_start_time = current_time
self.on_duty_window.clear()
self.on_duty_window.append((current_time, True))
self.last_person_seen_time = current_time
self.alarm_sent = False
elif elapsed >= self.threshold_sec:
if not self.alarm_sent:
cooldown_key = f"{roi_id}"
if cooldown_key not in self.alert_cooldowns or (
now - self.alert_cooldowns[cooldown_key]
current_time - self.alert_cooldowns[cooldown_key]
).total_seconds() > self.cooldown_seconds:
bbox = self.get_latest_bbox_in_roi(tracks, roi_id)
elapsed_minutes = int(elapsed / 60)
alerts.append({
"track_id": track_id,
"bbox": [x1, y1, x2, y2],
"track_id": roi_id,
"bbox": bbox,
"off_duty_duration": elapsed,
"alert_type": "leave_post",
"message": f"离岗超过 {int(elapsed / 60)} 分钟",
"message": f"离岗超过 {elapsed_minutes} 分钟",
})
state["alerted"] = True
self.alert_cooldowns[cooldown_key] = now
else:
if elapsed < self.return_sec:
state["off_duty_start"] = None
state["alerted"] = False
cleanup_time = current_time - timedelta(minutes=5)
for track_id, state in list(self.track_states.items()):
if state["last_seen"] < cleanup_time:
del self.track_states[track_id]
self.alarm_sent = True
self.alert_cooldowns[cooldown_key] = current_time
return alerts
def get_latest_bbox_in_roi(self, tracks: List[Dict], roi_id: str) -> List[float]:
for det in tracks:
if self.check_detection_in_roi(det, roi_id):
return det.get("bbox", [])
return []
def reset(self):
self.track_states.clear()
self.state = self.STATE_INIT
self.state_start_time = None
self.on_duty_window.clear()
self.alarm_sent = False
self.last_person_seen_time = None
self.last_detection_time = None
self.init_start_time = None
self.alert_cooldowns.clear()
def get_state(self, track_id: str) -> Optional[Dict[str, Any]]:
return self.track_states.get(track_id)
return {
"state": self.state,
"alarm_sent": self.alarm_sent,
"last_person_seen_time": self.last_person_seen_time,
}
class IntrusionAlgorithm:
def __init__(
self,
check_interval_sec: float = 1.0,
direction_sensitive: bool = False,
):
self.check_interval_sec = check_interval_sec
self.direction_sensitive = direction_sensitive
def __init__(self, cooldown_seconds: int = 300):
self.cooldown_seconds = cooldown_seconds
self.last_alert_time: Dict[str, float] = {}
self.alert_triggered: Dict[str, bool] = {}
self.last_check_times: Dict[str, float] = {}
self.tracker = Sort(max_age=5, min_hits=1, iou_threshold=0.3)
def is_roi_has_person(self, tracks: List[Dict], roi_id: str) -> bool:
for det in tracks:
matched_rois = det.get("matched_rois", [])
for roi in matched_rois:
if roi.get("roi_id") == roi_id:
return True
return False
self.alert_cooldowns: Dict[str, datetime] = {}
self.cooldown_seconds = 300
def get_latest_bbox_in_roi(self, tracks: List[Dict], roi_id: str) -> List[float]:
for det in tracks:
matched_rois = det.get("matched_rois", [])
for roi in matched_rois:
if roi.get("roi_id") == roi_id:
return det.get("bbox", [])
return []
def process(
self,
roi_id: str,
camera_id: str,
tracks: List[Dict],
current_time: Optional[datetime] = None,
) -> List[Dict]:
if not tracks:
roi_has_person = self.is_roi_has_person(tracks, roi_id)
if not roi_has_person:
return []
detections = []
for track in tracks:
bbox = track.get("bbox", [])
if len(bbox) >= 4:
detections.append(bbox + [track.get("conf", 0.0)])
now = time.monotonic()
key = f"{camera_id}_{roi_id}"
if not detections:
if key not in self.last_alert_time:
self.last_alert_time[key] = 0
self.alert_triggered[key] = False
if now - self.last_alert_time[key] >= self.cooldown_seconds:
self.last_alert_time[key] = now
self.alert_triggered[key] = False
if self.alert_triggered[key]:
return []
current_ts = current_time.timestamp() if current_time else datetime.now().timestamp()
bbox = self.get_latest_bbox_in_roi(tracks, roi_id)
self.alert_triggered[key] = True
if camera_id in self.last_check_times:
if current_ts - self.last_check_times[camera_id] < self.check_interval_sec:
return []
self.last_check_times[camera_id] = current_ts
detections = np.array(detections)
tracked = self.tracker.update(detections)
alerts = []
now = datetime.now()
for track_data in tracked:
x1, y1, x2, y2, track_id = track_data
cooldown_key = f"{camera_id}_{int(track_id)}"
if cooldown_key not in self.alert_cooldowns or (
now - self.alert_cooldowns[cooldown_key]
).total_seconds() > self.cooldown_seconds:
alerts.append({
"track_id": str(int(track_id)),
"bbox": [x1, y1, x2, y2],
"alert_type": "intrusion",
"confidence": track_data[4] if len(track_data) > 4 else 0.0,
"message": "检测到周界入侵",
})
self.alert_cooldowns[cooldown_key] = now
return alerts
return [{
"roi_id": roi_id,
"bbox": bbox,
"alert_type": "intrusion",
"message": "检测到周界入侵",
}]
def reset(self):
self.last_check_times.clear()
self.alert_cooldowns.clear()
self.last_alert_time.clear()
self.alert_triggered.clear()
class AlgorithmManager:
@@ -207,13 +257,12 @@ class AlgorithmManager:
self.default_params = {
"leave_post": {
"threshold_sec": 360,
"confirm_sec": 30,
"return_sec": 5,
"threshold_sec": 300,
"confirm_sec": 10,
"return_sec": 30,
},
"intrusion": {
"check_interval_sec": 1.0,
"direction_sensitive": False,
"cooldown_seconds": 300,
},
}
@@ -235,16 +284,16 @@ class AlgorithmManager:
algo_params.update(params)
if algorithm_type == "leave_post":
roi_working_hours = algo_params.get("working_hours") or self.working_hours
self.algorithms[roi_id]["leave_post"] = LeavePostAlgorithm(
threshold_sec=algo_params.get("threshold_sec", 360),
confirm_sec=algo_params.get("confirm_sec", 30),
return_sec=algo_params.get("return_sec", 5),
working_hours=self.working_hours,
threshold_sec=algo_params.get("threshold_sec", 300),
confirm_sec=algo_params.get("confirm_sec", 10),
return_sec=algo_params.get("return_sec", 30),
working_hours=roi_working_hours,
)
elif algorithm_type == "intrusion":
self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm(
check_interval_sec=algo_params.get("check_interval_sec", 1.0),
direction_sensitive=algo_params.get("direction_sensitive", False),
cooldown_seconds=algo_params.get("cooldown_seconds", 300),
)
def process(
@@ -258,7 +307,7 @@ class AlgorithmManager:
algo = self.algorithms.get(roi_id, {}).get(algorithm_type)
if algo is None:
return []
return algo.process(camera_id, tracks, current_time)
return algo.process(roi_id, camera_id, tracks, current_time)
def update_roi_params(
self,
@@ -297,7 +346,13 @@ class AlgorithmManager:
status = {}
if roi_id in self.algorithms:
for algo_type, algo in self.algorithms[roi_id].items():
status[algo_type] = {
"track_count": len(getattr(algo, "track_states", {})),
}
if algo_type == "leave_post":
status[algo_type] = {
"state": getattr(algo, "state", "INIT_STATE"),
"alarm_sent": getattr(algo, "alarm_sent", False),
}
else:
status[algo_type] = {
"track_count": len(getattr(algo, "track_states", {})),
}
return status

View File

@@ -133,3 +133,109 @@
2026-01-21 13:18:55,795 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 13:18:55,809 - security_monitor - INFO - 数据库初始化完成
2026-01-21 13:19:08,492 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:01:21,015 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:01:21,257 - security_monitor - INFO - 系统已关闭
2026-01-21 14:03:48,547 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:03:48,563 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:04:01,197 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:04:20,191 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:04:20,414 - security_monitor - INFO - 系统已关闭
2026-01-21 14:05:48,342 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:05:48,355 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:06:00,984 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:07:24,065 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:07:24,222 - security_monitor - INFO - 系统已关闭
2026-01-21 14:08:10,073 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:08:10,088 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:08:22,715 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:09:05,249 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:09:05,480 - security_monitor - INFO - 系统已关闭
2026-01-21 14:11:29,491 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:11:29,513 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:11:42,900 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:14:04,974 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:14:05,161 - security_monitor - INFO - 系统已关闭
2026-01-21 14:14:41,203 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:14:41,220 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:14:54,380 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:15:30,975 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:15:31,180 - security_monitor - INFO - 系统已关闭
2026-01-21 14:16:24,472 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:16:24,485 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:16:37,611 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:17:01,178 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:17:01,420 - security_monitor - INFO - 系统已关闭
2026-01-21 14:18:00,008 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:18:00,022 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:18:13,126 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:18:13,128 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:18:21,683 - security_monitor - INFO - 系统已关闭
2026-01-21 14:20:04,985 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:20:04,999 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:20:18,151 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:21:24,782 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:21:24,927 - security_monitor - INFO - 系统已关闭
2026-01-21 14:22:48,064 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:22:48,078 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:23:01,270 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:23:13,509 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:23:13,628 - security_monitor - INFO - 系统已关闭
2026-01-21 14:24:16,374 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:24:16,386 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:24:29,425 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:24:42,751 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:24:42,846 - security_monitor - INFO - 系统已关闭
2026-01-21 14:25:25,549 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:25:25,562 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:25:38,636 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:26:02,871 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:26:03,124 - security_monitor - INFO - 系统已关闭
2026-01-21 14:26:45,885 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:26:45,899 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:26:59,042 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:27:26,873 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:27:26,980 - security_monitor - INFO - 系统已关闭
2026-01-21 14:31:38,376 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:31:38,390 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:31:51,594 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:32:17,471 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:32:17,536 - security_monitor - INFO - 系统已关闭
2026-01-21 14:32:53,841 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:32:53,855 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:33:06,946 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:34:30,645 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:34:30,818 - security_monitor - INFO - 系统已关闭
2026-01-21 14:38:24,673 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:38:24,685 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:38:37,183 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:39:04,359 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:39:04,486 - security_monitor - INFO - 系统已关闭
2026-01-21 14:40:07,246 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:40:07,259 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:40:19,745 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 14:40:33,742 - security_monitor - INFO - 正在关闭系统...
2026-01-21 14:40:33,863 - security_monitor - INFO - 系统已关闭
2026-01-21 14:41:27,191 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 14:41:27,205 - security_monitor - INFO - 数据库初始化完成
2026-01-21 14:41:39,701 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 15:03:14,674 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 15:03:14,688 - security_monitor - INFO - 数据库初始化完成
2026-01-21 15:03:27,230 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 15:06:28,976 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 15:06:28,990 - security_monitor - INFO - 数据库初始化完成
2026-01-21 15:06:41,537 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 15:06:41,539 - security_monitor - INFO - 正在关闭系统...
2026-01-21 15:06:49,686 - security_monitor - INFO - 系统已关闭
2026-01-21 15:07:27,870 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 15:07:27,884 - security_monitor - INFO - 数据库初始化完成
2026-01-21 15:07:40,380 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 15:07:58,160 - security_monitor - INFO - 正在关闭系统...
2026-01-21 15:07:58,299 - security_monitor - INFO - 系统已关闭
2026-01-21 15:08:28,521 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 15:08:28,533 - security_monitor - INFO - 数据库初始化完成
2026-01-21 15:08:41,019 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2
2026-01-21 15:09:16,894 - security_monitor - INFO - 正在关闭系统...
2026-01-21 15:09:17,139 - security_monitor - INFO - 系统已关闭
2026-01-21 15:09:41,042 - security_monitor - INFO - 启动安保异常行为识别系统
2026-01-21 15:09:41,055 - security_monitor - INFO - 数据库初始化完成
2026-01-21 15:09:53,555 - security_monitor - INFO - 推理Pipeline启动活跃摄像头数: 2

35
main.py
View File

@@ -11,15 +11,28 @@ os.environ["TENSORRT_DISABLE_MYELIN"] = "1"
import cv2
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, StreamingResponse
from sqlalchemy.orm import Session
from prometheus_client import start_http_server
from db.models import get_db
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from ultralytics.engine.results import Boxes as UltralyticsBoxes
def _patch_boxes_ndim():
if not hasattr(UltralyticsBoxes, 'ndim'):
@property
def ndim(self):
return self.data.ndim
UltralyticsBoxes.ndim = ndim
_patch_boxes_ndim()
from api.alarm import router as alarm_router
from api.camera import router as camera_router
from api.camera import router as camera_router, router2 as camera_status_router
from api.roi import router as roi_router
from api.sync import router as sync_router
from config import get_config, load_config
@@ -82,6 +95,7 @@ app.add_middleware(
)
app.include_router(camera_router)
app.include_router(camera_status_router)
app.include_router(roi_router)
app.include_router(alarm_router)
app.include_router(sync_router)
@@ -134,6 +148,23 @@ async def get_snapshot_base64(camera_id: int):
return {"image": base64.b64encode(buffer).decode("utf-8")}
@app.get("/api/alarms/{alarm_id}/snapshot")
async def get_alarm_snapshot(alarm_id: int, db: Session = Depends(get_db)):
from db.models import Alarm
alarm = db.query(Alarm).filter(Alarm.id == alarm_id).first()
if not alarm:
raise HTTPException(status_code=404, detail="告警不存在")
if not alarm.snapshot_path:
raise HTTPException(status_code=404, detail="该告警没有截图")
if not os.path.exists(alarm.snapshot_path):
raise HTTPException(status_code=404, detail="截图文件不存在")
return FileResponse(alarm.snapshot_path, media_type="image/jpeg")
@app.get("/api/camera/{camera_id}/detect")
async def get_detection_with_overlay(camera_id: int):
pipeline = get_pipeline()

71
migrate_db.py Normal file
View File

@@ -0,0 +1,71 @@
import sqlite3
db_path = 'security_monitor.db'
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
def add_column(table_name, col_name, col_type, default_value=None):
try:
if default_value:
cursor.execute(f'ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type} DEFAULT {default_value}')
else:
cursor.execute(f'ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type}')
print(f'添加列 {table_name}.{col_name} 成功')
return True
except sqlite3.OperationalError as e:
if 'duplicate column name' in str(e):
print(f'{table_name}.{col_name} 已存在')
return True
else:
print(f'添加列 {table_name}.{col_name} 失败: {e}')
return False
print('=== 数据库迁移脚本 ===')
print()
# cameras 表
print('更新 cameras 表:')
add_column('cameras', 'cloud_id', 'INTEGER')
add_column('cameras', 'pending_sync', 'BOOLEAN', '0')
add_column('cameras', 'sync_failed_at', 'TIMESTAMP')
add_column('cameras', 'sync_retry_count', 'INTEGER', '0')
print()
# rois 表
print('更新 rois 表:')
add_column('rois', 'cloud_id', 'INTEGER')
add_column('rois', 'pending_sync', 'BOOLEAN', '0')
add_column('rois', 'sync_failed_at', 'TIMESTAMP')
add_column('rois', 'sync_retry_count', 'INTEGER', '0')
add_column('rois', 'sync_version', 'INTEGER', '0')
print()
# alarms 表
print('更新 alarms 表:')
add_column('alarms', 'cloud_id', 'INTEGER')
add_column('alarms', 'upload_status', "TEXT", "'pending_upload'")
add_column('alarms', 'upload_retry_count', 'INTEGER', '0')
add_column('alarms', 'error_message', 'TEXT')
add_column('alarms', 'region_data', 'TEXT')
add_column('alarms', 'llm_checked', 'BOOLEAN', '0')
add_column('alarms', 'llm_result', 'TEXT')
add_column('alarms', 'processed', 'BOOLEAN', '0')
print()
# camera_status 表
print('更新 camera_status 表:')
add_column('camera_status', 'last_frame_time', 'TIMESTAMP')
print()
conn.commit()
# 验证表结构
print('=== 验证表结构 ===')
for table in ['cameras', 'rois', 'alarms', 'camera_status']:
cursor.execute(f'PRAGMA table_info({table})')
cols = [col[1] for col in cursor.fetchall()]
print(f'{table}: {len(cols)}')
conn.close()
print()
print('数据库迁移完成!')

Binary file not shown.

1137
monitor.py

File diff suppressed because it is too large Load Diff

Binary file not shown.

461
services/sync_service.py Normal file
View File

@@ -0,0 +1,461 @@
"""
云端同步服务
实现"云端为主、本地为辅"的双层数据存储架构:
- 配置双向同步
- 报警单向上报
- 设备状态上报
- 断网容错机制
"""
import os
import sys
import time
import threading
import logging
from datetime import datetime
from typing import Optional, List, Dict, Any
from queue import Queue, Empty
from dataclasses import dataclass
from enum import Enum
import requests
from sqlalchemy.orm import Session
# 添加项目根目录到路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from config import get_config
logger = logging.getLogger(__name__)
class SyncStatus(Enum):
"""同步状态"""
PENDING = "pending"
SYNCING = "syncing"
SUCCESS = "success"
FAILED = "failed"
RETRY = "retry"
class EntityType(Enum):
"""实体类型"""
CAMERA = "camera"
ROI = "roi"
ALARM = "alarm"
STATUS = "status"
@dataclass
class SyncTask:
"""同步任务"""
entity_type: EntityType
entity_id: int
operation: str # create, update, delete
data: Optional[Dict[str, Any]] = None
status: SyncStatus = SyncStatus.PENDING
retry_count: int = 0
error_message: Optional[str] = None
created_at: datetime = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow()
class CloudAPIClient:
"""云端 API 客户端"""
def __init__(self, base_url: str, api_key: str, device_id: str):
self.base_url = base_url.rstrip('/')
self.api_key = api_key
self.device_id = device_id
self.session = requests.Session()
self.session.headers.update({
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json',
'X-Device-ID': device_id
})
def request(self, method: str, path: str, **kwargs) -> requests.Response:
"""发送 API 请求"""
url = f"{self.base_url}{path}"
response = self.session.request(method, url, **kwargs)
response.raise_for_status()
return response
def get(self, path: str, **kwargs):
return self.request('GET', path, **kwargs)
def post(self, path: str, **kwargs):
return self.request('POST', path, **kwargs)
def put(self, path: str, **kwargs):
return self.request('PUT', path, **kwargs)
def delete(self, path: str, **kwargs):
return self.request('DELETE', path, **kwargs)
class SyncService:
"""云端同步服务"""
def __init__(self):
config = get_config()
self.config = config
# 云端配置
self.cloud_enabled = config.cloud.enabled
self.cloud_url = config.cloud.api_url
self.api_key = config.cloud.api_key
self.device_id = config.cloud.device_id
# 同步配置
self.sync_interval = config.cloud.sync_interval
self.alarm_retry_interval = config.cloud.alarm_retry_interval
self.status_report_interval = config.cloud.status_report_interval
self.max_retries = config.cloud.max_retries
# 客户端
self.client: Optional[CloudAPIClient] = None
if self.cloud_enabled:
self.client = CloudAPIClient(
self.cloud_url,
self.api_key,
self.device_id
)
# 任务队列
self.sync_queue: Queue = Queue()
self.alarm_queue: Queue = Queue()
# 状态
self.running = False
self.threads: List[threading.Thread] = []
self.network_status = "disconnected"
# 重试配置
self.retry_delays = [60, 300, 900, 3600] # 1分钟, 5分钟, 15分钟, 1小时
def start(self):
"""启动同步服务"""
if self.running:
logger.warning("同步服务已在运行")
return
self.running = True
if self.cloud_enabled:
logger.info(f"启动云端同步服务设备ID: {self.device_id}")
else:
logger.info("云端同步已禁用,使用本地模式")
# 启动工作线程
self.threads.append(threading.Thread(target=self._sync_worker, daemon=True))
self.threads.append(threading.Thread(target=self._alarm_worker, daemon=True))
self.threads.append(threading.Thread(target=self._status_worker, daemon=True))
for thread in self.threads:
thread.start()
logger.info("同步服务已启动")
def stop(self):
"""停止同步服务"""
self.running = False
for thread in self.threads:
if thread.is_alive():
thread.join(timeout=5)
logger.info("同步服务已停止")
def _sync_worker(self):
"""配置同步工作线程"""
while self.running:
try:
task = self.sync_queue.get(timeout=1)
self._execute_sync(task)
except Empty:
self._check_network_status()
except Exception as e:
logger.error(f"同步工作线程异常: {e}")
def _alarm_worker(self):
"""报警上报工作线程"""
while self.running:
try:
alarm_id = self.alarm_queue.get(timeout=1)
self._upload_alarm(alarm_id)
except Empty:
continue
except Exception as e:
logger.error(f"报警上报工作线程异常: {e}")
def _status_worker(self):
"""状态上报工作线程"""
while self.running:
try:
if self.network_status == "connected":
self._report_status()
except Exception as e:
logger.error(f"状态上报失败: {e}")
time.sleep(self.status_report_interval)
def _check_network_status(self):
"""检查网络状态"""
if not self.cloud_enabled:
self.network_status = "disabled"
return
try:
self.client.get('/health')
self.network_status = "connected"
except:
self.network_status = "disconnected"
def _execute_sync(self, task: SyncTask):
"""执行同步任务"""
logger.info(f"执行同步任务: {task.entity_type.value}/{task.entity_id} ({task.operation})")
task.status = SyncStatus.SYNCING
try:
if task.entity_type == EntityType.CAMERA:
self._sync_camera(task)
elif task.entity_type == EntityType.ROI:
self._sync_roi(task)
task.status = SyncStatus.SUCCESS
logger.info(f"同步成功: {task.entity_type.value}/{task.entity_id}")
except requests.exceptions.RequestException as e:
self._handle_sync_error(task, str(e))
except Exception as e:
task.status = SyncStatus.FAILED
task.error_message = str(e)
logger.error(f"同步失败: {task.entity_type.value}/{task.entity_id}: {e}")
def _handle_sync_error(self, task: SyncTask, error: str):
"""处理同步错误"""
task.retry_count += 1
if task.retry_count < self.max_retries:
task.status = SyncStatus.RETRY
delay = self.retry_delays[task.retry_count - 1]
task.error_message = f"{task.retry_count}次失败: {error}"
logger.warning(f"同步重试 ({task.retry_count}/{self.max_retries}): {task.entity_type.value}/{task.entity_id}")
# 重新入队
time.sleep(delay)
self.sync_queue.put(task)
else:
task.status = SyncStatus.FAILED
task.error_message = f"已超过最大重试次数: {error}"
logger.error(f"同步失败,已达最大重试次数: {task.entity_type.value}/{task.entity_id}")
def _sync_camera(self, task: SyncTask):
"""同步摄像头配置"""
if task.operation == 'update':
self.client.put(f"/api/v1/cameras/{task.entity_id}", json=task.data)
elif task.operation == 'delete':
self.client.delete(f"/api/v1/cameras/{task.entity_id}")
def _sync_roi(self, task: SyncTask):
"""同步 ROI 配置"""
if task.operation == 'update':
self.client.put(f"/api/v1/rois/{task.entity_id}", json=task.data)
elif task.operation == 'delete':
self.client.delete(f"/api/v1/rois/{task.entity_id}")
def _upload_alarm(self, alarm_id: int):
"""上传报警记录"""
from db.crud import get_alarm_by_id, update_alarm_status
from db.models import get_session_factory
SessionLocal = get_session_factory()
db = SessionLocal()
try:
alarm = get_alarm_by_id(db, alarm_id)
if not alarm:
logger.warning(f"报警记录不存在: {alarm_id}")
return
# 准备数据
alarm_data = {
'device_id': self.device_id,
'camera_id': alarm.camera_id,
'alarm_type': alarm.event_type,
'confidence': alarm.confidence,
'timestamp': alarm.created_at.isoformat() if alarm.created_at else None,
'region': alarm.region_data
}
# 上传图片
if alarm.snapshot_path and os.path.exists(alarm.snapshot_path):
with open(alarm.snapshot_path, 'rb') as f:
files = {'file': f}
response = self.client.post('/api/v1/alarms/images', files=files)
alarm_data['image_url'] = response.json().get('data', {}).get('url')
# 上报报警
response = self.client.post('/api/v1/alarms/report', json=alarm_data)
cloud_id = response.json().get('data', {}).get('alarm_id')
# 更新本地状态
update_alarm_status(db, alarm_id, status='uploaded', cloud_id=cloud_id)
logger.info(f"报警上报成功: {alarm_id} -> 云端ID: {cloud_id}")
except requests.exceptions.RequestException as e:
update_alarm_status(db, alarm_id, status='retry', error_message=str(e))
self.alarm_queue.put(alarm_id) # 重试
except Exception as e:
update_alarm_status(db, alarm_id, status='failed', error_message=str(e))
logger.error(f"报警处理失败: {alarm_id}: {e}")
finally:
db.close()
def _report_status(self):
"""上报设备状态"""
import psutil
from db.crud import get_active_camera_count
try:
metrics = {
'device_id': self.device_id,
'cpu_percent': psutil.cpu_percent(),
'memory_percent': psutil.virtual_memory().percent,
'disk_usage': psutil.disk_usage('/').percent,
'timestamp': datetime.utcnow().isoformat()
}
self.client.post('/api/v1/devices/status', json=metrics)
logger.debug(f"设备状态上报成功: CPU={metrics['cpu_percent']}%")
except requests.exceptions.RequestException as e:
logger.warning(f"设备状态上报失败: {e}")
# 公共接口
def queue_camera_sync(self, camera_id: int, operation: str = 'update', data: Dict[str, Any] = None):
"""将摄像头同步加入队列"""
task = SyncTask(
entity_type=EntityType.CAMERA,
entity_id=camera_id,
operation=operation,
data=data
)
self.sync_queue.put(task)
def queue_roi_sync(self, roi_id: int, operation: str = 'update', data: Dict[str, Any] = None):
"""将 ROI 同步加入队列"""
task = SyncTask(
entity_type=EntityType.ROI,
entity_id=roi_id,
operation=operation,
data=data
)
self.sync_queue.put(task)
def queue_alarm_upload(self, alarm_id: int):
"""将报警上传加入队列"""
self.alarm_queue.put(alarm_id)
def sync_config_from_cloud(self, db: Session) -> Dict[str, int]:
"""从云端拉取配置"""
result = {'cameras': 0, 'rois': 0}
if not self.cloud_enabled:
logger.info("云端同步已禁用,跳过配置拉取")
return result
try:
logger.info("从云端拉取配置...")
# 拉取设备配置
response = self.client.get(f"/api/v1/devices/{self.device_id}/config")
config = response.json().get('data', {})
# 处理摄像头
cameras = config.get('cameras', [])
for cloud_cam in cameras:
self._merge_camera(db, cloud_cam)
result['cameras'] += 1
logger.info(f"从云端拉取配置完成: {result['cameras']} 个摄像头")
except requests.exceptions.RequestException as e:
logger.error(f"从云端拉取配置失败: {e}")
except Exception as e:
logger.error(f"处理云端配置时出错: {e}")
return result
def _merge_camera(self, db: Session, cloud_data: Dict[str, Any]):
"""合并摄像头配置"""
from db.crud import get_camera_by_cloud_id, create_camera, update_camera
from db.models import Camera
cloud_id = cloud_data.get('id')
existing = get_camera_by_cloud_id(db, cloud_id)
if existing:
# 更新现有记录
if not existing.pending_sync:
update_camera(db, existing.id, {
'name': cloud_data.get('name'),
'rtsp_url': cloud_data.get('rtsp_url'),
'enabled': cloud_data.get('enabled', True),
'fps_limit': cloud_data.get('fps_limit', 30),
'process_every_n_frames': cloud_data.get('process_every_n_frames', 3),
})
else:
# 创建新记录
camera = create_camera(db, {
'name': cloud_data.get('name'),
'rtsp_url': cloud_data.get('rtsp_url'),
'enabled': cloud_data.get('enabled', True),
'fps_limit': cloud_data.get('fps_limit', 30),
'process_every_n_frames': cloud_data.get('process_every_n_frames', 3),
})
# 更新 cloud_id
camera.cloud_id = cloud_id
db.commit()
def get_status(self) -> Dict[str, Any]:
"""获取同步服务状态"""
return {
'running': self.running,
'cloud_enabled': self.cloud_enabled,
'network_status': self.network_status,
'device_id': self.device_id,
'pending_sync': self.sync_queue.qsize(),
'pending_alarms': self.alarm_queue.qsize(),
}
# 单例
_sync_service: Optional[SyncService] = None
def get_sync_service() -> SyncService:
"""获取同步服务单例"""
global _sync_service
if _sync_service is None:
_sync_service = SyncService()
return _sync_service
def start_sync_service():
"""启动同步服务"""
service = get_sync_service()
service.start()
def stop_sync_service():
"""停止同步服务"""
global _sync_service
if _sync_service:
_sync_service.stop()
_sync_service = None

214
sort.py
View File

@@ -1,214 +0,0 @@
import numpy as np
from scipy.optimize import linear_sum_assignment
from filterpy.kalman import KalmanFilter
def linear_assignment(cost_matrix):
x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y)))
def iou_batch(bb_test, bb_gt):
"""
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
"""
bb_test = np.expand_dims(bb_test, 1)
bb_gt = np.expand_dims(bb_gt, 0)
xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
return o
def convert_bbox_to_z(bbox):
"""
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
the aspect ratio
"""
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w / 2.
y = bbox[1] + h / 2.
s = w * h # scale is just area
r = w / float(h)
return np.array([x, y, s, r]).reshape((4, 1))
def convert_x_to_bbox(x, score=None):
"""
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if score is None:
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4))
else:
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5))
class KalmanBoxTracker(object):
"""
This class represents the internal state of individual tracked objects observed as bbox.
"""
count = 0
def __init__(self, bbox):
"""
Initialises a tracker using initial bounding box.
"""
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
self.kf.P *= 10.
self.kf.Q[-1, -1] *= 0.01
self.kf.Q[4:, 4:] *= 0.01
self.kf.x[:4] = convert_bbox_to_z(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.history = []
self.hits = 0
self.hit_streak = 0
self.age = 0
def update(self, bbox):
"""
Updates the state vector with observed bbox.
"""
self.time_since_update = 0
self.history = []
self.hits += 1
self.hit_streak += 1
self.kf.update(convert_bbox_to_z(bbox))
def predict(self):
"""
Advances the state vector and returns the predicted bounding box estimate.
"""
if (self.kf.x[6] + self.kf.x[2]) <= 0:
self.kf.x[6] *= 0.0
self.kf.predict()
self.age += 1
if self.time_since_update > 0:
self.hit_streak = 0
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x))
return self.history[-1]
def get_state(self):
"""
Returns the current bounding box estimate.
"""
return convert_x_to_bbox(self.kf.x)
class Sort(object):
def __init__(self, max_age=5, min_hits=2, iou_threshold=0.3):
"""
Sets key parameters for SORT
"""
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.trackers = []
self.frame_count = 0
def update(self, dets=np.empty((0, 5))):
"""
Params:
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],...]
Requires: this method must be called once for each frame even with empty detections.
Returns the a similar array, where the last column is the object ID.
NOTE: The number of objects returned may differ from the number of detections provided.
"""
self.frame_count += 1
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.trackers[t].predict()[0]
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
if np.any(np.isnan(pos)):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.trackers.pop(t)
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
# update matched trackers with assigned detections
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
# create and initialise new trackers for unmatched detections
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :])
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
d = trk.get_state()[0]
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive
i -= 1
# remove dead tracklet
if trk.time_since_update > self.max_age:
self.trackers.pop(i)
if len(ret) > 0:
return np.concatenate(ret)
return np.empty((0, 5))
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
"""
Assigns detections to tracked object (both represented as bounding boxes)
Returns 3 lists of matches, unmatched_detections, unmatched_trackers
"""
if len(trackers) == 0:
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 1), dtype=int)
iou_matrix = iou_batch(detections, trackers)
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-iou_matrix)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if d not in matched_indices[:, 0]:
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if t not in matched_indices[:, 1]:
unmatched_trackers.append(t)
matches = []
for m in matched_indices:
if iou_matrix[m[0], m[1]] < iou_threshold:
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if len(matches) == 0:
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)

View File

@@ -58,7 +58,7 @@ def test_leave_post_algorithm_process():
{"bbox": [100, 100, 200, 200], "conf": 0.9, "cls": 0},
]
alerts = algo.process("test_cam", tracks, datetime.now())
alerts = algo.process("roi_1", "test_cam", tracks, datetime.now())
assert isinstance(alerts, list)