Compare commits
25 Commits
2c00b5afe3
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| 8e2b285893 | |||
| afa9e122a5 | |||
| 7a10a983c8 | |||
| 98c741cb2b | |||
| 44b6c70a4b | |||
| 3af7a0f805 | |||
| cb46d12cfa | |||
| 123903950b | |||
| 2d5ada2909 | |||
| 6fc17ccf64 | |||
| 6116f0b982 | |||
| 20f295a491 | |||
| cc4f33c0fd | |||
| 2e9bf2b50c | |||
| 248a240524 | |||
| 10b9fb1804 | |||
| 1a94854c52 | |||
| 13afc654ab | |||
| 804c6a60e9 | |||
| 20634c2ad4 | |||
| 46ee360d46 | |||
| 6712a311f8 | |||
| 294b0e1abb | |||
| 1c7190bbb0 | |||
| 1b344aeb2e |
34
.gitea/workflows/python-test.yml
Normal file
34
.gitea/workflows/python-test.yml
Normal file
@@ -0,0 +1,34 @@
|
||||
name: Python Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "master", "main" ]
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if [ -f "requirements.txt" ]; then
|
||||
pip install -r requirements.txt
|
||||
fi
|
||||
pip install pytest
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
if [ -d "tests" ]; then
|
||||
pytest tests/ --verbose
|
||||
else
|
||||
echo "No tests directory found, skipping tests."
|
||||
fi
|
||||
75
api/alarm.py
75
api/alarm.py
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Body
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db.crud import (
|
||||
@@ -16,6 +17,41 @@ from inference.pipeline import get_pipeline
|
||||
router = APIRouter(prefix="/api/alarms", tags=["告警管理"])
|
||||
|
||||
|
||||
class AlarmUpdateRequest(BaseModel):
|
||||
llm_checked: Optional[bool] = None
|
||||
llm_result: Optional[str] = None
|
||||
processed: Optional[bool] = None
|
||||
|
||||
|
||||
def convert_to_china_time(dt: Optional[datetime]) -> Optional[str]:
|
||||
"""将 UTC 时间转换为中国时间 (UTC+8)"""
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
china_tz = timezone(timedelta(hours=8))
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(china_tz).isoformat()
|
||||
except Exception:
|
||||
return dt.isoformat() if dt else None
|
||||
|
||||
|
||||
def format_alarm_response(alarm) -> dict:
|
||||
"""格式化告警响应,将 UTC 时间转换为中国时间"""
|
||||
return {
|
||||
"id": alarm.id,
|
||||
"camera_id": alarm.camera_id,
|
||||
"roi_id": alarm.roi_id,
|
||||
"event_type": alarm.event_type,
|
||||
"confidence": alarm.confidence,
|
||||
"snapshot_path": alarm.snapshot_path,
|
||||
"llm_checked": alarm.llm_checked,
|
||||
"llm_result": alarm.llm_result,
|
||||
"processed": alarm.processed,
|
||||
"created_at": convert_to_china_time(alarm.created_at),
|
||||
}
|
||||
|
||||
|
||||
@router.get("", response_model=List[dict])
|
||||
def list_alarms(
|
||||
camera_id: Optional[int] = None,
|
||||
@@ -25,21 +61,7 @@ def list_alarms(
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
alarms = get_alarms(db, camera_id=camera_id, event_type=event_type, limit=limit, offset=offset)
|
||||
return [
|
||||
{
|
||||
"id": alarm.id,
|
||||
"camera_id": alarm.camera_id,
|
||||
"roi_id": alarm.roi_id,
|
||||
"event_type": alarm.event_type,
|
||||
"confidence": alarm.confidence,
|
||||
"snapshot_path": alarm.snapshot_path,
|
||||
"llm_checked": alarm.llm_checked,
|
||||
"llm_result": alarm.llm_result,
|
||||
"processed": alarm.processed,
|
||||
"created_at": alarm.created_at.isoformat() if alarm.created_at else None,
|
||||
}
|
||||
for alarm in alarms
|
||||
]
|
||||
return [format_alarm_response(alarm) for alarm in alarms]
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
@@ -55,29 +77,16 @@ def get_alarm(alarm_id: int, db: Session = Depends(get_db)):
|
||||
alarm = next((a for a in alarms if a.id == alarm_id), None)
|
||||
if not alarm:
|
||||
raise HTTPException(status_code=404, detail="告警不存在")
|
||||
return {
|
||||
"id": alarm.id,
|
||||
"camera_id": alarm.camera_id,
|
||||
"roi_id": alarm.roi_id,
|
||||
"event_type": alarm.event_type,
|
||||
"confidence": alarm.confidence,
|
||||
"snapshot_path": alarm.snapshot_path,
|
||||
"llm_checked": alarm.llm_checked,
|
||||
"llm_result": alarm.llm_result,
|
||||
"processed": alarm.processed,
|
||||
"created_at": alarm.created_at.isoformat() if alarm.created_at else None,
|
||||
}
|
||||
return format_alarm_response(alarm)
|
||||
|
||||
|
||||
@router.put("/{alarm_id}")
|
||||
def update_alarm_status(
|
||||
alarm_id: int,
|
||||
llm_checked: Optional[bool] = None,
|
||||
llm_result: Optional[str] = None,
|
||||
processed: Optional[bool] = None,
|
||||
request: AlarmUpdateRequest = Body(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
alarm = update_alarm(db, alarm_id, llm_checked=llm_checked, llm_result=llm_result, processed=processed)
|
||||
alarm = update_alarm(db, alarm_id, llm_checked=request.llm_checked, llm_result=request.llm_result, processed=request.processed)
|
||||
if not alarm:
|
||||
raise HTTPException(status_code=404, detail="告警不存在")
|
||||
return {"message": "更新成功"}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
@@ -15,6 +16,7 @@ from db.models import get_db
|
||||
from inference.pipeline import get_pipeline
|
||||
|
||||
router = APIRouter(prefix="/api/cameras", tags=["摄像头管理"])
|
||||
router2 = APIRouter(prefix="/api/camera", tags=["摄像头状态"])
|
||||
|
||||
|
||||
class CameraUpdateRequest(BaseModel):
|
||||
@@ -25,6 +27,19 @@ class CameraUpdateRequest(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
|
||||
|
||||
def convert_to_china_time(dt: Optional[datetime]) -> Optional[str]:
|
||||
"""将 UTC 时间转换为中国时间 (UTC+8)"""
|
||||
if dt is None:
|
||||
return None
|
||||
try:
|
||||
china_tz = timezone(timedelta(hours=8))
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(china_tz).isoformat()
|
||||
except Exception:
|
||||
return dt.isoformat() if dt else None
|
||||
|
||||
|
||||
@router.get("", response_model=List[dict])
|
||||
def list_cameras(
|
||||
enabled_only: bool = True,
|
||||
@@ -39,7 +54,7 @@ def list_cameras(
|
||||
"enabled": cam.enabled,
|
||||
"fps_limit": cam.fps_limit,
|
||||
"process_every_n_frames": cam.process_every_n_frames,
|
||||
"created_at": cam.created_at.isoformat() if cam.created_at else None,
|
||||
"created_at": convert_to_china_time(cam.created_at),
|
||||
}
|
||||
for cam in cameras
|
||||
]
|
||||
@@ -57,24 +72,21 @@ def get_camera(camera_id: int, db: Session = Depends(get_db)):
|
||||
"enabled": camera.enabled,
|
||||
"fps_limit": camera.fps_limit,
|
||||
"process_every_n_frames": camera.process_every_n_frames,
|
||||
"created_at": camera.created_at.isoformat() if camera.created_at else None,
|
||||
"created_at": convert_to_china_time(camera.created_at),
|
||||
}
|
||||
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
def add_camera(
|
||||
name: str,
|
||||
rtsp_url: str,
|
||||
fps_limit: int = 30,
|
||||
process_every_n_frames: int = 3,
|
||||
request: CameraUpdateRequest = Body(...),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
camera = create_camera(
|
||||
db,
|
||||
name=name,
|
||||
rtsp_url=rtsp_url,
|
||||
fps_limit=fps_limit,
|
||||
process_every_n_frames=process_every_n_frames,
|
||||
name=request.name,
|
||||
rtsp_url=request.rtsp_url,
|
||||
fps_limit=request.fps_limit or 30,
|
||||
process_every_n_frames=request.process_every_n_frames or 3,
|
||||
)
|
||||
|
||||
if camera.enabled:
|
||||
@@ -163,3 +175,42 @@ def get_camera_status(camera_id: int, db: Session = Depends(get_db)):
|
||||
"last_check_time": None,
|
||||
"stream": stream_info,
|
||||
}
|
||||
|
||||
|
||||
router2 = APIRouter(prefix="/api/camera", tags=["摄像头状态"])
|
||||
|
||||
|
||||
@router2.get("/status/all")
|
||||
def get_all_camera_status(db: Session = Depends(get_db)):
|
||||
from db.crud import get_all_cameras, get_camera_status as get_status
|
||||
|
||||
cameras = get_all_cameras(db, enabled_only=False)
|
||||
pipeline = get_pipeline()
|
||||
|
||||
result = []
|
||||
for cam in cameras:
|
||||
status = get_status(db, cam.id)
|
||||
|
||||
stream = pipeline.stream_manager.get_stream(str(cam.id))
|
||||
stream_info = stream.get_info() if stream else None
|
||||
|
||||
if status:
|
||||
result.append({
|
||||
"camera_id": cam.id,
|
||||
"is_running": status.is_running,
|
||||
"fps": status.fps,
|
||||
"error_message": status.error_message,
|
||||
"last_check_time": status.last_check_time.isoformat() if status.last_check_time else None,
|
||||
"stream": stream_info,
|
||||
})
|
||||
else:
|
||||
result.append({
|
||||
"camera_id": cam.id,
|
||||
"is_running": False,
|
||||
"fps": 0.0,
|
||||
"error_message": None,
|
||||
"last_check_time": None,
|
||||
"stream": stream_info,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
110
api/roi.py
110
api/roi.py
@@ -1,7 +1,8 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Body
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db.crud import (
|
||||
@@ -20,6 +21,34 @@ from inference.roi.roi_filter import ROIFilter
|
||||
router = APIRouter(prefix="/api/camera", tags=["ROI管理"])
|
||||
|
||||
|
||||
class CreateROIRequest(BaseModel):
|
||||
roi_id: str
|
||||
name: str
|
||||
roi_type: str
|
||||
points: List[List[float]]
|
||||
rule_type: str
|
||||
direction: Optional[str] = None
|
||||
stay_time: Optional[int] = None
|
||||
threshold_sec: int = 300
|
||||
confirm_sec: int = 10
|
||||
return_sec: int = 30
|
||||
working_hours: Optional[List[dict]] = None
|
||||
|
||||
|
||||
class UpdateROIRequest(BaseModel):
|
||||
name: Optional[str] = None
|
||||
roi_type: Optional[str] = None
|
||||
points: Optional[List[List[float]]] = None
|
||||
rule_type: Optional[str] = None
|
||||
direction: Optional[str] = None
|
||||
stay_time: Optional[int] = None
|
||||
enabled: Optional[bool] = None
|
||||
threshold_sec: Optional[int] = None
|
||||
confirm_sec: Optional[int] = None
|
||||
return_sec: Optional[int] = None
|
||||
working_hours: Optional[List[dict]] = None
|
||||
|
||||
|
||||
def _invalidate_roi_cache(camera_id: int):
|
||||
pipeline = get_pipeline()
|
||||
pipeline.roi_filter.clear_cache(camera_id)
|
||||
@@ -43,6 +72,7 @@ def list_rois(camera_id: int, db: Session = Depends(get_db)):
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
|
||||
}
|
||||
for roi in roi_configs
|
||||
]
|
||||
@@ -66,37 +96,34 @@ def get_roi(camera_id: int, roi_id: int, db: Session = Depends(get_db)):
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{camera_id}/roi", response_model=dict)
|
||||
def add_roi(
|
||||
camera_id: int,
|
||||
roi_id: str,
|
||||
name: str,
|
||||
roi_type: str,
|
||||
points: List[List[float]],
|
||||
rule_type: str,
|
||||
direction: Optional[str] = None,
|
||||
stay_time: Optional[int] = None,
|
||||
threshold_sec: int = 360,
|
||||
confirm_sec: int = 30,
|
||||
return_sec: int = 5,
|
||||
request: CreateROIRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
import json
|
||||
|
||||
working_hours_json = json.dumps(request.working_hours) if request.working_hours is not None else None
|
||||
|
||||
roi = create_roi(
|
||||
db,
|
||||
camera_id=camera_id,
|
||||
roi_id=roi_id,
|
||||
name=name,
|
||||
roi_type=roi_type,
|
||||
points=points,
|
||||
rule_type=rule_type,
|
||||
direction=direction,
|
||||
stay_time=stay_time,
|
||||
threshold_sec=threshold_sec,
|
||||
confirm_sec=confirm_sec,
|
||||
return_sec=return_sec,
|
||||
roi_id=request.roi_id,
|
||||
name=request.name,
|
||||
roi_type=request.roi_type,
|
||||
points=request.points,
|
||||
rule_type=request.rule_type,
|
||||
direction=request.direction,
|
||||
stay_time=request.stay_time,
|
||||
threshold_sec=request.threshold_sec,
|
||||
confirm_sec=request.confirm_sec,
|
||||
return_sec=request.return_sec,
|
||||
working_hours=working_hours_json,
|
||||
)
|
||||
|
||||
_invalidate_roi_cache(camera_id)
|
||||
@@ -106,9 +133,15 @@ def add_roi(
|
||||
"roi_id": roi.roi_id,
|
||||
"name": roi.name,
|
||||
"type": roi.roi_type,
|
||||
"points": points,
|
||||
"points": request.points,
|
||||
"rule": roi.rule_type,
|
||||
"direction": roi.direction,
|
||||
"stay_time": roi.stay_time,
|
||||
"enabled": roi.enabled,
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
"working_hours": request.working_hours,
|
||||
}
|
||||
|
||||
|
||||
@@ -116,29 +149,25 @@ def add_roi(
|
||||
def modify_roi(
|
||||
camera_id: int,
|
||||
roi_id: int,
|
||||
name: Optional[str] = None,
|
||||
points: Optional[List[List[float]]] = None,
|
||||
rule_type: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
stay_time: Optional[int] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
threshold_sec: Optional[int] = None,
|
||||
confirm_sec: Optional[int] = None,
|
||||
return_sec: Optional[int] = None,
|
||||
request: UpdateROIRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
import json
|
||||
working_hours_json = json.dumps(request.working_hours) if request.working_hours is not None else None
|
||||
|
||||
roi = update_roi(
|
||||
db,
|
||||
roi_id=roi_id,
|
||||
name=name,
|
||||
points=points,
|
||||
rule_type=rule_type,
|
||||
direction=direction,
|
||||
stay_time=stay_time,
|
||||
enabled=enabled,
|
||||
threshold_sec=threshold_sec,
|
||||
confirm_sec=confirm_sec,
|
||||
return_sec=return_sec,
|
||||
name=request.name,
|
||||
points=request.points,
|
||||
rule_type=request.rule_type,
|
||||
direction=request.direction,
|
||||
stay_time=request.stay_time,
|
||||
enabled=request.enabled,
|
||||
threshold_sec=request.threshold_sec,
|
||||
confirm_sec=request.confirm_sec,
|
||||
return_sec=request.return_sec,
|
||||
working_hours=working_hours_json,
|
||||
)
|
||||
if not roi:
|
||||
raise HTTPException(status_code=404, detail="ROI不存在")
|
||||
@@ -153,6 +182,7 @@ def modify_roi(
|
||||
"points": json.loads(roi.points),
|
||||
"rule": roi.rule_type,
|
||||
"enabled": roi.enabled,
|
||||
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
|
||||
}
|
||||
|
||||
|
||||
|
||||
116
api/sync.py
Normal file
116
api/sync.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
|
||||
from db.models import get_db
|
||||
from services.sync_service import get_sync_service
|
||||
|
||||
router = APIRouter(prefix="/api/sync", tags=["同步"])
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
def get_sync_status(db: Session = Depends(get_db)):
|
||||
"""获取同步服务状态"""
|
||||
from sqlalchemy import text
|
||||
|
||||
service = get_sync_service()
|
||||
status = service.get_status()
|
||||
|
||||
pending_cameras = db.execute(
|
||||
text("SELECT COUNT(*) FROM cameras WHERE pending_sync = 1")
|
||||
).scalar() or 0
|
||||
|
||||
pending_rois = db.execute(
|
||||
text("SELECT COUNT(*) FROM rois WHERE pending_sync = 1")
|
||||
).scalar() or 0
|
||||
|
||||
pending_alarms = db.execute(
|
||||
text("SELECT COUNT(*) FROM alarms WHERE upload_status = 'pending' OR upload_status = 'retry'")
|
||||
).scalar() or 0
|
||||
|
||||
return {
|
||||
"running": status["running"],
|
||||
"cloud_enabled": status["cloud_enabled"],
|
||||
"network_status": status["network_status"],
|
||||
"device_id": status["device_id"],
|
||||
"pending_sync": pending_cameras + pending_rois,
|
||||
"pending_alarms": pending_alarms,
|
||||
"details": {
|
||||
"pending_cameras": pending_cameras,
|
||||
"pending_rois": pending_rois,
|
||||
"pending_alarms": pending_alarms
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/pending")
|
||||
def get_pending_syncs(db: Session = Depends(get_db)):
|
||||
"""获取待同步项列表"""
|
||||
from sqlalchemy import text
|
||||
from db.models import Camera, ROI, Alarm
|
||||
|
||||
pending_cameras = db.query(Camera).filter(Camera.pending_sync == True).all()
|
||||
pending_rois = db.query(ROI).filter(ROI.pending_sync == True).all()
|
||||
|
||||
from db.session import SessionLocal
|
||||
temp_db = SessionLocal()
|
||||
try:
|
||||
pending_alarms = temp_db.query(Alarm).filter(
|
||||
Alarm.upload_status.in_(['pending', 'retry'])
|
||||
).limit(100).all()
|
||||
finally:
|
||||
temp_db.close()
|
||||
|
||||
return {
|
||||
"cameras": [{"id": c.id, "name": c.name} for c in pending_cameras],
|
||||
"rois": [{"id": r.id, "name": r.name, "camera_id": r.camera_id} for r in pending_rois],
|
||||
"alarms": [{"id": a.id, "camera_id": a.camera_id, "type": a.event_type} for a in pending_alarms]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/trigger")
|
||||
def trigger_sync():
|
||||
"""手动触发同步"""
|
||||
service = get_sync_service()
|
||||
from db.session import SessionLocal
|
||||
from db.crud import get_all_cameras, get_all_rois
|
||||
from db.models import Camera, ROI
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
cameras = get_all_cameras(db)
|
||||
for camera in cameras:
|
||||
service.queue_camera_sync(camera.id, 'update', {
|
||||
'name': camera.name,
|
||||
'rtsp_url': camera.rtsp_url,
|
||||
'enabled': camera.enabled
|
||||
})
|
||||
db.query(Camera).filter(Camera.id == camera.id).update({'pending_sync': False})
|
||||
db.commit()
|
||||
|
||||
rois = get_all_rois(db)
|
||||
for roi in rois:
|
||||
service.queue_roi_sync(roi.id, 'update', {
|
||||
'name': roi.name,
|
||||
'roi_type': roi.roi_type,
|
||||
'points': roi.points,
|
||||
'enabled': roi.enabled
|
||||
})
|
||||
db.query(ROI).filter(ROI.id == roi.id).update({'pending_sync': False})
|
||||
db.commit()
|
||||
|
||||
return {"message": "同步任务已加入队列", "count": len(cameras) + len(rois)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/clear-failed")
|
||||
def clear_failed_syncs(db: Session = Depends(get_db)):
|
||||
"""清除失败的同步标记"""
|
||||
from sqlalchemy import text
|
||||
|
||||
db.execute(text("UPDATE cameras SET pending_sync = 0, sync_failed_at = NULL, sync_retry_count = 0"))
|
||||
db.execute(text("UPDATE rois SET pending_sync = 0, sync_failed_at = NULL, sync_retry_count = 0"))
|
||||
db.commit()
|
||||
|
||||
return {"message": "已清除所有失败的同步标记"}
|
||||
17
config.py
17
config.py
@@ -23,16 +23,17 @@ class DatabaseConfig(BaseModel):
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
engine_path: str = "models/yolo11s.engine"
|
||||
onnx_path: str = "models/yolo11s.onnx"
|
||||
pt_model_path: str = "models/yolo11s.pt"
|
||||
engine_path: str = "models/yolo11n.engine"
|
||||
onnx_path: str = "models/yolo11n.onnx"
|
||||
pt_model_path: str = "models/yolo11n.pt"
|
||||
imgsz: List[int] = [640, 640]
|
||||
conf_threshold: float = 0.5
|
||||
iou_threshold: float = 0.45
|
||||
device: int = 0
|
||||
batch_size: int = 8
|
||||
half: bool = True
|
||||
use_onnx: bool = True
|
||||
half: bool = False
|
||||
use_onnx: bool = False
|
||||
use_trt: bool = False
|
||||
|
||||
|
||||
class StreamConfig(BaseModel):
|
||||
@@ -65,9 +66,9 @@ class WorkingHours(BaseModel):
|
||||
|
||||
|
||||
class AlgorithmsConfig(BaseModel):
|
||||
leave_post_threshold_sec: int = 360
|
||||
leave_post_confirm_sec: int = 30
|
||||
leave_post_return_sec: int = 5
|
||||
leave_post_threshold_sec: int = 300
|
||||
leave_post_confirm_sec: int = 10
|
||||
leave_post_return_sec: int = 30
|
||||
intrusion_check_interval_sec: float = 1.0
|
||||
intrusion_direction_sensitive: bool = False
|
||||
|
||||
|
||||
21
config.yaml
21
config.yaml
@@ -60,22 +60,17 @@ roi:
|
||||
- "line"
|
||||
max_points: 50 # 多边形最大顶点数
|
||||
|
||||
# 工作时间配置(全局默认)
|
||||
working_hours:
|
||||
- start: [8, 30] # 8:30
|
||||
end: [11, 0] # 11:00
|
||||
- start: [12, 0] # 12:00
|
||||
end: [17, 30] # 17:30
|
||||
# 工作时间配置(全局默认,空数组表示全天开启)
|
||||
working_hours: []
|
||||
|
||||
# 算法默认参数
|
||||
algorithms:
|
||||
leave_post:
|
||||
default_threshold_sec: 360 # 离岗超时(6分钟)
|
||||
confirm_sec: 30 # 离岗确认时间
|
||||
return_sec: 5 # 上岗确认时间
|
||||
threshold_sec: 300 # 离岗超时(5分钟)
|
||||
confirm_sec: 10 # 上岗确认时间(10秒)
|
||||
return_sec: 30 # 离岗缓冲时间(30秒)
|
||||
intrusion:
|
||||
check_interval_sec: 1.0 # 检测间隔
|
||||
direction_sensitive: false # 方向敏感
|
||||
cooldown_seconds: 300 # 入侵检测冷却时间(秒)
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
@@ -94,7 +89,7 @@ monitoring:
|
||||
# 大模型配置(预留)
|
||||
llm:
|
||||
enabled: false
|
||||
api_key: ""
|
||||
base_url: ""
|
||||
api_key: "sk-21e61bef09074682b589da3bdbfe07a2"
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
model: "qwen3-vl-max"
|
||||
timeout: 30
|
||||
|
||||
12
db/crud.py
12
db/crud.py
@@ -136,9 +136,10 @@ def create_roi(
|
||||
rule_type: str,
|
||||
direction: Optional[str] = None,
|
||||
stay_time: Optional[int] = None,
|
||||
threshold_sec: int = 360,
|
||||
confirm_sec: int = 30,
|
||||
return_sec: int = 5,
|
||||
threshold_sec: int = 300,
|
||||
confirm_sec: int = 10,
|
||||
return_sec: int = 30,
|
||||
working_hours: Optional[str] = None,
|
||||
) -> ROI:
|
||||
import json
|
||||
|
||||
@@ -154,6 +155,7 @@ def create_roi(
|
||||
threshold_sec=threshold_sec,
|
||||
confirm_sec=confirm_sec,
|
||||
return_sec=return_sec,
|
||||
working_hours=working_hours,
|
||||
)
|
||||
db.add(roi)
|
||||
db.commit()
|
||||
@@ -173,6 +175,7 @@ def update_roi(
|
||||
threshold_sec: Optional[int] = None,
|
||||
confirm_sec: Optional[int] = None,
|
||||
return_sec: Optional[int] = None,
|
||||
working_hours: Optional[str] = None,
|
||||
) -> Optional[ROI]:
|
||||
import json
|
||||
|
||||
@@ -198,6 +201,8 @@ def update_roi(
|
||||
roi.confirm_sec = confirm_sec
|
||||
if return_sec is not None:
|
||||
roi.return_sec = return_sec
|
||||
if working_hours is not None:
|
||||
roi.working_hours = working_hours
|
||||
|
||||
db.commit()
|
||||
db.refresh(roi)
|
||||
@@ -232,6 +237,7 @@ def get_roi_points(db: Session, camera_id: int) -> List[dict]:
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
|
||||
}
|
||||
for roi in rois
|
||||
]
|
||||
|
||||
@@ -90,9 +90,10 @@ class ROI(Base):
|
||||
direction: Mapped[Optional[str]] = mapped_column(String(32))
|
||||
stay_time: Mapped[Optional[int]] = mapped_column(Integer)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
threshold_sec: Mapped[int] = mapped_column(Integer, default=360)
|
||||
confirm_sec: Mapped[int] = mapped_column(Integer, default=30)
|
||||
return_sec: Mapped[int] = mapped_column(Integer, default=5)
|
||||
threshold_sec: Mapped[int] = mapped_column(Integer, default=300)
|
||||
confirm_sec: Mapped[int] = mapped_column(Integer, default=10)
|
||||
return_sec: Mapped[int] = mapped_column(Integer, default=30)
|
||||
working_hours: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
pending_sync: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
sync_version: Mapped[int] = mapped_column(Integer, default=0)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@@ -38,14 +38,48 @@ const CameraManagement: React.FC = () => {
|
||||
const fetchCameras = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
const [camerasRes, statusRes] = await Promise.all([
|
||||
const [camerasRes, statusRes, pipelineRes] = await Promise.all([
|
||||
axios.get('/api/cameras?enabled_only=false'),
|
||||
axios.get('/api/camera/status/all'),
|
||||
axios.get('/api/pipeline/status')
|
||||
]);
|
||||
setCameras(camerasRes.data);
|
||||
setCameraStatus(statusRes.data.cameras || {});
|
||||
|
||||
const statusMap: Record<number, CameraStatus> = {};
|
||||
|
||||
for (const cam of camerasRes.data) {
|
||||
const camId = cam.id;
|
||||
|
||||
let status: CameraStatus = {
|
||||
is_running: false,
|
||||
fps: 0,
|
||||
error_message: null,
|
||||
last_check_time: null,
|
||||
};
|
||||
|
||||
const pipelineStatus = pipelineRes.data.cameras?.[String(camId)];
|
||||
if (pipelineStatus) {
|
||||
status.is_running = pipelineStatus.is_running || false;
|
||||
status.fps = pipelineStatus.fps || 0;
|
||||
status.last_check_time = pipelineStatus.last_check_time;
|
||||
}
|
||||
|
||||
const dbStatus = statusRes.data.find((s: any) => s.camera_id === camId);
|
||||
if (dbStatus) {
|
||||
if (!status.is_running) {
|
||||
status.is_running = dbStatus.is_running || false;
|
||||
}
|
||||
status.fps = dbStatus.fps || status.fps;
|
||||
status.error_message = dbStatus.error_message;
|
||||
status.last_check_time = status.last_check_time || dbStatus.last_check_time;
|
||||
}
|
||||
|
||||
statusMap[camId] = status;
|
||||
}
|
||||
|
||||
setCameraStatus(statusMap);
|
||||
} catch (err) {
|
||||
message.error('获取摄像头列表失败');
|
||||
console.error('获取摄像头状态失败', err);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
@@ -59,8 +93,8 @@ const CameraManagement: React.FC = () => {
|
||||
|
||||
const extractIP = (url: string): string => {
|
||||
try {
|
||||
const match = url.match(/:\/\/([^:]+):?(\d+)?\//);
|
||||
return match ? match[1] : '未知';
|
||||
const ipMatch = url.match(/(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})/);
|
||||
return ipMatch ? ipMatch[1] : '未知';
|
||||
} catch {
|
||||
return '未知';
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ const Dashboard: React.FC = () => {
|
||||
const [recentAlerts, setRecentAlerts] = useState<Alert[]>([]);
|
||||
const [cameraStatus, setCameraStatus] = useState<any[]>([]);
|
||||
|
||||
const handleViewSnapshot = (alert: Alert) => {
|
||||
window.open(`/api/alarms/${alert.id}/snapshot`, '_blank');
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchStats();
|
||||
fetchAlerts();
|
||||
@@ -60,7 +64,8 @@ const Dashboard: React.FC = () => {
|
||||
const res = await axios.get('/api/pipeline/status');
|
||||
const cameras = Object.entries(res.data.cameras || {}).map(([id, info]) => ({
|
||||
id,
|
||||
...info as any,
|
||||
is_running: (info as any).is_running || false,
|
||||
fps: (info as any).fps || 0,
|
||||
}));
|
||||
setCameraStatus(cameras);
|
||||
} catch (err) {
|
||||
@@ -139,7 +144,7 @@ const Dashboard: React.FC = () => {
|
||||
description={formatTime(alert.created_at)}
|
||||
/>
|
||||
{alert.snapshot_path && (
|
||||
<Button type="link" size="small">
|
||||
<Button type="link" size="small" onClick={() => handleViewSnapshot(alert)}>
|
||||
查看截图
|
||||
</Button>
|
||||
)}
|
||||
@@ -159,8 +164,8 @@ const Dashboard: React.FC = () => {
|
||||
title={`摄像头 ${cam.id}`}
|
||||
description={
|
||||
<Space>
|
||||
<Tag color={cam.running ? 'green' : 'red'}>
|
||||
{cam.running ? '运行中' : '已停止'}
|
||||
<Tag color={cam.is_running ? 'green' : 'red'}>
|
||||
{cam.is_running ? '运行中' : '已停止'}
|
||||
</Tag>
|
||||
<span>{cam.fps?.toFixed(1) || 0} FPS</span>
|
||||
</Space>
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
import React, { useEffect, useState, useRef } from 'react';
|
||||
import { Card, Button, Space, Select, message, Drawer, Form, Input, InputNumber, Switch } from 'antd';
|
||||
import { Card, Button, Space, Select, message, Drawer, Form, Input, InputNumber, Switch, TimePicker, Divider } from 'antd';
|
||||
import { Stage, Layer, Rect, Line, Circle, Text as KonvaText } from 'react-konva';
|
||||
import { RangePickerProps } from 'antd/es/date-picker';
|
||||
import axios from 'axios';
|
||||
import dayjs from 'dayjs';
|
||||
|
||||
interface WorkingHours {
|
||||
start: number[];
|
||||
end: number[];
|
||||
}
|
||||
|
||||
interface ROI {
|
||||
id: number;
|
||||
@@ -13,6 +20,7 @@ interface ROI {
|
||||
threshold_sec: number;
|
||||
confirm_sec: number;
|
||||
return_sec: number;
|
||||
working_hours: WorkingHours[] | null;
|
||||
}
|
||||
|
||||
interface Camera {
|
||||
@@ -29,12 +37,38 @@ const ROIEditor: React.FC = () => {
|
||||
const [selectedROI, setSelectedROI] = useState<ROI | null>(null);
|
||||
const [drawerVisible, setDrawerVisible] = useState(false);
|
||||
const [form] = Form.useForm();
|
||||
const [workingHoursList, setWorkingHoursList] = useState<{start: dayjs.Dayjs | null, end: dayjs.Dayjs | null}[]>([]);
|
||||
|
||||
const [isDrawing, setIsDrawing] = useState(false);
|
||||
const [tempPoints, setTempPoints] = useState<number[][]>([]);
|
||||
const [backgroundImage, setBackgroundImage] = useState<HTMLImageElement | null>(null);
|
||||
const stageRef = useRef<any>(null);
|
||||
|
||||
const addWorkingHours = () => {
|
||||
setWorkingHoursList([...workingHoursList, { start: null, end: null }]);
|
||||
};
|
||||
|
||||
const removeWorkingHours = (index: number) => {
|
||||
const newList = workingHoursList.filter((_, i) => i !== index);
|
||||
setWorkingHoursList(newList);
|
||||
};
|
||||
|
||||
const updateWorkingHours = (index: number, field: 'start' | 'end', value: dayjs.Dayjs | null) => {
|
||||
const newList = [...workingHoursList];
|
||||
newList[index] = { ...newList[index], [field]: value };
|
||||
setWorkingHoursList(newList);
|
||||
};
|
||||
|
||||
const updateWorkingHoursRange = (index: number, start: dayjs.Dayjs | null, end: dayjs.Dayjs | null) => {
|
||||
setWorkingHoursList(prev => {
|
||||
const newList = [...prev];
|
||||
if (newList[index]) {
|
||||
newList[index] = { start, end };
|
||||
}
|
||||
return newList;
|
||||
});
|
||||
};
|
||||
|
||||
const fetchCameras = async () => {
|
||||
try {
|
||||
const res = await axios.get('/api/cameras?enabled_only=true');
|
||||
@@ -95,16 +129,25 @@ const ROIEditor: React.FC = () => {
|
||||
const handleSaveROI = async (values: any) => {
|
||||
if (!selectedCamera || !selectedROI) return;
|
||||
try {
|
||||
const workingHours = workingHoursList
|
||||
.filter(item => item.start && item.end)
|
||||
.map(item => ({
|
||||
start: [item.start!.hour(), item.start!.minute()],
|
||||
end: [item.end!.hour(), item.end!.minute()],
|
||||
}));
|
||||
|
||||
await axios.put(`/api/camera/${selectedCamera}/roi/${selectedROI.id}`, {
|
||||
name: values.name,
|
||||
roi_type: values.roi_type,
|
||||
rule_type: values.rule_type,
|
||||
threshold_sec: values.threshold_sec,
|
||||
confirm_sec: values.confirm_sec,
|
||||
working_hours: workingHours,
|
||||
enabled: values.enabled,
|
||||
});
|
||||
message.success('保存成功');
|
||||
setDrawerVisible(false);
|
||||
setWorkingHoursList([]);
|
||||
fetchROIs();
|
||||
} catch (err: any) {
|
||||
message.error(`保存失败: ${err.response?.data?.detail || '未知错误'}`);
|
||||
@@ -150,6 +193,7 @@ const ROIEditor: React.FC = () => {
|
||||
threshold_sec: 60,
|
||||
confirm_sec: 5,
|
||||
return_sec: 5,
|
||||
working_hours: [],
|
||||
})
|
||||
.then(() => {
|
||||
message.success('ROI添加成功');
|
||||
@@ -212,6 +256,10 @@ const ROIEditor: React.FC = () => {
|
||||
confirm_sec: roi.confirm_sec,
|
||||
enabled: roi.enabled,
|
||||
});
|
||||
setWorkingHoursList(roi.working_hours?.map((wh: WorkingHours) => ({
|
||||
start: wh.start ? dayjs().hour(wh.start[0]).minute(wh.start[1]) : null,
|
||||
end: wh.end ? dayjs().hour(wh.end[0]).minute(wh.end[1]) : null,
|
||||
})) || []);
|
||||
setDrawerVisible(true);
|
||||
}}
|
||||
onMouseEnter={(e) => {
|
||||
@@ -369,6 +417,10 @@ const ROIEditor: React.FC = () => {
|
||||
confirm_sec: roi.confirm_sec,
|
||||
enabled: roi.enabled,
|
||||
});
|
||||
setWorkingHoursList(roi.working_hours?.map((wh: WorkingHours) => ({
|
||||
start: wh.start ? dayjs().hour(wh.start[0]).minute(wh.start[1]) : null,
|
||||
end: wh.end ? dayjs().hour(wh.end[0]).minute(wh.end[1]) : null,
|
||||
})) || []);
|
||||
setDrawerVisible(true);
|
||||
}}
|
||||
>
|
||||
@@ -403,6 +455,7 @@ const ROIEditor: React.FC = () => {
|
||||
onClose={() => {
|
||||
setDrawerVisible(false);
|
||||
setSelectedROI(null);
|
||||
setWorkingHoursList([]);
|
||||
}}
|
||||
width={400}
|
||||
>
|
||||
@@ -434,6 +487,39 @@ const ROIEditor: React.FC = () => {
|
||||
<Form.Item name="confirm_sec" label="确认时间(秒)" rules={[{ required: true }]}>
|
||||
<InputNumber min={5} style={{ width: '100%' }} />
|
||||
</Form.Item>
|
||||
<Divider>工作时间配置(可选)</Divider>
|
||||
<div>
|
||||
{workingHoursList.map((item, index) => (
|
||||
<Space key={index} align="baseline" style={{ display: 'flex', marginBottom: 8 }}>
|
||||
<Form.Item label={index === 0 ? '时间段' : ''} style={{ marginBottom: 0 }}>
|
||||
<TimePicker.RangePicker
|
||||
format="HH:mm"
|
||||
value={item.start && item.end ? [item.start, item.end] : null}
|
||||
onChange={(dates) => {
|
||||
if (dates && Array.isArray(dates) && dates.length >= 2 && dates[0] && dates[1]) {
|
||||
updateWorkingHoursRange(index, dates[0], dates[1]);
|
||||
} else {
|
||||
updateWorkingHoursRange(index, null, null);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Button
|
||||
type="link"
|
||||
danger
|
||||
onClick={() => removeWorkingHours(index)}
|
||||
>
|
||||
删除
|
||||
</Button>
|
||||
</Space>
|
||||
))}
|
||||
<Button type="dashed" onClick={addWorkingHours} block>
|
||||
添加时间段
|
||||
</Button>
|
||||
</div>
|
||||
<Form.Item style={{ fontSize: 12, color: '#999' }}>
|
||||
不配置工作时间则使用系统全局设置
|
||||
</Form.Item>
|
||||
</>
|
||||
)}
|
||||
<Form.Item name="enabled" label="启用状态" valuePropName="checked">
|
||||
@@ -447,6 +533,7 @@ const ROIEditor: React.FC = () => {
|
||||
<Button onClick={() => {
|
||||
setDrawerVisible(false);
|
||||
setSelectedROI(null);
|
||||
setWorkingHoursList([]);
|
||||
}}>
|
||||
取消
|
||||
</Button>
|
||||
|
||||
@@ -49,6 +49,10 @@ class ONNXEngine:
|
||||
return img
|
||||
|
||||
def postprocess(self, output: np.ndarray, orig_img: np.ndarray) -> List[Results]:
|
||||
import torch
|
||||
import numpy as np
|
||||
from ultralytics.engine.results import Boxes as BoxesObj, Results
|
||||
|
||||
c, n = output.shape
|
||||
output = output.T
|
||||
|
||||
@@ -74,6 +78,9 @@ class ONNXEngine:
|
||||
orig_h, orig_w = orig_img.shape[:2]
|
||||
scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0]
|
||||
|
||||
if len(indices) == 0:
|
||||
return [Results(orig_img=orig_img, path="", names={0: "person"})]
|
||||
|
||||
filtered_boxes = []
|
||||
for idx in indices:
|
||||
if idx >= len(boxes):
|
||||
@@ -82,30 +89,30 @@ class ONNXEngine:
|
||||
x1, y1, x2, y2 = box
|
||||
w, h = x2 - x1, y2 - y1
|
||||
filtered_boxes.append([
|
||||
int(x1 * scale_x),
|
||||
int(y1 * scale_y),
|
||||
int(w * scale_x),
|
||||
int(h * scale_y),
|
||||
float(x1 * scale_x),
|
||||
float(y1 * scale_y),
|
||||
float(w * scale_x),
|
||||
float(h * scale_y),
|
||||
float(scores[idx]),
|
||||
int(classes[idx])
|
||||
])
|
||||
|
||||
from ultralytics.engine.results import Boxes as BoxesObj
|
||||
if filtered_boxes:
|
||||
box_tensor = torch.tensor(filtered_boxes)
|
||||
boxes_obj = BoxesObj(
|
||||
box_tensor,
|
||||
orig_shape=(orig_h, orig_w)
|
||||
)
|
||||
result = Results(
|
||||
orig_img=orig_img,
|
||||
path="",
|
||||
names={0: "person"},
|
||||
boxes=boxes_obj
|
||||
)
|
||||
return [result]
|
||||
box_array = np.array(filtered_boxes, dtype=np.float32)
|
||||
else:
|
||||
box_array = np.zeros((0, 6), dtype=np.float32)
|
||||
|
||||
return [Results(orig_img=orig_img, path="", names={0: "person"})]
|
||||
boxes_obj = BoxesObj(
|
||||
torch.from_numpy(box_array),
|
||||
orig_shape=(orig_h, orig_w)
|
||||
)
|
||||
result = Results(
|
||||
orig_img=orig_img,
|
||||
path="",
|
||||
names={0: "person"},
|
||||
boxes=boxes_obj
|
||||
)
|
||||
return [result]
|
||||
|
||||
def inference(self, images: List[np.ndarray]) -> List[Results]:
|
||||
if not images:
|
||||
@@ -183,29 +190,21 @@ class TensorRTEngine:
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
self.stream = torch.cuda.Stream(device=self.device)
|
||||
self.batch_size = 1
|
||||
|
||||
for i in range(self.engine.num_io_tensors):
|
||||
name = self.engine.get_tensor_name(i)
|
||||
dtype = self.engine.get_tensor_dtype(name)
|
||||
shape = list(self.engine.get_tensor_shape(name))
|
||||
|
||||
if dtype == trt.float16:
|
||||
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
|
||||
else:
|
||||
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
|
||||
|
||||
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||
if -1 in shape:
|
||||
shape = [self.batch_size if d == -1 else d for d in shape]
|
||||
if dtype == trt.float16:
|
||||
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
|
||||
else:
|
||||
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
|
||||
self.input_buffer = buffer
|
||||
self.input_name = name
|
||||
else:
|
||||
if -1 in shape:
|
||||
shape = [self.batch_size if d == -1 else d for d in shape]
|
||||
if dtype == trt.float16:
|
||||
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
|
||||
else:
|
||||
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
|
||||
self.output_buffers.append(buffer)
|
||||
if self.output_name is None:
|
||||
self.output_name = name
|
||||
@@ -215,8 +214,6 @@ class TensorRTEngine:
|
||||
stream_handle = torch.cuda.current_stream(self.device).cuda_stream
|
||||
self.context.set_optimization_profile_async(0, stream_handle)
|
||||
|
||||
self.batch_size = 1
|
||||
|
||||
def preprocess(self, frame: np.ndarray) -> torch.Tensor:
|
||||
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, self.imgsz)
|
||||
@@ -247,9 +244,6 @@ class TensorRTEngine:
|
||||
self.input_name, input_tensor.contiguous().data_ptr()
|
||||
)
|
||||
|
||||
input_shape = list(input_tensor.shape)
|
||||
self.context.set_input_shape(self.input_name, input_shape)
|
||||
|
||||
torch.cuda.synchronize(self.stream)
|
||||
self.context.execute_async_v3(self.stream.cuda_stream)
|
||||
torch.cuda.synchronize(self.stream)
|
||||
@@ -336,6 +330,10 @@ class Boxes:
|
||||
self.orig_shape = orig_shape
|
||||
self.is_track = is_track
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
return self.data.ndim
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
if self.is_track:
|
||||
@@ -369,35 +367,15 @@ class YOLOEngine:
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
device: int = 0,
|
||||
use_trt: bool = True,
|
||||
use_trt: bool = False,
|
||||
):
|
||||
self.use_trt = False
|
||||
self.onnx_engine = None
|
||||
self.trt_engine = None
|
||||
self.model = None
|
||||
self.device = device
|
||||
config = get_config()
|
||||
|
||||
if use_trt:
|
||||
try:
|
||||
self.trt_engine = TensorRTEngine(device=device)
|
||||
self.trt_engine.warmup()
|
||||
self.use_trt = True
|
||||
print("TensorRT引擎加载成功")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"TensorRT加载失败: {e}")
|
||||
|
||||
try:
|
||||
onnx_path = config.model.onnx_path
|
||||
if os.path.exists(onnx_path):
|
||||
self.onnx_engine = ONNXEngine(device=device)
|
||||
self.onnx_engine.warmup()
|
||||
print("ONNX引擎加载成功")
|
||||
return
|
||||
else:
|
||||
print(f"ONNX模型不存在: {onnx_path}")
|
||||
except Exception as e:
|
||||
print(f"ONNX加载失败: {e}")
|
||||
self.config = config
|
||||
|
||||
try:
|
||||
pt_path = model_path or config.model.pt_model_path
|
||||
@@ -409,26 +387,17 @@ class YOLOEngine:
|
||||
raise FileNotFoundError(f"PT文件无效或不存在: {pt_path}")
|
||||
except Exception as e:
|
||||
print(f"PyTorch加载失败: {e}")
|
||||
raise RuntimeError("所有模型加载方式均失败")
|
||||
raise RuntimeError("无法加载模型")
|
||||
|
||||
def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]:
|
||||
if self.use_trt and self.trt_engine:
|
||||
if self.model is not None:
|
||||
try:
|
||||
return self.trt_engine.inference_single(frame)
|
||||
return self.model(frame, imgsz=self.config.model.imgsz, conf=self.config.model.conf_threshold, iou=self.config.model.iou_threshold, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"TensorRT推理失败,切换到ONNX: {e}")
|
||||
self.use_trt = False
|
||||
if self.onnx_engine:
|
||||
return self.onnx_engine.inference_single(frame)
|
||||
elif self.model:
|
||||
return self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
|
||||
else:
|
||||
return []
|
||||
elif self.onnx_engine:
|
||||
return self.onnx_engine.inference_single(frame)
|
||||
else:
|
||||
results = self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
|
||||
return results
|
||||
print(f"PyTorch推理失败: {e}")
|
||||
|
||||
print("警告: 模型不可用,返回空结果")
|
||||
return []
|
||||
|
||||
def __del__(self):
|
||||
if self.trt_engine:
|
||||
|
||||
@@ -7,6 +7,7 @@ from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from config import get_config
|
||||
@@ -186,9 +187,10 @@ class InferencePipeline:
|
||||
roi_id,
|
||||
rule_type,
|
||||
{
|
||||
"threshold_sec": roi_config.get("threshold_sec", 360),
|
||||
"confirm_sec": roi_config.get("confirm_sec", 30),
|
||||
"return_sec": roi_config.get("return_sec", 5),
|
||||
"threshold_sec": roi_config.get("threshold_sec", 300),
|
||||
"confirm_sec": roi_config.get("confirm_sec", 10),
|
||||
"return_sec": roi_config.get("return_sec", 30),
|
||||
"working_hours": roi_config.get("working_hours"),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -216,22 +218,30 @@ class InferencePipeline:
|
||||
else:
|
||||
filtered_detections = detections
|
||||
|
||||
roi_detections: Dict[str, List[Dict]] = {}
|
||||
for detection in filtered_detections:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
for roi_conf in matched_rois:
|
||||
roi_id = roi_conf["roi_id"]
|
||||
rule_type = roi_conf["rule"]
|
||||
if roi_id not in roi_detections:
|
||||
roi_detections[roi_id] = []
|
||||
roi_detections[roi_id].append(detection)
|
||||
|
||||
alerts = self.algo_manager.process(
|
||||
roi_id,
|
||||
str(camera_id),
|
||||
rule_type,
|
||||
[detection],
|
||||
datetime.now(),
|
||||
)
|
||||
for roi_config in roi_configs:
|
||||
roi_id = roi_config["roi_id"]
|
||||
rule_type = roi_config["rule"]
|
||||
roi_dets = roi_detections.get(roi_id, [])
|
||||
|
||||
for alert in alerts:
|
||||
self._handle_alert(camera_id, alert, frame, roi_conf)
|
||||
alerts = self.algo_manager.process(
|
||||
roi_id,
|
||||
str(camera_id),
|
||||
rule_type,
|
||||
roi_dets,
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
for alert in alerts:
|
||||
self._handle_alert(camera_id, alert, frame, roi_config)
|
||||
|
||||
def _handle_alert(
|
||||
self,
|
||||
@@ -322,20 +332,23 @@ class InferencePipeline:
|
||||
print("推理pipeline已停止")
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
result = {
|
||||
"running": self.running,
|
||||
"camera_count": len(self.camera_threads),
|
||||
"cameras": {
|
||||
cid: {
|
||||
"running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(),
|
||||
"fps": self.get_camera_fps(cid),
|
||||
"frame_time": self.camera_frame_times.get(cid).isoformat() if self.camera_frame_times.get(cid) else None,
|
||||
}
|
||||
for cid in self.camera_threads
|
||||
},
|
||||
"cameras": {},
|
||||
"event_queue_size": len(self.event_queue),
|
||||
}
|
||||
|
||||
for cid in self.camera_threads:
|
||||
frame_time = self.camera_frame_times.get(cid)
|
||||
result["cameras"][str(cid)] = {
|
||||
"is_running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(),
|
||||
"fps": self.get_camera_fps(cid),
|
||||
"last_check_time": frame_time.isoformat() if frame_time else None,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_pipeline: Optional[InferencePipeline] = None
|
||||
|
||||
|
||||
167
inference/roi/cache_manager.py
Normal file
167
inference/roi/cache_manager.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, List, Optional, Callable
|
||||
from collections import deque
|
||||
|
||||
|
||||
class ROICacheManager:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
self._cache: Dict[int, List[Dict]] = {}
|
||||
self._cache_timestamps: Dict[int, float] = {}
|
||||
self._refresh_interval = 10.0
|
||||
self._db_session_factory = None
|
||||
self._refresh_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
self._last_refresh_time = 0
|
||||
self._on_cache_update: Optional[Callable[[int], None]] = None
|
||||
self._update_callbacks: Dict[int, List[Callable]] = {}
|
||||
|
||||
def initialize(self, session_factory, refresh_interval: float = 10.0):
|
||||
self._db_session_factory = session_factory
|
||||
self._refresh_interval = refresh_interval
|
||||
|
||||
def start_background_refresh(self):
|
||||
if self._refresh_thread is not None and self._refresh_thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._refresh_thread = threading.Thread(target=self._background_refresh_loop, daemon=True)
|
||||
self._refresh_thread.start()
|
||||
|
||||
def stop_background_refresh(self):
|
||||
self._stop_event.set()
|
||||
if self._refresh_thread is not None:
|
||||
self._refresh_thread.join(timeout=2)
|
||||
self._refresh_thread = None
|
||||
|
||||
def _background_refresh_loop(self):
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
self.refresh_all()
|
||||
except Exception:
|
||||
pass
|
||||
self._stop_event.wait(self._refresh_interval)
|
||||
|
||||
def _load_rois_from_db(self, camera_id: int) -> List[Dict]:
|
||||
if self._db_session_factory is None:
|
||||
return []
|
||||
|
||||
session = self._db_session_factory()
|
||||
try:
|
||||
from db.crud import get_all_rois
|
||||
rois = get_all_rois(session, camera_id)
|
||||
roi_configs = []
|
||||
for roi in rois:
|
||||
try:
|
||||
points = json.loads(roi.points) if isinstance(roi.points, str) else roi.points
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
points = []
|
||||
|
||||
roi_config = {
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"name": roi.name,
|
||||
"type": roi.roi_type,
|
||||
"points": points,
|
||||
"rule": roi.rule_type,
|
||||
"direction": roi.direction,
|
||||
"enabled": roi.enabled,
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
"working_hours": json.loads(roi.working_hours) if roi.working_hours else None,
|
||||
}
|
||||
roi_configs.append(roi_config)
|
||||
return roi_configs
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def refresh_all(self):
|
||||
if self._db_session_factory is None:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self._last_refresh_time < 1.0:
|
||||
return
|
||||
|
||||
self._last_refresh_time = current_time
|
||||
camera_ids = list(self._cache.keys())
|
||||
|
||||
for camera_id in camera_ids:
|
||||
try:
|
||||
new_rois = self._load_rois_from_db(camera_id)
|
||||
old_rois_str = str(self._cache.get(camera_id, []))
|
||||
new_rois_str = str(new_rois)
|
||||
|
||||
if old_rois_str != new_rois_str:
|
||||
self._cache[camera_id] = new_rois
|
||||
self._cache_timestamps[camera_id] = current_time
|
||||
self._notify_update(camera_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_rois(self, camera_id: int, force_refresh: bool = False) -> List[Dict]:
|
||||
if force_refresh or camera_id not in self._cache:
|
||||
self._cache[camera_id] = self._load_rois_from_db(camera_id)
|
||||
self._cache_timestamps[camera_id] = time.time()
|
||||
|
||||
return self._cache.get(camera_id, [])
|
||||
|
||||
def get_rois_by_rule(self, camera_id: int, rule_type: str) -> List[Dict]:
|
||||
rois = self.get_rois(camera_id)
|
||||
return [roi for roi in rois if roi.get("rule") == rule_type and roi.get("enabled", True)]
|
||||
|
||||
def invalidate(self, camera_id: Optional[int] = None):
|
||||
if camera_id is None:
|
||||
self._cache.clear()
|
||||
self._cache_timestamps.clear()
|
||||
elif camera_id in self._cache:
|
||||
del self._cache[camera_id]
|
||||
if camera_id in self._cache_timestamps:
|
||||
del self._cache_timestamps[camera_id]
|
||||
|
||||
def register_update_callback(self, camera_id: int, callback: Callable):
|
||||
if camera_id not in self._update_callbacks:
|
||||
self._update_callbacks[camera_id] = []
|
||||
self._update_callbacks[camera_id].append(callback)
|
||||
|
||||
def _notify_update(self, camera_id: int):
|
||||
if camera_id in self._update_callbacks:
|
||||
for callback in self._update_callbacks[camera_id]:
|
||||
try:
|
||||
callback(camera_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_cache_info(self) -> Dict:
|
||||
return {
|
||||
"camera_count": len(self._cache),
|
||||
"refresh_interval": self._refresh_interval,
|
||||
"cameras": {
|
||||
cam_id: {
|
||||
"roi_count": len(rois),
|
||||
"last_update": self._cache_timestamps.get(cam_id, 0),
|
||||
}
|
||||
for cam_id, rois in self._cache.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_roi_cache() -> ROICacheManager:
|
||||
return ROICacheManager()
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -8,15 +10,18 @@ import numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sort import Sort
|
||||
|
||||
|
||||
class LeavePostAlgorithm:
|
||||
STATE_ON_DUTY = "ON_DUTY"
|
||||
STATE_OFF_DUTY_COUNTDOWN = "OFF_DUTY_COUNTDOWN"
|
||||
STATE_NON_WORK_TIME = "NON_WORK_TIME"
|
||||
STATE_INIT = "INIT"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold_sec: int = 360,
|
||||
confirm_sec: int = 30,
|
||||
return_sec: int = 5,
|
||||
threshold_sec: int = 300,
|
||||
confirm_sec: int = 10,
|
||||
return_sec: int = 30,
|
||||
working_hours: Optional[List[Dict]] = None,
|
||||
):
|
||||
self.threshold_sec = threshold_sec
|
||||
@@ -24,12 +29,17 @@ class LeavePostAlgorithm:
|
||||
self.return_sec = return_sec
|
||||
self.working_hours = working_hours or []
|
||||
|
||||
self.track_states: Dict[str, Dict[str, Any]] = {}
|
||||
self.tracker = Sort(max_age=10, min_hits=2, iou_threshold=0.3)
|
||||
|
||||
self.alert_cooldowns: Dict[str, datetime] = {}
|
||||
self.cooldown_seconds = 300
|
||||
|
||||
self.state: str = self.STATE_INIT
|
||||
self.state_start_time: Optional[datetime] = None
|
||||
self.on_duty_window = deque()
|
||||
self.alarm_sent: bool = False
|
||||
self.last_person_seen_time: Optional[datetime] = None
|
||||
self.last_detection_time: Optional[datetime] = None
|
||||
self.init_start_time: Optional[datetime] = None
|
||||
|
||||
def is_in_working_hours(self, dt: Optional[datetime] = None) -> bool:
|
||||
if not self.working_hours:
|
||||
return True
|
||||
@@ -45,159 +55,199 @@ class LeavePostAlgorithm:
|
||||
|
||||
return False
|
||||
|
||||
def check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
for roi in matched_rois:
|
||||
if roi.get("roi_id") == roi_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
if not self.is_in_working_hours(current_time):
|
||||
return []
|
||||
|
||||
if not tracks:
|
||||
return []
|
||||
|
||||
detections = []
|
||||
for track in tracks:
|
||||
bbox = track.get("bbox", [])
|
||||
if len(bbox) >= 4:
|
||||
detections.append(bbox + [track.get("conf", 0.0)])
|
||||
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
detections = np.array(detections)
|
||||
tracked = self.tracker.update(detections)
|
||||
|
||||
alerts = []
|
||||
current_time = current_time or datetime.now()
|
||||
|
||||
for track_data in tracked:
|
||||
x1, y1, x2, y2, track_id = track_data
|
||||
track_id = str(int(track_id))
|
||||
roi_has_person = False
|
||||
for det in tracks:
|
||||
if self.check_detection_in_roi(det, roi_id):
|
||||
roi_has_person = True
|
||||
break
|
||||
|
||||
if track_id not in self.track_states:
|
||||
self.track_states[track_id] = {
|
||||
"first_seen": current_time,
|
||||
"last_seen": current_time,
|
||||
"off_duty_start": None,
|
||||
"alerted": False,
|
||||
"last_position": (x1, y1, x2, y2),
|
||||
}
|
||||
in_work = self.is_in_working_hours(current_time)
|
||||
alerts = []
|
||||
|
||||
state = self.track_states[track_id]
|
||||
state["last_seen"] = current_time
|
||||
state["last_position"] = (x1, y1, x2, y2)
|
||||
if not in_work:
|
||||
self.state = self.STATE_NON_WORK_TIME
|
||||
self.last_person_seen_time = None
|
||||
self.last_detection_time = None
|
||||
self.on_duty_window.clear()
|
||||
self.alarm_sent = False
|
||||
self.init_start_time = None
|
||||
else:
|
||||
if self.state == self.STATE_NON_WORK_TIME:
|
||||
self.state = self.STATE_INIT
|
||||
self.init_start_time = current_time
|
||||
self.on_duty_window.clear()
|
||||
self.alarm_sent = False
|
||||
|
||||
if state["off_duty_start"] is None:
|
||||
off_duty_duration = (current_time - state["first_seen"]).total_seconds()
|
||||
if off_duty_duration > self.confirm_sec:
|
||||
state["off_duty_start"] = current_time
|
||||
else:
|
||||
elapsed = (current_time - state["off_duty_start"]).total_seconds()
|
||||
if elapsed > self.threshold_sec:
|
||||
if not state["alerted"]:
|
||||
cooldown_key = f"{camera_id}_{track_id}"
|
||||
now = datetime.now()
|
||||
if self.state == self.STATE_INIT:
|
||||
if roi_has_person:
|
||||
self.state = self.STATE_ON_DUTY
|
||||
self.state_start_time = current_time
|
||||
self.on_duty_window.clear()
|
||||
self.on_duty_window.append((current_time, True))
|
||||
self.last_person_seen_time = current_time
|
||||
self.last_detection_time = current_time
|
||||
self.init_start_time = None
|
||||
else:
|
||||
if self.init_start_time is None:
|
||||
self.init_start_time = current_time
|
||||
|
||||
elapsed_since_init = (current_time - self.init_start_time).total_seconds()
|
||||
if elapsed_since_init >= self.threshold_sec:
|
||||
self.state = self.STATE_OFF_DUTY_COUNTDOWN
|
||||
self.state_start_time = current_time
|
||||
self.alarm_sent = False
|
||||
|
||||
elif self.state == self.STATE_ON_DUTY:
|
||||
if roi_has_person:
|
||||
self.last_person_seen_time = current_time
|
||||
self.last_detection_time = current_time
|
||||
|
||||
self.on_duty_window.append((current_time, True))
|
||||
while self.on_duty_window and (current_time - self.on_duty_window[0][0]).total_seconds() > self.confirm_sec:
|
||||
self.on_duty_window.popleft()
|
||||
else:
|
||||
self.on_duty_window.append((current_time, False))
|
||||
while self.on_duty_window and (current_time - self.on_duty_window[0][0]).total_seconds() > self.confirm_sec:
|
||||
self.on_duty_window.popleft()
|
||||
|
||||
hit_ratio = sum(1 for t, detected in self.on_duty_window if detected) / max(len(self.on_duty_window), 1)
|
||||
|
||||
if hit_ratio == 0:
|
||||
self.state = self.STATE_OFF_DUTY_COUNTDOWN
|
||||
self.state_start_time = current_time
|
||||
self.alarm_sent = False
|
||||
|
||||
elif self.state == self.STATE_OFF_DUTY_COUNTDOWN:
|
||||
elapsed = (current_time - self.state_start_time).total_seconds()
|
||||
|
||||
if roi_has_person:
|
||||
self.state = self.STATE_ON_DUTY
|
||||
self.state_start_time = current_time
|
||||
self.on_duty_window.clear()
|
||||
self.on_duty_window.append((current_time, True))
|
||||
self.last_person_seen_time = current_time
|
||||
self.alarm_sent = False
|
||||
elif elapsed >= self.threshold_sec:
|
||||
if not self.alarm_sent:
|
||||
cooldown_key = f"{roi_id}"
|
||||
if cooldown_key not in self.alert_cooldowns or (
|
||||
now - self.alert_cooldowns[cooldown_key]
|
||||
current_time - self.alert_cooldowns[cooldown_key]
|
||||
).total_seconds() > self.cooldown_seconds:
|
||||
bbox = self.get_latest_bbox_in_roi(tracks, roi_id)
|
||||
elapsed_minutes = int(elapsed / 60)
|
||||
alerts.append({
|
||||
"track_id": track_id,
|
||||
"bbox": [x1, y1, x2, y2],
|
||||
"track_id": roi_id,
|
||||
"bbox": bbox,
|
||||
"off_duty_duration": elapsed,
|
||||
"alert_type": "leave_post",
|
||||
"message": f"离岗超过 {int(elapsed / 60)} 分钟",
|
||||
"message": f"离岗超过 {elapsed_minutes} 分钟",
|
||||
})
|
||||
state["alerted"] = True
|
||||
self.alert_cooldowns[cooldown_key] = now
|
||||
else:
|
||||
if elapsed < self.return_sec:
|
||||
state["off_duty_start"] = None
|
||||
state["alerted"] = False
|
||||
|
||||
cleanup_time = current_time - timedelta(minutes=5)
|
||||
for track_id, state in list(self.track_states.items()):
|
||||
if state["last_seen"] < cleanup_time:
|
||||
del self.track_states[track_id]
|
||||
self.alarm_sent = True
|
||||
self.alert_cooldowns[cooldown_key] = current_time
|
||||
|
||||
return alerts
|
||||
|
||||
def get_latest_bbox_in_roi(self, tracks: List[Dict], roi_id: str) -> List[float]:
|
||||
for det in tracks:
|
||||
if self.check_detection_in_roi(det, roi_id):
|
||||
return det.get("bbox", [])
|
||||
return []
|
||||
|
||||
def reset(self):
|
||||
self.track_states.clear()
|
||||
self.state = self.STATE_INIT
|
||||
self.state_start_time = None
|
||||
self.on_duty_window.clear()
|
||||
self.alarm_sent = False
|
||||
self.last_person_seen_time = None
|
||||
self.last_detection_time = None
|
||||
self.init_start_time = None
|
||||
self.alert_cooldowns.clear()
|
||||
|
||||
def get_state(self, track_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self.track_states.get(track_id)
|
||||
return {
|
||||
"state": self.state,
|
||||
"alarm_sent": self.alarm_sent,
|
||||
"last_person_seen_time": self.last_person_seen_time,
|
||||
}
|
||||
|
||||
|
||||
class IntrusionAlgorithm:
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_sec: float = 1.0,
|
||||
direction_sensitive: bool = False,
|
||||
):
|
||||
self.check_interval_sec = check_interval_sec
|
||||
self.direction_sensitive = direction_sensitive
|
||||
def __init__(self, cooldown_seconds: int = 300):
|
||||
self.cooldown_seconds = cooldown_seconds
|
||||
self.last_alert_time: Dict[str, float] = {}
|
||||
self.alert_triggered: Dict[str, bool] = {}
|
||||
|
||||
self.last_check_times: Dict[str, float] = {}
|
||||
self.tracker = Sort(max_age=5, min_hits=1, iou_threshold=0.3)
|
||||
def is_roi_has_person(self, tracks: List[Dict], roi_id: str) -> bool:
|
||||
for det in tracks:
|
||||
matched_rois = det.get("matched_rois", [])
|
||||
for roi in matched_rois:
|
||||
if roi.get("roi_id") == roi_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
self.alert_cooldowns: Dict[str, datetime] = {}
|
||||
self.cooldown_seconds = 300
|
||||
def get_latest_bbox_in_roi(self, tracks: List[Dict], roi_id: str) -> List[float]:
|
||||
for det in tracks:
|
||||
matched_rois = det.get("matched_rois", [])
|
||||
for roi in matched_rois:
|
||||
if roi.get("roi_id") == roi_id:
|
||||
return det.get("bbox", [])
|
||||
return []
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
if not tracks:
|
||||
roi_has_person = self.is_roi_has_person(tracks, roi_id)
|
||||
|
||||
if not roi_has_person:
|
||||
return []
|
||||
|
||||
detections = []
|
||||
for track in tracks:
|
||||
bbox = track.get("bbox", [])
|
||||
if len(bbox) >= 4:
|
||||
detections.append(bbox + [track.get("conf", 0.0)])
|
||||
now = time.monotonic()
|
||||
key = f"{camera_id}_{roi_id}"
|
||||
|
||||
if not detections:
|
||||
if key not in self.last_alert_time:
|
||||
self.last_alert_time[key] = 0
|
||||
self.alert_triggered[key] = False
|
||||
|
||||
if now - self.last_alert_time[key] >= self.cooldown_seconds:
|
||||
self.last_alert_time[key] = now
|
||||
self.alert_triggered[key] = False
|
||||
|
||||
if self.alert_triggered[key]:
|
||||
return []
|
||||
|
||||
current_ts = current_time.timestamp() if current_time else datetime.now().timestamp()
|
||||
bbox = self.get_latest_bbox_in_roi(tracks, roi_id)
|
||||
self.alert_triggered[key] = True
|
||||
|
||||
if camera_id in self.last_check_times:
|
||||
if current_ts - self.last_check_times[camera_id] < self.check_interval_sec:
|
||||
return []
|
||||
self.last_check_times[camera_id] = current_ts
|
||||
|
||||
detections = np.array(detections)
|
||||
tracked = self.tracker.update(detections)
|
||||
|
||||
alerts = []
|
||||
now = datetime.now()
|
||||
|
||||
for track_data in tracked:
|
||||
x1, y1, x2, y2, track_id = track_data
|
||||
cooldown_key = f"{camera_id}_{int(track_id)}"
|
||||
|
||||
if cooldown_key not in self.alert_cooldowns or (
|
||||
now - self.alert_cooldowns[cooldown_key]
|
||||
).total_seconds() > self.cooldown_seconds:
|
||||
alerts.append({
|
||||
"track_id": str(int(track_id)),
|
||||
"bbox": [x1, y1, x2, y2],
|
||||
"alert_type": "intrusion",
|
||||
"confidence": track_data[4] if len(track_data) > 4 else 0.0,
|
||||
"message": "检测到周界入侵",
|
||||
})
|
||||
self.alert_cooldowns[cooldown_key] = now
|
||||
|
||||
return alerts
|
||||
return [{
|
||||
"roi_id": roi_id,
|
||||
"bbox": bbox,
|
||||
"alert_type": "intrusion",
|
||||
"message": "检测到周界入侵",
|
||||
}]
|
||||
|
||||
def reset(self):
|
||||
self.last_check_times.clear()
|
||||
self.alert_cooldowns.clear()
|
||||
self.last_alert_time.clear()
|
||||
self.alert_triggered.clear()
|
||||
|
||||
|
||||
class AlgorithmManager:
|
||||
@@ -207,13 +257,12 @@ class AlgorithmManager:
|
||||
|
||||
self.default_params = {
|
||||
"leave_post": {
|
||||
"threshold_sec": 360,
|
||||
"confirm_sec": 30,
|
||||
"return_sec": 5,
|
||||
"threshold_sec": 300,
|
||||
"confirm_sec": 10,
|
||||
"return_sec": 30,
|
||||
},
|
||||
"intrusion": {
|
||||
"check_interval_sec": 1.0,
|
||||
"direction_sensitive": False,
|
||||
"cooldown_seconds": 300,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -235,16 +284,16 @@ class AlgorithmManager:
|
||||
algo_params.update(params)
|
||||
|
||||
if algorithm_type == "leave_post":
|
||||
roi_working_hours = algo_params.get("working_hours") or self.working_hours
|
||||
self.algorithms[roi_id]["leave_post"] = LeavePostAlgorithm(
|
||||
threshold_sec=algo_params.get("threshold_sec", 360),
|
||||
confirm_sec=algo_params.get("confirm_sec", 30),
|
||||
return_sec=algo_params.get("return_sec", 5),
|
||||
working_hours=self.working_hours,
|
||||
threshold_sec=algo_params.get("threshold_sec", 300),
|
||||
confirm_sec=algo_params.get("confirm_sec", 10),
|
||||
return_sec=algo_params.get("return_sec", 30),
|
||||
working_hours=roi_working_hours,
|
||||
)
|
||||
elif algorithm_type == "intrusion":
|
||||
self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm(
|
||||
check_interval_sec=algo_params.get("check_interval_sec", 1.0),
|
||||
direction_sensitive=algo_params.get("direction_sensitive", False),
|
||||
cooldown_seconds=algo_params.get("cooldown_seconds", 300),
|
||||
)
|
||||
|
||||
def process(
|
||||
@@ -258,7 +307,7 @@ class AlgorithmManager:
|
||||
algo = self.algorithms.get(roi_id, {}).get(algorithm_type)
|
||||
if algo is None:
|
||||
return []
|
||||
return algo.process(camera_id, tracks, current_time)
|
||||
return algo.process(roi_id, camera_id, tracks, current_time)
|
||||
|
||||
def update_roi_params(
|
||||
self,
|
||||
@@ -297,7 +346,13 @@ class AlgorithmManager:
|
||||
status = {}
|
||||
if roi_id in self.algorithms:
|
||||
for algo_type, algo in self.algorithms[roi_id].items():
|
||||
status[algo_type] = {
|
||||
"track_count": len(getattr(algo, "track_states", {})),
|
||||
}
|
||||
if algo_type == "leave_post":
|
||||
status[algo_type] = {
|
||||
"state": getattr(algo, "state", "INIT_STATE"),
|
||||
"alarm_sent": getattr(algo, "alarm_sent", False),
|
||||
}
|
||||
else:
|
||||
status[algo_type] = {
|
||||
"track_count": len(getattr(algo, "track_states", {})),
|
||||
}
|
||||
return status
|
||||
|
||||
106
logs/app.log
106
logs/app.log
@@ -133,3 +133,109 @@
|
||||
2026-01-21 13:18:55,795 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 13:18:55,809 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 13:19:08,492 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:01:21,015 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:01:21,257 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:03:48,547 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:03:48,563 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:04:01,197 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:04:20,191 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:04:20,414 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:05:48,342 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:05:48,355 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:06:00,984 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:07:24,065 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:07:24,222 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:08:10,073 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:08:10,088 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:08:22,715 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:09:05,249 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:09:05,480 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:11:29,491 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:11:29,513 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:11:42,900 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:14:04,974 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:14:05,161 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:14:41,203 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:14:41,220 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:14:54,380 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:15:30,975 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:15:31,180 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:16:24,472 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:16:24,485 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:16:37,611 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:17:01,178 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:17:01,420 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:18:00,008 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:18:00,022 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:18:13,126 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:18:13,128 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:18:21,683 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:20:04,985 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:20:04,999 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:20:18,151 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:21:24,782 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:21:24,927 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:22:48,064 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:22:48,078 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:23:01,270 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:23:13,509 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:23:13,628 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:24:16,374 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:24:16,386 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:24:29,425 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:24:42,751 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:24:42,846 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:25:25,549 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:25:25,562 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:25:38,636 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:26:02,871 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:26:03,124 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:26:45,885 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:26:45,899 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:26:59,042 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:27:26,873 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:27:26,980 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:31:38,376 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:31:38,390 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:31:51,594 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:32:17,471 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:32:17,536 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:32:53,841 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:32:53,855 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:33:06,946 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:34:30,645 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:34:30,818 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:38:24,673 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:38:24,685 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:38:37,183 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:39:04,359 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:39:04,486 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:40:07,246 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:40:07,259 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:40:19,745 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 14:40:33,742 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 14:40:33,863 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 14:41:27,191 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 14:41:27,205 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 14:41:39,701 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 15:03:14,674 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 15:03:14,688 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 15:03:27,230 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 15:06:28,976 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 15:06:28,990 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 15:06:41,537 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 15:06:41,539 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 15:06:49,686 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 15:07:27,870 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 15:07:27,884 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 15:07:40,380 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 15:07:58,160 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 15:07:58,299 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 15:08:28,521 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 15:08:28,533 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 15:08:41,019 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-21 15:09:16,894 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-21 15:09:17,139 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-21 15:09:41,042 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-21 15:09:41,055 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-21 15:09:53,555 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
|
||||
35
main.py
35
main.py
@@ -11,15 +11,28 @@ os.environ["TENSORRT_DISABLE_MYELIN"] = "1"
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from db.models import get_db
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from ultralytics.engine.results import Boxes as UltralyticsBoxes
|
||||
|
||||
def _patch_boxes_ndim():
|
||||
if not hasattr(UltralyticsBoxes, 'ndim'):
|
||||
@property
|
||||
def ndim(self):
|
||||
return self.data.ndim
|
||||
UltralyticsBoxes.ndim = ndim
|
||||
_patch_boxes_ndim()
|
||||
|
||||
from api.alarm import router as alarm_router
|
||||
from api.camera import router as camera_router
|
||||
from api.camera import router as camera_router, router2 as camera_status_router
|
||||
from api.roi import router as roi_router
|
||||
from api.sync import router as sync_router
|
||||
from config import get_config, load_config
|
||||
@@ -82,6 +95,7 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
app.include_router(camera_router)
|
||||
app.include_router(camera_status_router)
|
||||
app.include_router(roi_router)
|
||||
app.include_router(alarm_router)
|
||||
app.include_router(sync_router)
|
||||
@@ -134,6 +148,23 @@ async def get_snapshot_base64(camera_id: int):
|
||||
return {"image": base64.b64encode(buffer).decode("utf-8")}
|
||||
|
||||
|
||||
@app.get("/api/alarms/{alarm_id}/snapshot")
|
||||
async def get_alarm_snapshot(alarm_id: int, db: Session = Depends(get_db)):
|
||||
from db.models import Alarm
|
||||
|
||||
alarm = db.query(Alarm).filter(Alarm.id == alarm_id).first()
|
||||
if not alarm:
|
||||
raise HTTPException(status_code=404, detail="告警不存在")
|
||||
|
||||
if not alarm.snapshot_path:
|
||||
raise HTTPException(status_code=404, detail="该告警没有截图")
|
||||
|
||||
if not os.path.exists(alarm.snapshot_path):
|
||||
raise HTTPException(status_code=404, detail="截图文件不存在")
|
||||
|
||||
return FileResponse(alarm.snapshot_path, media_type="image/jpeg")
|
||||
|
||||
|
||||
@app.get("/api/camera/{camera_id}/detect")
|
||||
async def get_detection_with_overlay(camera_id: int):
|
||||
pipeline = get_pipeline()
|
||||
|
||||
71
migrate_db.py
Normal file
71
migrate_db.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import sqlite3
|
||||
|
||||
db_path = 'security_monitor.db'
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
def add_column(table_name, col_name, col_type, default_value=None):
|
||||
try:
|
||||
if default_value:
|
||||
cursor.execute(f'ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type} DEFAULT {default_value}')
|
||||
else:
|
||||
cursor.execute(f'ALTER TABLE {table_name} ADD COLUMN {col_name} {col_type}')
|
||||
print(f'添加列 {table_name}.{col_name} 成功')
|
||||
return True
|
||||
except sqlite3.OperationalError as e:
|
||||
if 'duplicate column name' in str(e):
|
||||
print(f'列 {table_name}.{col_name} 已存在')
|
||||
return True
|
||||
else:
|
||||
print(f'添加列 {table_name}.{col_name} 失败: {e}')
|
||||
return False
|
||||
|
||||
print('=== 数据库迁移脚本 ===')
|
||||
print()
|
||||
|
||||
# cameras 表
|
||||
print('更新 cameras 表:')
|
||||
add_column('cameras', 'cloud_id', 'INTEGER')
|
||||
add_column('cameras', 'pending_sync', 'BOOLEAN', '0')
|
||||
add_column('cameras', 'sync_failed_at', 'TIMESTAMP')
|
||||
add_column('cameras', 'sync_retry_count', 'INTEGER', '0')
|
||||
print()
|
||||
|
||||
# rois 表
|
||||
print('更新 rois 表:')
|
||||
add_column('rois', 'cloud_id', 'INTEGER')
|
||||
add_column('rois', 'pending_sync', 'BOOLEAN', '0')
|
||||
add_column('rois', 'sync_failed_at', 'TIMESTAMP')
|
||||
add_column('rois', 'sync_retry_count', 'INTEGER', '0')
|
||||
add_column('rois', 'sync_version', 'INTEGER', '0')
|
||||
print()
|
||||
|
||||
# alarms 表
|
||||
print('更新 alarms 表:')
|
||||
add_column('alarms', 'cloud_id', 'INTEGER')
|
||||
add_column('alarms', 'upload_status', "TEXT", "'pending_upload'")
|
||||
add_column('alarms', 'upload_retry_count', 'INTEGER', '0')
|
||||
add_column('alarms', 'error_message', 'TEXT')
|
||||
add_column('alarms', 'region_data', 'TEXT')
|
||||
add_column('alarms', 'llm_checked', 'BOOLEAN', '0')
|
||||
add_column('alarms', 'llm_result', 'TEXT')
|
||||
add_column('alarms', 'processed', 'BOOLEAN', '0')
|
||||
print()
|
||||
|
||||
# camera_status 表
|
||||
print('更新 camera_status 表:')
|
||||
add_column('camera_status', 'last_frame_time', 'TIMESTAMP')
|
||||
print()
|
||||
|
||||
conn.commit()
|
||||
|
||||
# 验证表结构
|
||||
print('=== 验证表结构 ===')
|
||||
for table in ['cameras', 'rois', 'alarms', 'camera_status']:
|
||||
cursor.execute(f'PRAGMA table_info({table})')
|
||||
cols = [col[1] for col in cursor.fetchall()]
|
||||
print(f'{table}: {len(cols)} 列')
|
||||
|
||||
conn.close()
|
||||
print()
|
||||
print('数据库迁移完成!')
|
||||
Binary file not shown.
1137
monitor.py
1137
monitor.py
File diff suppressed because it is too large
Load Diff
Binary file not shown.
461
services/sync_service.py
Normal file
461
services/sync_service.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
云端同步服务
|
||||
|
||||
实现"云端为主、本地为辅"的双层数据存储架构:
|
||||
- 配置双向同步
|
||||
- 报警单向上报
|
||||
- 设备状态上报
|
||||
- 断网容错机制
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from queue import Queue, Empty
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncStatus(Enum):
|
||||
"""同步状态"""
|
||||
PENDING = "pending"
|
||||
SYNCING = "syncing"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class EntityType(Enum):
|
||||
"""实体类型"""
|
||||
CAMERA = "camera"
|
||||
ROI = "roi"
|
||||
ALARM = "alarm"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncTask:
|
||||
"""同步任务"""
|
||||
entity_type: EntityType
|
||||
entity_id: int
|
||||
operation: str # create, update, delete
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
status: SyncStatus = SyncStatus.PENDING
|
||||
retry_count: int = 0
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class CloudAPIClient:
|
||||
"""云端 API 客户端"""
|
||||
|
||||
def __init__(self, base_url: str, api_key: str, device_id: str):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.api_key = api_key
|
||||
self.device_id = device_id
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
'X-Device-ID': device_id
|
||||
})
|
||||
|
||||
def request(self, method: str, path: str, **kwargs) -> requests.Response:
|
||||
"""发送 API 请求"""
|
||||
url = f"{self.base_url}{path}"
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def get(self, path: str, **kwargs):
|
||||
return self.request('GET', path, **kwargs)
|
||||
|
||||
def post(self, path: str, **kwargs):
|
||||
return self.request('POST', path, **kwargs)
|
||||
|
||||
def put(self, path: str, **kwargs):
|
||||
return self.request('PUT', path, **kwargs)
|
||||
|
||||
def delete(self, path: str, **kwargs):
|
||||
return self.request('DELETE', path, **kwargs)
|
||||
|
||||
|
||||
class SyncService:
|
||||
"""云端同步服务"""
|
||||
|
||||
def __init__(self):
|
||||
config = get_config()
|
||||
self.config = config
|
||||
|
||||
# 云端配置
|
||||
self.cloud_enabled = config.cloud.enabled
|
||||
self.cloud_url = config.cloud.api_url
|
||||
self.api_key = config.cloud.api_key
|
||||
self.device_id = config.cloud.device_id
|
||||
|
||||
# 同步配置
|
||||
self.sync_interval = config.cloud.sync_interval
|
||||
self.alarm_retry_interval = config.cloud.alarm_retry_interval
|
||||
self.status_report_interval = config.cloud.status_report_interval
|
||||
self.max_retries = config.cloud.max_retries
|
||||
|
||||
# 客户端
|
||||
self.client: Optional[CloudAPIClient] = None
|
||||
if self.cloud_enabled:
|
||||
self.client = CloudAPIClient(
|
||||
self.cloud_url,
|
||||
self.api_key,
|
||||
self.device_id
|
||||
)
|
||||
|
||||
# 任务队列
|
||||
self.sync_queue: Queue = Queue()
|
||||
self.alarm_queue: Queue = Queue()
|
||||
|
||||
# 状态
|
||||
self.running = False
|
||||
self.threads: List[threading.Thread] = []
|
||||
self.network_status = "disconnected"
|
||||
|
||||
# 重试配置
|
||||
self.retry_delays = [60, 300, 900, 3600] # 1分钟, 5分钟, 15分钟, 1小时
|
||||
|
||||
def start(self):
|
||||
"""启动同步服务"""
|
||||
if self.running:
|
||||
logger.warning("同步服务已在运行")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
if self.cloud_enabled:
|
||||
logger.info(f"启动云端同步服务,设备ID: {self.device_id}")
|
||||
else:
|
||||
logger.info("云端同步已禁用,使用本地模式")
|
||||
|
||||
# 启动工作线程
|
||||
self.threads.append(threading.Thread(target=self._sync_worker, daemon=True))
|
||||
self.threads.append(threading.Thread(target=self._alarm_worker, daemon=True))
|
||||
self.threads.append(threading.Thread(target=self._status_worker, daemon=True))
|
||||
|
||||
for thread in self.threads:
|
||||
thread.start()
|
||||
|
||||
logger.info("同步服务已启动")
|
||||
|
||||
def stop(self):
|
||||
"""停止同步服务"""
|
||||
self.running = False
|
||||
|
||||
for thread in self.threads:
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=5)
|
||||
|
||||
logger.info("同步服务已停止")
|
||||
|
||||
def _sync_worker(self):
|
||||
"""配置同步工作线程"""
|
||||
while self.running:
|
||||
try:
|
||||
task = self.sync_queue.get(timeout=1)
|
||||
self._execute_sync(task)
|
||||
except Empty:
|
||||
self._check_network_status()
|
||||
except Exception as e:
|
||||
logger.error(f"同步工作线程异常: {e}")
|
||||
|
||||
def _alarm_worker(self):
|
||||
"""报警上报工作线程"""
|
||||
while self.running:
|
||||
try:
|
||||
alarm_id = self.alarm_queue.get(timeout=1)
|
||||
self._upload_alarm(alarm_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"报警上报工作线程异常: {e}")
|
||||
|
||||
def _status_worker(self):
|
||||
"""状态上报工作线程"""
|
||||
while self.running:
|
||||
try:
|
||||
if self.network_status == "connected":
|
||||
self._report_status()
|
||||
except Exception as e:
|
||||
logger.error(f"状态上报失败: {e}")
|
||||
time.sleep(self.status_report_interval)
|
||||
|
||||
def _check_network_status(self):
|
||||
"""检查网络状态"""
|
||||
if not self.cloud_enabled:
|
||||
self.network_status = "disabled"
|
||||
return
|
||||
|
||||
try:
|
||||
self.client.get('/health')
|
||||
self.network_status = "connected"
|
||||
except:
|
||||
self.network_status = "disconnected"
|
||||
|
||||
def _execute_sync(self, task: SyncTask):
|
||||
"""执行同步任务"""
|
||||
logger.info(f"执行同步任务: {task.entity_type.value}/{task.entity_id} ({task.operation})")
|
||||
|
||||
task.status = SyncStatus.SYNCING
|
||||
|
||||
try:
|
||||
if task.entity_type == EntityType.CAMERA:
|
||||
self._sync_camera(task)
|
||||
elif task.entity_type == EntityType.ROI:
|
||||
self._sync_roi(task)
|
||||
|
||||
task.status = SyncStatus.SUCCESS
|
||||
logger.info(f"同步成功: {task.entity_type.value}/{task.entity_id}")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._handle_sync_error(task, str(e))
|
||||
except Exception as e:
|
||||
task.status = SyncStatus.FAILED
|
||||
task.error_message = str(e)
|
||||
logger.error(f"同步失败: {task.entity_type.value}/{task.entity_id}: {e}")
|
||||
|
||||
def _handle_sync_error(self, task: SyncTask, error: str):
|
||||
"""处理同步错误"""
|
||||
task.retry_count += 1
|
||||
|
||||
if task.retry_count < self.max_retries:
|
||||
task.status = SyncStatus.RETRY
|
||||
delay = self.retry_delays[task.retry_count - 1]
|
||||
task.error_message = f"第{task.retry_count}次失败: {error}"
|
||||
logger.warning(f"同步重试 ({task.retry_count}/{self.max_retries}): {task.entity_type.value}/{task.entity_id}")
|
||||
# 重新入队
|
||||
time.sleep(delay)
|
||||
self.sync_queue.put(task)
|
||||
else:
|
||||
task.status = SyncStatus.FAILED
|
||||
task.error_message = f"已超过最大重试次数: {error}"
|
||||
logger.error(f"同步失败,已达最大重试次数: {task.entity_type.value}/{task.entity_id}")
|
||||
|
||||
def _sync_camera(self, task: SyncTask):
|
||||
"""同步摄像头配置"""
|
||||
if task.operation == 'update':
|
||||
self.client.put(f"/api/v1/cameras/{task.entity_id}", json=task.data)
|
||||
elif task.operation == 'delete':
|
||||
self.client.delete(f"/api/v1/cameras/{task.entity_id}")
|
||||
|
||||
def _sync_roi(self, task: SyncTask):
|
||||
"""同步 ROI 配置"""
|
||||
if task.operation == 'update':
|
||||
self.client.put(f"/api/v1/rois/{task.entity_id}", json=task.data)
|
||||
elif task.operation == 'delete':
|
||||
self.client.delete(f"/api/v1/rois/{task.entity_id}")
|
||||
|
||||
def _upload_alarm(self, alarm_id: int):
|
||||
"""上传报警记录"""
|
||||
from db.crud import get_alarm_by_id, update_alarm_status
|
||||
from db.models import get_session_factory
|
||||
|
||||
SessionLocal = get_session_factory()
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
alarm = get_alarm_by_id(db, alarm_id)
|
||||
if not alarm:
|
||||
logger.warning(f"报警记录不存在: {alarm_id}")
|
||||
return
|
||||
|
||||
# 准备数据
|
||||
alarm_data = {
|
||||
'device_id': self.device_id,
|
||||
'camera_id': alarm.camera_id,
|
||||
'alarm_type': alarm.event_type,
|
||||
'confidence': alarm.confidence,
|
||||
'timestamp': alarm.created_at.isoformat() if alarm.created_at else None,
|
||||
'region': alarm.region_data
|
||||
}
|
||||
|
||||
# 上传图片
|
||||
if alarm.snapshot_path and os.path.exists(alarm.snapshot_path):
|
||||
with open(alarm.snapshot_path, 'rb') as f:
|
||||
files = {'file': f}
|
||||
response = self.client.post('/api/v1/alarms/images', files=files)
|
||||
alarm_data['image_url'] = response.json().get('data', {}).get('url')
|
||||
|
||||
# 上报报警
|
||||
response = self.client.post('/api/v1/alarms/report', json=alarm_data)
|
||||
cloud_id = response.json().get('data', {}).get('alarm_id')
|
||||
|
||||
# 更新本地状态
|
||||
update_alarm_status(db, alarm_id, status='uploaded', cloud_id=cloud_id)
|
||||
logger.info(f"报警上报成功: {alarm_id} -> 云端ID: {cloud_id}")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
update_alarm_status(db, alarm_id, status='retry', error_message=str(e))
|
||||
self.alarm_queue.put(alarm_id) # 重试
|
||||
except Exception as e:
|
||||
update_alarm_status(db, alarm_id, status='failed', error_message=str(e))
|
||||
logger.error(f"报警处理失败: {alarm_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _report_status(self):
|
||||
"""上报设备状态"""
|
||||
import psutil
|
||||
from db.crud import get_active_camera_count
|
||||
|
||||
try:
|
||||
metrics = {
|
||||
'device_id': self.device_id,
|
||||
'cpu_percent': psutil.cpu_percent(),
|
||||
'memory_percent': psutil.virtual_memory().percent,
|
||||
'disk_usage': psutil.disk_usage('/').percent,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
self.client.post('/api/v1/devices/status', json=metrics)
|
||||
logger.debug(f"设备状态上报成功: CPU={metrics['cpu_percent']}%")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"设备状态上报失败: {e}")
|
||||
|
||||
# 公共接口
|
||||
|
||||
def queue_camera_sync(self, camera_id: int, operation: str = 'update', data: Dict[str, Any] = None):
|
||||
"""将摄像头同步加入队列"""
|
||||
task = SyncTask(
|
||||
entity_type=EntityType.CAMERA,
|
||||
entity_id=camera_id,
|
||||
operation=operation,
|
||||
data=data
|
||||
)
|
||||
self.sync_queue.put(task)
|
||||
|
||||
def queue_roi_sync(self, roi_id: int, operation: str = 'update', data: Dict[str, Any] = None):
|
||||
"""将 ROI 同步加入队列"""
|
||||
task = SyncTask(
|
||||
entity_type=EntityType.ROI,
|
||||
entity_id=roi_id,
|
||||
operation=operation,
|
||||
data=data
|
||||
)
|
||||
self.sync_queue.put(task)
|
||||
|
||||
def queue_alarm_upload(self, alarm_id: int):
|
||||
"""将报警上传加入队列"""
|
||||
self.alarm_queue.put(alarm_id)
|
||||
|
||||
def sync_config_from_cloud(self, db: Session) -> Dict[str, int]:
|
||||
"""从云端拉取配置"""
|
||||
result = {'cameras': 0, 'rois': 0}
|
||||
|
||||
if not self.cloud_enabled:
|
||||
logger.info("云端同步已禁用,跳过配置拉取")
|
||||
return result
|
||||
|
||||
try:
|
||||
logger.info("从云端拉取配置...")
|
||||
|
||||
# 拉取设备配置
|
||||
response = self.client.get(f"/api/v1/devices/{self.device_id}/config")
|
||||
config = response.json().get('data', {})
|
||||
|
||||
# 处理摄像头
|
||||
cameras = config.get('cameras', [])
|
||||
for cloud_cam in cameras:
|
||||
self._merge_camera(db, cloud_cam)
|
||||
result['cameras'] += 1
|
||||
|
||||
logger.info(f"从云端拉取配置完成: {result['cameras']} 个摄像头")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"从云端拉取配置失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理云端配置时出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def _merge_camera(self, db: Session, cloud_data: Dict[str, Any]):
|
||||
"""合并摄像头配置"""
|
||||
from db.crud import get_camera_by_cloud_id, create_camera, update_camera
|
||||
from db.models import Camera
|
||||
|
||||
cloud_id = cloud_data.get('id')
|
||||
existing = get_camera_by_cloud_id(db, cloud_id)
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
if not existing.pending_sync:
|
||||
update_camera(db, existing.id, {
|
||||
'name': cloud_data.get('name'),
|
||||
'rtsp_url': cloud_data.get('rtsp_url'),
|
||||
'enabled': cloud_data.get('enabled', True),
|
||||
'fps_limit': cloud_data.get('fps_limit', 30),
|
||||
'process_every_n_frames': cloud_data.get('process_every_n_frames', 3),
|
||||
})
|
||||
else:
|
||||
# 创建新记录
|
||||
camera = create_camera(db, {
|
||||
'name': cloud_data.get('name'),
|
||||
'rtsp_url': cloud_data.get('rtsp_url'),
|
||||
'enabled': cloud_data.get('enabled', True),
|
||||
'fps_limit': cloud_data.get('fps_limit', 30),
|
||||
'process_every_n_frames': cloud_data.get('process_every_n_frames', 3),
|
||||
})
|
||||
# 更新 cloud_id
|
||||
camera.cloud_id = cloud_id
|
||||
db.commit()
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取同步服务状态"""
|
||||
return {
|
||||
'running': self.running,
|
||||
'cloud_enabled': self.cloud_enabled,
|
||||
'network_status': self.network_status,
|
||||
'device_id': self.device_id,
|
||||
'pending_sync': self.sync_queue.qsize(),
|
||||
'pending_alarms': self.alarm_queue.qsize(),
|
||||
}
|
||||
|
||||
|
||||
# 单例
|
||||
_sync_service: Optional[SyncService] = None
|
||||
|
||||
|
||||
def get_sync_service() -> SyncService:
|
||||
"""获取同步服务单例"""
|
||||
global _sync_service
|
||||
if _sync_service is None:
|
||||
_sync_service = SyncService()
|
||||
return _sync_service
|
||||
|
||||
|
||||
def start_sync_service():
|
||||
"""启动同步服务"""
|
||||
service = get_sync_service()
|
||||
service.start()
|
||||
|
||||
|
||||
def stop_sync_service():
|
||||
"""停止同步服务"""
|
||||
global _sync_service
|
||||
if _sync_service:
|
||||
_sync_service.stop()
|
||||
_sync_service = None
|
||||
214
sort.py
214
sort.py
@@ -1,214 +0,0 @@
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from filterpy.kalman import KalmanFilter
|
||||
|
||||
def linear_assignment(cost_matrix):
|
||||
x, y = linear_sum_assignment(cost_matrix)
|
||||
return np.array(list(zip(x, y)))
|
||||
|
||||
def iou_batch(bb_test, bb_gt):
|
||||
"""
|
||||
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
|
||||
"""
|
||||
bb_test = np.expand_dims(bb_test, 1)
|
||||
bb_gt = np.expand_dims(bb_gt, 0)
|
||||
xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
|
||||
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
|
||||
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
|
||||
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
|
||||
w = np.maximum(0., xx2 - xx1)
|
||||
h = np.maximum(0., yy2 - yy1)
|
||||
wh = w * h
|
||||
o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
|
||||
+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
|
||||
return o
|
||||
|
||||
def convert_bbox_to_z(bbox):
|
||||
"""
|
||||
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
|
||||
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
|
||||
the aspect ratio
|
||||
"""
|
||||
w = bbox[2] - bbox[0]
|
||||
h = bbox[3] - bbox[1]
|
||||
x = bbox[0] + w / 2.
|
||||
y = bbox[1] + h / 2.
|
||||
s = w * h # scale is just area
|
||||
r = w / float(h)
|
||||
return np.array([x, y, s, r]).reshape((4, 1))
|
||||
|
||||
def convert_x_to_bbox(x, score=None):
|
||||
"""
|
||||
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
|
||||
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
|
||||
"""
|
||||
w = np.sqrt(x[2] * x[3])
|
||||
h = x[2] / w
|
||||
if score is None:
|
||||
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4))
|
||||
else:
|
||||
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5))
|
||||
|
||||
class KalmanBoxTracker(object):
|
||||
"""
|
||||
This class represents the internal state of individual tracked objects observed as bbox.
|
||||
"""
|
||||
count = 0
|
||||
|
||||
def __init__(self, bbox):
|
||||
"""
|
||||
Initialises a tracker using initial bounding box.
|
||||
"""
|
||||
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
||||
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
|
||||
[0, 1, 0, 0, 0, 1, 0],
|
||||
[0, 0, 1, 0, 0, 0, 1],
|
||||
[0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1]])
|
||||
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0]])
|
||||
|
||||
self.kf.R[2:, 2:] *= 10.
|
||||
self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
|
||||
self.kf.P *= 10.
|
||||
self.kf.Q[-1, -1] *= 0.01
|
||||
self.kf.Q[4:, 4:] *= 0.01
|
||||
|
||||
self.kf.x[:4] = convert_bbox_to_z(bbox)
|
||||
self.time_since_update = 0
|
||||
self.id = KalmanBoxTracker.count
|
||||
KalmanBoxTracker.count += 1
|
||||
self.history = []
|
||||
self.hits = 0
|
||||
self.hit_streak = 0
|
||||
self.age = 0
|
||||
|
||||
def update(self, bbox):
|
||||
"""
|
||||
Updates the state vector with observed bbox.
|
||||
"""
|
||||
self.time_since_update = 0
|
||||
self.history = []
|
||||
self.hits += 1
|
||||
self.hit_streak += 1
|
||||
self.kf.update(convert_bbox_to_z(bbox))
|
||||
|
||||
def predict(self):
|
||||
"""
|
||||
Advances the state vector and returns the predicted bounding box estimate.
|
||||
"""
|
||||
if (self.kf.x[6] + self.kf.x[2]) <= 0:
|
||||
self.kf.x[6] *= 0.0
|
||||
self.kf.predict()
|
||||
self.age += 1
|
||||
if self.time_since_update > 0:
|
||||
self.hit_streak = 0
|
||||
self.time_since_update += 1
|
||||
self.history.append(convert_x_to_bbox(self.kf.x))
|
||||
return self.history[-1]
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Returns the current bounding box estimate.
|
||||
"""
|
||||
return convert_x_to_bbox(self.kf.x)
|
||||
|
||||
class Sort(object):
|
||||
def __init__(self, max_age=5, min_hits=2, iou_threshold=0.3):
|
||||
"""
|
||||
Sets key parameters for SORT
|
||||
"""
|
||||
self.max_age = max_age
|
||||
self.min_hits = min_hits
|
||||
self.iou_threshold = iou_threshold
|
||||
self.trackers = []
|
||||
self.frame_count = 0
|
||||
|
||||
def update(self, dets=np.empty((0, 5))):
|
||||
"""
|
||||
Params:
|
||||
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],...]
|
||||
Requires: this method must be called once for each frame even with empty detections.
|
||||
Returns the a similar array, where the last column is the object ID.
|
||||
|
||||
NOTE: The number of objects returned may differ from the number of detections provided.
|
||||
"""
|
||||
self.frame_count += 1
|
||||
trks = np.zeros((len(self.trackers), 5))
|
||||
to_del = []
|
||||
ret = []
|
||||
for t, trk in enumerate(trks):
|
||||
pos = self.trackers[t].predict()[0]
|
||||
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
|
||||
if np.any(np.isnan(pos)):
|
||||
to_del.append(t)
|
||||
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
||||
for t in reversed(to_del):
|
||||
self.trackers.pop(t)
|
||||
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
|
||||
|
||||
# update matched trackers with assigned detections
|
||||
for m in matched:
|
||||
self.trackers[m[1]].update(dets[m[0], :])
|
||||
|
||||
# create and initialise new trackers for unmatched detections
|
||||
for i in unmatched_dets:
|
||||
trk = KalmanBoxTracker(dets[i, :])
|
||||
self.trackers.append(trk)
|
||||
i = len(self.trackers)
|
||||
for trk in reversed(self.trackers):
|
||||
d = trk.get_state()[0]
|
||||
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
|
||||
ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive
|
||||
i -= 1
|
||||
# remove dead tracklet
|
||||
if trk.time_since_update > self.max_age:
|
||||
self.trackers.pop(i)
|
||||
if len(ret) > 0:
|
||||
return np.concatenate(ret)
|
||||
return np.empty((0, 5))
|
||||
|
||||
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
|
||||
"""
|
||||
Assigns detections to tracked object (both represented as bounding boxes)
|
||||
Returns 3 lists of matches, unmatched_detections, unmatched_trackers
|
||||
"""
|
||||
if len(trackers) == 0:
|
||||
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 1), dtype=int)
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(-iou_matrix)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0, 2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if d not in matched_indices[:, 0]:
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if t not in matched_indices[:, 1]:
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if iou_matrix[m[0], m[1]] < iou_threshold:
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1, 2))
|
||||
if len(matches) == 0:
|
||||
matches = np.empty((0, 2), dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches, axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
@@ -58,7 +58,7 @@ def test_leave_post_algorithm_process():
|
||||
{"bbox": [100, 100, 200, 200], "conf": 0.9, "cls": 0},
|
||||
]
|
||||
|
||||
alerts = algo.process("test_cam", tracks, datetime.now())
|
||||
alerts = algo.process("roi_1", "test_cam", tracks, datetime.now())
|
||||
assert isinstance(alerts, list)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user