ROI选区01
This commit is contained in:
18
.env.example
Normal file
18
.env.example
Normal file
@@ -0,0 +1,18 @@
|
||||
# Redis 配置(可选,用于分布式部署)
|
||||
|
||||
# 连接配置
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_DB=0
|
||||
|
||||
# 事件队列配置
|
||||
REDIS_EVENT_QUEUE=security_events
|
||||
REDIS_ALARM_QUEUE=security_alarms
|
||||
|
||||
# 大模型 API 配置
|
||||
LLM_API_KEY=your-api-key
|
||||
LLM_BASE_URL=https://api.openai.com/v1
|
||||
LLM_MODEL=gpt-4-vision-preview
|
||||
|
||||
# 其他配置
|
||||
LOG_LEVEL=INFO
|
||||
39
Dockerfile
Normal file
39
Dockerfile
Normal file
@@ -0,0 +1,39 @@
|
||||
FROM nvidia/cuda:12.2-devel-ubuntu22.04
|
||||
|
||||
LABEL maintainer="security-monitor"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
git \
|
||||
wget \
|
||||
libgl1-mesa-glx \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender1 \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN python3.10 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --upgrade pip wheel && \
|
||||
pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN mkdir -p /app/models /app/data /app/logs
|
||||
|
||||
EXPOSE 8000 9090
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
136
README.md
Normal file
136
README.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# 安保异常行为识别系统
|
||||
|
||||
## 项目概述
|
||||
|
||||
基于 YOLO + TensorRT 的安保异常行为识别系统,支持离岗检测和周界入侵检测。
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 生成 TensorRT 引擎
|
||||
|
||||
```bash
|
||||
python scripts/build_engine.py --model models/yolo11n.pt --fp16
|
||||
```
|
||||
|
||||
### 3. 配置数据库
|
||||
|
||||
编辑 `config.yaml` 配置数据库连接。
|
||||
|
||||
### 4. 启动服务
|
||||
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
### 5. 访问前端
|
||||
|
||||
打开浏览器访问 `http://localhost:3000`
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
project_root/
|
||||
├── main.py # FastAPI 入口
|
||||
├── config.yaml # 配置文件
|
||||
├── requirements.txt # Python 依赖
|
||||
├── inference/
|
||||
│ ├── engine.py # TensorRT 引擎封装
|
||||
│ ├── stream.py # RTSP 流处理
|
||||
│ ├── pipeline.py # 推理主流程
|
||||
│ ├── roi/
|
||||
│ │ └── roi_filter.py # ROI 过滤
|
||||
│ └── rules/
|
||||
│ └── algorithms.py # 规则算法
|
||||
├── db/
|
||||
│ ├── models.py # SQLAlchemy ORM 模型
|
||||
│ └── crud.py # 数据库操作
|
||||
├── api/
|
||||
│ ├── camera.py # 摄像头管理接口
|
||||
│ ├── roi.py # ROI 管理接口
|
||||
│ └── alarm.py # 告警管理接口
|
||||
├── utils/
|
||||
│ ├── logger.py # 日志工具
|
||||
│ ├── helpers.py # 辅助函数
|
||||
│ └── metrics.py # Prometheus 监控
|
||||
├── frontend/ # React 前端
|
||||
├── scripts/
|
||||
│ └── build_engine.py # TensorRT 引擎构建脚本
|
||||
├── tests/ # 单元测试
|
||||
└── Dockerfile # Docker 配置
|
||||
```
|
||||
|
||||
## API 接口
|
||||
|
||||
### 摄像头管理
|
||||
|
||||
- `GET /api/cameras` - 获取摄像头列表
|
||||
- `POST /api/cameras` - 添加摄像头
|
||||
- `PUT /api/cameras/{id}` - 更新摄像头
|
||||
- `DELETE /api/cameras/{id}` - 删除摄像头
|
||||
|
||||
### ROI 管理
|
||||
|
||||
- `GET /api/camera/{id}/rois` - 获取摄像头 ROI 列表
|
||||
- `POST /api/camera/{id}/roi` - 添加 ROI
|
||||
- `PUT /api/camera/{id}/roi/{roi_id}` - 更新 ROI
|
||||
- `DELETE /api/camera/{id}/roi/{roi_id}` - 删除 ROI
|
||||
|
||||
### 告警管理
|
||||
|
||||
- `GET /api/alarms` - 获取告警列表
|
||||
- `GET /api/alarms/stats` - 获取告警统计
|
||||
- `PUT /api/alarms/{id}` - 更新告警状态
|
||||
- `POST /api/alarms/{id}/llm-check` - 触发大模型检查
|
||||
|
||||
### 其他接口
|
||||
|
||||
- `GET /api/camera/{id}/snapshot` - 获取实时截图
|
||||
- `GET /api/camera/{id}/detect` - 获取带检测框的截图
|
||||
- `GET /api/pipeline/status` - 获取 Pipeline 状态
|
||||
- `GET /health` - 健康检查
|
||||
|
||||
## 配置说明
|
||||
|
||||
### config.yaml
|
||||
|
||||
```yaml
|
||||
database:
|
||||
dialect: "sqlite" # sqlite 或 mysql
|
||||
name: "security_monitor"
|
||||
|
||||
model:
|
||||
engine_path: "models/yolo11n_fp16_480.engine"
|
||||
imgsz: [480, 480]
|
||||
batch_size: 8
|
||||
half: true
|
||||
|
||||
stream:
|
||||
buffer_size: 2
|
||||
reconnect_delay: 3.0
|
||||
|
||||
alert:
|
||||
snapshot_path: "data/alerts"
|
||||
cooldown_sec: 300
|
||||
```
|
||||
|
||||
## Docker 部署
|
||||
|
||||
```bash
|
||||
docker-compose up -d
|
||||
```
|
||||
|
||||
## 监控指标
|
||||
|
||||
系统暴露 Prometheus 格式的监控指标:
|
||||
|
||||
- `camera_count` - 活跃摄像头数量
|
||||
- `camera_fps{camera_id="*"}` - 各摄像头 FPS
|
||||
- `inference_latency_seconds{camera_id="*"}` - 推理延迟
|
||||
- `alert_total{camera_id="*", event_type="*"}` - 告警总数
|
||||
- `event_queue_size` - 事件队列大小
|
||||
38
TRT_BUILD.md
Normal file
38
TRT_BUILD.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# TensorRT Engine 生成
|
||||
# 使用 trtexec 命令行工具
|
||||
|
||||
# 1. 导出 ONNX 模型
|
||||
trtexec \
|
||||
--onnx=yolo11n_480.onnx \
|
||||
--saveEngine=yolo11n_fp16_480.engine \
|
||||
--fp16 \
|
||||
--minShapes=input:1x3x480x480 \
|
||||
--optShapes=input:4x3x480x480 \
|
||||
--maxShapes=input:8x3x480x480 \
|
||||
--workspace=4096 \
|
||||
--verbose
|
||||
|
||||
# 或者使用优化器设置
|
||||
trtexec \
|
||||
--onnx=yolo11n_480.onnx \
|
||||
--saveEngine=yolo11n_fp16_480.engine \
|
||||
--fp16 \
|
||||
--optShapes=input:4x3x480x480 \
|
||||
--workspace=4096 \
|
||||
--builderOptimizationLevel=5 \
|
||||
--refit=False \
|
||||
--sparsity=False
|
||||
|
||||
# INT8 量化(需要校准数据)
|
||||
trtexec \
|
||||
--onnx=yolo11n_480.onnx \
|
||||
--saveEngine=yolo11n_int8_480.engine \
|
||||
--int8 \
|
||||
--calib=calibration.bin \
|
||||
--minShapes=input:1x3x480x480 \
|
||||
--optShapes=input:4x3x480x480 \
|
||||
--maxShapes=input:8x3x480x480 \
|
||||
--workspace=4096
|
||||
|
||||
# 验证引擎
|
||||
trtexec --loadEngine=yolo11n_fp16_480.engine --dumpOutput
|
||||
142
api/alarm.py
Normal file
142
api/alarm.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db.crud import (
|
||||
create_alarm,
|
||||
get_alarm_stats,
|
||||
get_alarms,
|
||||
update_alarm,
|
||||
)
|
||||
from db.models import get_db
|
||||
from inference.pipeline import get_pipeline
|
||||
|
||||
router = APIRouter(prefix="/api/alarms", tags=["告警管理"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[dict])
|
||||
def list_alarms(
|
||||
camera_id: Optional[int] = None,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = Query(default=100, le=1000),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
alarms = get_alarms(db, camera_id=camera_id, event_type=event_type, limit=limit, offset=offset)
|
||||
return [
|
||||
{
|
||||
"id": alarm.id,
|
||||
"camera_id": alarm.camera_id,
|
||||
"roi_id": alarm.roi_id,
|
||||
"event_type": alarm.event_type,
|
||||
"confidence": alarm.confidence,
|
||||
"snapshot_path": alarm.snapshot_path,
|
||||
"llm_checked": alarm.llm_checked,
|
||||
"llm_result": alarm.llm_result,
|
||||
"processed": alarm.processed,
|
||||
"created_at": alarm.created_at.isoformat() if alarm.created_at else None,
|
||||
}
|
||||
for alarm in alarms
|
||||
]
|
||||
|
||||
|
||||
@router.get("/stats")
|
||||
def get_alarm_statistics(db: Session = Depends(get_db)):
|
||||
stats = get_alarm_stats(db)
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/{alarm_id}", response_model=dict)
|
||||
def get_alarm(alarm_id: int, db: Session = Depends(get_db)):
|
||||
from db.crud import get_alarms
|
||||
alarms = get_alarms(db, limit=1)
|
||||
alarm = next((a for a in alarms if a.id == alarm_id), None)
|
||||
if not alarm:
|
||||
raise HTTPException(status_code=404, detail="告警不存在")
|
||||
return {
|
||||
"id": alarm.id,
|
||||
"camera_id": alarm.camera_id,
|
||||
"roi_id": alarm.roi_id,
|
||||
"event_type": alarm.event_type,
|
||||
"confidence": alarm.confidence,
|
||||
"snapshot_path": alarm.snapshot_path,
|
||||
"llm_checked": alarm.llm_checked,
|
||||
"llm_result": alarm.llm_result,
|
||||
"processed": alarm.processed,
|
||||
"created_at": alarm.created_at.isoformat() if alarm.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{alarm_id}")
|
||||
def update_alarm_status(
|
||||
alarm_id: int,
|
||||
llm_checked: Optional[bool] = None,
|
||||
llm_result: Optional[str] = None,
|
||||
processed: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
alarm = update_alarm(db, alarm_id, llm_checked=llm_checked, llm_result=llm_result, processed=processed)
|
||||
if not alarm:
|
||||
raise HTTPException(status_code=404, detail="告警不存在")
|
||||
return {"message": "更新成功"}
|
||||
|
||||
|
||||
@router.post("/{alarm_id}/llm-check")
|
||||
async def trigger_llm_check(alarm_id: int, db: Session = Depends(get_db)):
|
||||
from db.crud import get_alarms
|
||||
alarms = get_alarms(db, limit=1)
|
||||
alarm = next((a for a in alarms if a.id == alarm_id), None)
|
||||
if not alarm:
|
||||
raise HTTPException(status_code=404, detail="告警不存在")
|
||||
|
||||
if not alarm.snapshot_path or not os.path.exists(alarm.snapshot_path):
|
||||
raise HTTPException(status_code=400, detail="截图不存在")
|
||||
|
||||
try:
|
||||
from config import get_config
|
||||
config = get_config()
|
||||
if not config.llm.enabled:
|
||||
raise HTTPException(status_code=400, detail="大模型功能未启用")
|
||||
|
||||
import base64
|
||||
with open(alarm.snapshot_path, "rb") as f:
|
||||
img_base64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
from openai import OpenAI
|
||||
client = OpenAI(
|
||||
api_key=config.llm.api_key,
|
||||
base_url=config.llm.base_url,
|
||||
)
|
||||
|
||||
prompt = """分析这张监控截图,判断是否存在异常行为。请简要说明:
|
||||
1. 画面中是否有人
|
||||
2. 人员位置和行为
|
||||
3. 是否存在异常"""
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=config.llm.model,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content
|
||||
update_alarm(db, alarm_id, llm_checked=True, llm_result=result)
|
||||
|
||||
return {"message": "大模型分析完成", "result": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"大模型调用失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/queue/size")
|
||||
def get_event_queue_size():
|
||||
pipeline = get_pipeline()
|
||||
return {"size": len(pipeline.event_queue), "max_size": pipeline.config.inference.event_queue_maxlen}
|
||||
160
api/camera.py
Normal file
160
api/camera.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db.crud import (
|
||||
create_camera,
|
||||
delete_camera,
|
||||
get_all_cameras,
|
||||
get_camera_by_id,
|
||||
update_camera,
|
||||
)
|
||||
from db.models import get_db
|
||||
from inference.pipeline import get_pipeline
|
||||
|
||||
router = APIRouter(prefix="/api/cameras", tags=["摄像头管理"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[dict])
|
||||
def list_cameras(
|
||||
enabled_only: bool = True,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
cameras = get_all_cameras(db, enabled_only=enabled_only)
|
||||
return [
|
||||
{
|
||||
"id": cam.id,
|
||||
"name": cam.name,
|
||||
"rtsp_url": cam.rtsp_url,
|
||||
"enabled": cam.enabled,
|
||||
"fps_limit": cam.fps_limit,
|
||||
"process_every_n_frames": cam.process_every_n_frames,
|
||||
"created_at": cam.created_at.isoformat() if cam.created_at else None,
|
||||
}
|
||||
for cam in cameras
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{camera_id}", response_model=dict)
|
||||
def get_camera(camera_id: int, db: Session = Depends(get_db)):
|
||||
camera = get_camera_by_id(db, camera_id)
|
||||
if not camera:
|
||||
raise HTTPException(status_code=404, detail="摄像头不存在")
|
||||
return {
|
||||
"id": camera.id,
|
||||
"name": camera.name,
|
||||
"rtsp_url": camera.rtsp_url,
|
||||
"enabled": camera.enabled,
|
||||
"fps_limit": camera.fps_limit,
|
||||
"process_every_n_frames": camera.process_every_n_frames,
|
||||
"created_at": camera.created_at.isoformat() if camera.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.post("", response_model=dict)
|
||||
def add_camera(
|
||||
name: str,
|
||||
rtsp_url: str,
|
||||
fps_limit: int = 30,
|
||||
process_every_n_frames: int = 3,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
camera = create_camera(
|
||||
db,
|
||||
name=name,
|
||||
rtsp_url=rtsp_url,
|
||||
fps_limit=fps_limit,
|
||||
process_every_n_frames=process_every_n_frames,
|
||||
)
|
||||
|
||||
if camera.enabled:
|
||||
pipeline = get_pipeline()
|
||||
pipeline.add_camera(camera)
|
||||
|
||||
return {
|
||||
"id": camera.id,
|
||||
"name": camera.name,
|
||||
"rtsp_url": camera.rtsp_url,
|
||||
"enabled": camera.enabled,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{camera_id}", response_model=dict)
|
||||
def modify_camera(
|
||||
camera_id: int,
|
||||
name: Optional[str] = None,
|
||||
rtsp_url: Optional[str] = None,
|
||||
fps_limit: Optional[int] = None,
|
||||
process_every_n_frames: Optional[int] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
camera = update_camera(
|
||||
db,
|
||||
camera_id=camera_id,
|
||||
name=name,
|
||||
rtsp_url=rtsp_url,
|
||||
fps_limit=fps_limit,
|
||||
process_every_n_frames=process_every_n_frames,
|
||||
enabled=enabled,
|
||||
)
|
||||
if not camera:
|
||||
raise HTTPException(status_code=404, detail="摄像头不存在")
|
||||
|
||||
pipeline = get_pipeline()
|
||||
if enabled is True:
|
||||
pipeline.add_camera(camera)
|
||||
elif enabled is False:
|
||||
pipeline.remove_camera(camera_id)
|
||||
|
||||
return {
|
||||
"id": camera.id,
|
||||
"name": camera.name,
|
||||
"rtsp_url": camera.rtsp_url,
|
||||
"enabled": camera.enabled,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{camera_id}")
|
||||
def remove_camera(camera_id: int, db: Session = Depends(get_db)):
|
||||
pipeline = get_pipeline()
|
||||
pipeline.remove_camera(camera_id)
|
||||
|
||||
if not delete_camera(db, camera_id):
|
||||
raise HTTPException(status_code=404, detail="摄像头不存在")
|
||||
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.get("/{camera_id}/status")
|
||||
def get_camera_status(camera_id: int, db: Session = Depends(get_db)):
|
||||
from db.crud import get_camera_status as get_status
|
||||
|
||||
status = get_status(db, camera_id)
|
||||
pipeline = get_pipeline()
|
||||
|
||||
stream = pipeline.stream_manager.get_stream(str(camera_id))
|
||||
if stream:
|
||||
stream_info = stream.get_info()
|
||||
else:
|
||||
stream_info = None
|
||||
|
||||
if status:
|
||||
return {
|
||||
"camera_id": camera_id,
|
||||
"is_running": status.is_running,
|
||||
"fps": status.fps,
|
||||
"error_message": status.error_message,
|
||||
"last_check_time": status.last_check_time.isoformat() if status.last_check_time else None,
|
||||
"stream": stream_info,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"camera_id": camera_id,
|
||||
"is_running": False,
|
||||
"fps": 0.0,
|
||||
"error_message": None,
|
||||
"last_check_time": None,
|
||||
"stream": stream_info,
|
||||
}
|
||||
192
api/roi.py
Normal file
192
api/roi.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db.crud import (
|
||||
create_roi,
|
||||
delete_roi,
|
||||
get_all_rois,
|
||||
get_roi_by_id,
|
||||
get_roi_points,
|
||||
update_roi,
|
||||
)
|
||||
from db.models import get_db
|
||||
from inference.pipeline import get_pipeline
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
|
||||
router = APIRouter(prefix="/api/camera", tags=["ROI管理"])
|
||||
|
||||
|
||||
@router.get("/{camera_id}/rois", response_model=List[dict])
|
||||
def list_rois(camera_id: int, db: Session = Depends(get_db)):
|
||||
roi_configs = get_all_rois(db, camera_id)
|
||||
return [
|
||||
{
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"name": roi.name,
|
||||
"type": roi.roi_type,
|
||||
"points": json.loads(roi.points),
|
||||
"rule": roi.rule_type,
|
||||
"direction": roi.direction,
|
||||
"stay_time": roi.stay_time,
|
||||
"enabled": roi.enabled,
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
}
|
||||
for roi in roi_configs
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{camera_id}/roi/{roi_id}", response_model=dict)
|
||||
def get_roi(camera_id: int, roi_id: int, db: Session = Depends(get_db)):
|
||||
roi = get_roi_by_id(db, roi_id)
|
||||
if not roi:
|
||||
raise HTTPException(status_code=404, detail="ROI不存在")
|
||||
return {
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"name": roi.name,
|
||||
"type": roi.roi_type,
|
||||
"points": json.loads(roi.points),
|
||||
"rule": roi.rule_type,
|
||||
"direction": roi.direction,
|
||||
"stay_time": roi.stay_time,
|
||||
"enabled": roi.enabled,
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{camera_id}/roi", response_model=dict)
|
||||
def add_roi(
|
||||
camera_id: int,
|
||||
roi_id: str,
|
||||
name: str,
|
||||
roi_type: str,
|
||||
points: List[List[float]],
|
||||
rule_type: str,
|
||||
direction: Optional[str] = None,
|
||||
stay_time: Optional[int] = None,
|
||||
threshold_sec: int = 360,
|
||||
confirm_sec: int = 30,
|
||||
return_sec: int = 5,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
roi = create_roi(
|
||||
db,
|
||||
camera_id=camera_id,
|
||||
roi_id=roi_id,
|
||||
name=name,
|
||||
roi_type=roi_type,
|
||||
points=points,
|
||||
rule_type=rule_type,
|
||||
direction=direction,
|
||||
stay_time=stay_time,
|
||||
threshold_sec=threshold_sec,
|
||||
confirm_sec=confirm_sec,
|
||||
return_sec=return_sec,
|
||||
)
|
||||
|
||||
pipeline = get_pipeline()
|
||||
pipeline.roi_filter.update_cache(camera_id, get_roi_points(db, camera_id))
|
||||
|
||||
return {
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"name": roi.name,
|
||||
"type": roi.roi_type,
|
||||
"points": points,
|
||||
"rule": roi.rule_type,
|
||||
"enabled": roi.enabled,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/{camera_id}/roi/{roi_id}", response_model=dict)
|
||||
def modify_roi(
|
||||
camera_id: int,
|
||||
roi_id: int,
|
||||
name: Optional[str] = None,
|
||||
points: Optional[List[List[float]]] = None,
|
||||
rule_type: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
stay_time: Optional[int] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
threshold_sec: Optional[int] = None,
|
||||
confirm_sec: Optional[int] = None,
|
||||
return_sec: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
roi = update_roi(
|
||||
db,
|
||||
roi_id=roi_id,
|
||||
name=name,
|
||||
points=points,
|
||||
rule_type=rule_type,
|
||||
direction=direction,
|
||||
stay_time=stay_time,
|
||||
enabled=enabled,
|
||||
threshold_sec=threshold_sec,
|
||||
confirm_sec=confirm_sec,
|
||||
return_sec=return_sec,
|
||||
)
|
||||
if not roi:
|
||||
raise HTTPException(status_code=404, detail="ROI不存在")
|
||||
|
||||
pipeline = get_pipeline()
|
||||
pipeline.roi_filter.update_cache(camera_id, get_roi_points(db, camera_id))
|
||||
|
||||
return {
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"name": roi.name,
|
||||
"type": roi.roi_type,
|
||||
"points": json.loads(roi.points),
|
||||
"rule": roi.rule_type,
|
||||
"enabled": roi.enabled,
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/{camera_id}/roi/{roi_id}")
|
||||
def remove_roi(camera_id: int, roi_id: int, db: Session = Depends(get_db)):
|
||||
if not delete_roi(db, roi_id):
|
||||
raise HTTPException(status_code=404, detail="ROI不存在")
|
||||
|
||||
pipeline = get_pipeline()
|
||||
pipeline.roi_filter.update_cache(camera_id, get_roi_points(db, camera_id))
|
||||
|
||||
return {"message": "删除成功"}
|
||||
|
||||
|
||||
@router.get("/{camera_id}/roi/validate")
|
||||
def validate_roi(
|
||||
camera_id: int,
|
||||
roi_type: str,
|
||||
points: List[List[float]],
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
if roi_type == "polygon":
|
||||
from shapely.geometry import Polygon
|
||||
poly = Polygon(points)
|
||||
is_valid = poly.is_valid
|
||||
area = poly.area
|
||||
elif roi_type == "line":
|
||||
from shapely.geometry import LineString
|
||||
line = LineString(points)
|
||||
is_valid = True
|
||||
area = 0.0
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="不支持的ROI类型")
|
||||
|
||||
return {
|
||||
"valid": is_valid,
|
||||
"area": area,
|
||||
"message": "有效" if is_valid else "无效:自交或自重叠",
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
149
config.py
Normal file
149
config.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
dialect: str = "sqlite"
|
||||
host: str = "localhost"
|
||||
port: int = 3306
|
||||
username: str = "root"
|
||||
password: str = "password"
|
||||
name: str = "security_monitor"
|
||||
echo: bool = False
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
if self.dialect == "sqlite":
|
||||
return f"sqlite:///{self.name}.db"
|
||||
return f"mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.name}"
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
engine_path: str = "models/yolo11n_fp16_480.engine"
|
||||
pt_model_path: str = "models/yolo11n.pt"
|
||||
imgsz: List[int] = [480, 480]
|
||||
conf_threshold: float = 0.5
|
||||
iou_threshold: float = 0.45
|
||||
device: int = 0
|
||||
batch_size: int = 8
|
||||
half: bool = True
|
||||
|
||||
|
||||
class StreamConfig(BaseModel):
|
||||
buffer_size: int = 2
|
||||
reconnect_delay: float = 3.0
|
||||
timeout: float = 10.0
|
||||
fps_limit: int = 30
|
||||
|
||||
|
||||
class InferenceConfig(BaseModel):
|
||||
queue_maxlen: int = 100
|
||||
event_queue_maxlen: int = 1000
|
||||
worker_threads: int = 4
|
||||
|
||||
|
||||
class AlertConfig(BaseModel):
|
||||
snapshot_path: str = "data/alerts"
|
||||
cooldown_sec: int = 300
|
||||
image_quality: int = 85
|
||||
|
||||
|
||||
class ROIConfig(BaseModel):
|
||||
default_types: List[str] = ["polygon", "line"]
|
||||
max_points: int = 50
|
||||
|
||||
|
||||
class WorkingHours(BaseModel):
|
||||
start: List[int] = Field(default_factory=lambda: [8, 30])
|
||||
end: List[int] = Field(default_factory=lambda: [17, 30])
|
||||
|
||||
|
||||
class AlgorithmsConfig(BaseModel):
|
||||
leave_post_threshold_sec: int = 360
|
||||
leave_post_confirm_sec: int = 30
|
||||
leave_post_return_sec: int = 5
|
||||
intrusion_check_interval_sec: float = 1.0
|
||||
intrusion_direction_sensitive: bool = False
|
||||
|
||||
|
||||
class LoggingConfig(BaseModel):
|
||||
level: str = "INFO"
|
||||
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
file: str = "logs/app.log"
|
||||
max_size: str = "100MB"
|
||||
backup_count: int = 5
|
||||
|
||||
|
||||
class MonitoringConfig(BaseModel):
|
||||
enabled: bool = True
|
||||
port: int = 9090
|
||||
path: str = "/metrics"
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
model: str = "qwen-vl-max"
|
||||
timeout: int = 30
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
database: DatabaseConfig = Field(default_factory=DatabaseConfig)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||
stream: StreamConfig = Field(default_factory=StreamConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
alert: AlertConfig = Field(default_factory=AlertConfig)
|
||||
roi: ROIConfig = Field(default_factory=ROIConfig)
|
||||
working_hours: List[WorkingHours] = Field(default_factory=lambda: [
|
||||
WorkingHours(start=[8, 30], end=[11, 0]),
|
||||
WorkingHours(start=[12, 0], end=[17, 30])
|
||||
])
|
||||
algorithms: AlgorithmsConfig = Field(default_factory=AlgorithmsConfig)
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig)
|
||||
monitoring: MonitoringConfig = Field(default_factory=MonitoringConfig)
|
||||
llm: LLMConfig = Field(default_factory=LLMConfig)
|
||||
|
||||
|
||||
_config: Optional[Config] = None
|
||||
|
||||
|
||||
def load_config(config_path: Optional[str] = None) -> Config:
|
||||
"""加载配置文件"""
|
||||
global _config
|
||||
|
||||
if _config is not None:
|
||||
return _config
|
||||
|
||||
if config_path is None:
|
||||
config_path = os.environ.get(
|
||||
"CONFIG_PATH",
|
||||
str(Path(__file__).parent / "config.yaml")
|
||||
)
|
||||
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
yaml_config = yaml.safe_load(f) or {}
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
_config = Config(**yaml_config)
|
||||
return _config
|
||||
|
||||
|
||||
def get_config() -> Config:
|
||||
"""获取配置单例"""
|
||||
if _config is None:
|
||||
return load_config()
|
||||
return _config
|
||||
|
||||
|
||||
def reload_config(config_path: Optional[str] = None) -> Config:
|
||||
"""重新加载配置"""
|
||||
global _config
|
||||
_config = None
|
||||
return load_config(config_path)
|
||||
186
config.yaml
186
config.yaml
@@ -1,107 +1,87 @@
|
||||
# 安保异常行为识别系统 - 核心配置
|
||||
|
||||
# 数据库配置
|
||||
database:
|
||||
dialect: "sqlite" # sqlite 或 mysql
|
||||
host: "localhost"
|
||||
port: 3306
|
||||
username: "root"
|
||||
password: "password"
|
||||
name: "security_monitor"
|
||||
echo: true # SQL日志输出
|
||||
|
||||
# TensorRT模型配置
|
||||
model:
|
||||
path: "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
|
||||
imgsz: 480
|
||||
engine_path: "models/yolo11n_fp16_480.engine"
|
||||
pt_model_path: "models/yolo11n.pt"
|
||||
imgsz: [480, 480]
|
||||
conf_threshold: 0.5
|
||||
device: "cuda" # cuda, cpu
|
||||
iou_threshold: 0.45
|
||||
device: 0 # GPU设备号
|
||||
batch_size: 8 # 最大batch size
|
||||
half: true # FP16推理
|
||||
|
||||
# 大模型配置
|
||||
# RTSP流配置
|
||||
stream:
|
||||
buffer_size: 2 # 每路摄像头帧缓冲大小
|
||||
reconnect_delay: 3.0 # 重连延迟(秒)
|
||||
timeout: 10.0 # 连接超时(秒)
|
||||
fps_limit: 30 # 最大处理FPS
|
||||
|
||||
# 推理队列配置
|
||||
inference:
|
||||
queue_maxlen: 100 # 检测结果队列最大长度
|
||||
event_queue_maxlen: 1000 # 异常事件队列最大长度
|
||||
worker_threads: 4 # 处理线程数
|
||||
|
||||
# 告警配置
|
||||
alert:
|
||||
snapshot_path: "data/alerts"
|
||||
cooldown_sec: 300 # 同类型告警冷却时间
|
||||
image_quality: 85 # JPEG质量
|
||||
|
||||
# ROI配置
|
||||
roi:
|
||||
default_types:
|
||||
- "polygon"
|
||||
- "line"
|
||||
max_points: 50 # 多边形最大顶点数
|
||||
|
||||
# 工作时间配置(全局默认)
|
||||
working_hours:
|
||||
- start: [8, 30] # 8:30
|
||||
end: [11, 0] # 11:00
|
||||
- start: [12, 0] # 12:00
|
||||
end: [17, 30] # 17:30
|
||||
|
||||
# 算法默认参数
|
||||
algorithms:
|
||||
leave_post:
|
||||
default_threshold_sec: 360 # 离岗超时(6分钟)
|
||||
confirm_sec: 30 # 离岗确认时间
|
||||
return_sec: 5 # 上岗确认时间
|
||||
intrusion:
|
||||
check_interval_sec: 1.0 # 检测间隔
|
||||
direction_sensitive: false # 方向敏感
|
||||
|
||||
# 日志配置
|
||||
logging:
|
||||
level: "INFO"
|
||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
file: "logs/app.log"
|
||||
max_size: "100MB"
|
||||
backup_count: 5
|
||||
|
||||
# 监控配置
|
||||
monitoring:
|
||||
enabled: true
|
||||
port: 9090 # Prometheus metrics端口
|
||||
path: "/metrics"
|
||||
|
||||
# 大模型配置(预留)
|
||||
llm:
|
||||
api_key: "sk-21e61bef09074682b589da3bdbfe07a2" # 请替换为实际的API密钥(阿里云DashScope API Key)
|
||||
base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1/"
|
||||
model_name: "qwen3-vl-flash" # 模型名称,可选:qwen-vl-max, qwen-vl-plus, qwen3-vl-flash等
|
||||
|
||||
common:
|
||||
# 工作时间段:支持多个时间段,格式为 [开始小时, 开始分钟, 结束小时, 结束分钟]
|
||||
# 8:30-11:00, 12:00-17:30
|
||||
working_hours:
|
||||
- [8, 30, 11, 0] # 8:30-11:00
|
||||
- [12, 0, 17, 30] # 12:00-17:30
|
||||
process_every_n_frames: 3 # 每3帧处理1帧(用于人员离岗)
|
||||
alert_cooldown_sec: 300 # 离岗告警冷却(秒)
|
||||
off_duty_alert_threshold_sec: 360 # 离岗超过6分钟(360秒)触发告警
|
||||
|
||||
cameras:
|
||||
- id: "cam_01"
|
||||
rtsp_url: "rtsp://admin:admin@172.16.8.19:554/cam/realmonitor?channel=16&subtype=1"
|
||||
process_every_n_frames: 5
|
||||
rois:
|
||||
- name: "离岗检测区域"
|
||||
points: [[380, 50], [530, 100], [550, 550], [140, 420]]
|
||||
algorithms:
|
||||
- name: "人员离岗"
|
||||
enabled: true
|
||||
off_duty_threshold_sec: 300 # 离岗超时告警(秒)
|
||||
on_duty_confirm_sec: 5 # 上岗确认时间(秒)
|
||||
off_duty_confirm_sec: 30 # 离岗确认时间(秒)
|
||||
- name: "周界入侵"
|
||||
enabled: false
|
||||
# - name: "周界入侵区域1"
|
||||
# points: [[100, 100], [200, 100], [200, 300], [100, 300]]
|
||||
# algorithms:
|
||||
# - name: "人员离岗"
|
||||
# enabled: false
|
||||
# - name: "周界入侵"
|
||||
# enabled: false
|
||||
|
||||
- id: "cam_02"
|
||||
rtsp_url: "rtsp://admin:admin@172.16.8.13:554/cam/realmonitor?channel=7&subtype=1"
|
||||
rois:
|
||||
- name: "离岗检测区域"
|
||||
points: [[380, 50], [530, 100], [550, 550], [140, 420]]
|
||||
algorithms:
|
||||
- name: "人员离岗"
|
||||
enabled: true
|
||||
off_duty_threshold_sec: 600
|
||||
on_duty_confirm_sec: 10
|
||||
off_duty_confirm_sec: 20
|
||||
- name: "周界入侵"
|
||||
enabled: false
|
||||
|
||||
- id: "cam_03"
|
||||
rtsp_url: "rtsp://admin:admin@172.16.8.26:554/cam/realmonitor?channel=3&subtype=1"
|
||||
rois:
|
||||
- name: "离岗检测区域"
|
||||
points: [[380, 50], [530, 100], [550, 550], [140, 420]]
|
||||
algorithms:
|
||||
- name: "人员离岗"
|
||||
enabled: true
|
||||
off_duty_threshold_sec: 600
|
||||
on_duty_confirm_sec: 10
|
||||
off_duty_confirm_sec: 20
|
||||
- name: "周界入侵"
|
||||
enabled: false
|
||||
|
||||
- id: "cam_04"
|
||||
rtsp_url: "rtsp://admin:admin@172.16.8.20:554/cam/realmonitor?channel=14&subtype=1"
|
||||
process_every_n_frames: 5
|
||||
rois:
|
||||
- name: "离岗检测区域"
|
||||
# points: [ [ 380, 50 ], [ 530, 100 ], [ 550, 550 ], [ 140, 420 ] ]
|
||||
algorithms:
|
||||
- name: "人员离岗"
|
||||
enabled: false
|
||||
- name: "周界入侵"
|
||||
enabled: true
|
||||
- id: "cam_05"
|
||||
rtsp_url: "rtsp://admin:admin@172.16.8.31:554/cam/realmonitor?channel=15&subtype=1"
|
||||
process_every_n_frames: 5
|
||||
rois:
|
||||
- name: "离岗检测区域"
|
||||
# points: [ [ 380, 50 ], [ 530, 100 ], [ 550, 550 ], [ 140, 420 ] ]
|
||||
algorithms:
|
||||
- name: "人员离岗"
|
||||
enabled: false # 离岗确认时间(秒)
|
||||
- name: "周界入侵"
|
||||
enabled: true
|
||||
|
||||
- id: "cam_06"
|
||||
rtsp_url: "rtsp://admin:admin@172.16.8.35:554/cam/realmonitor?channel=13&subtype=1"
|
||||
process_every_n_frames: 5
|
||||
rois:
|
||||
- name: "离岗检测区域"
|
||||
points: [ [ 150, 100 ], [ 600, 100 ], [ 600, 500 ], [ 150, 500 ] ]
|
||||
algorithms:
|
||||
- name: "人员离岗"
|
||||
enabled: false
|
||||
- name: "周界入侵"
|
||||
enabled: true
|
||||
enabled: false
|
||||
api_key: ""
|
||||
base_url: ""
|
||||
model: "qwen3-vl-max"
|
||||
timeout: 30
|
||||
|
||||
314
db/crud.py
Normal file
314
db/crud.py
Normal file
@@ -0,0 +1,314 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from db.models import Camera, CameraStatus, ROI, Alarm
|
||||
|
||||
|
||||
def get_all_cameras(db: Session, enabled_only: bool = True) -> List[Camera]:
|
||||
query = db.query(Camera)
|
||||
if enabled_only:
|
||||
query = query.filter(Camera.enabled == True)
|
||||
return query.all()
|
||||
|
||||
|
||||
def get_camera_by_id(db: Session, camera_id: int) -> Optional[Camera]:
|
||||
return db.query(Camera).filter(Camera.id == camera_id).first()
|
||||
|
||||
|
||||
def create_camera(
|
||||
db: Session,
|
||||
name: str,
|
||||
rtsp_url: str,
|
||||
fps_limit: int = 30,
|
||||
process_every_n_frames: int = 3,
|
||||
) -> Camera:
|
||||
camera = Camera(
|
||||
name=name,
|
||||
rtsp_url=rtsp_url,
|
||||
fps_limit=fps_limit,
|
||||
process_every_n_frames=process_every_n_frames,
|
||||
)
|
||||
db.add(camera)
|
||||
db.commit()
|
||||
db.refresh(camera)
|
||||
return camera
|
||||
|
||||
|
||||
def update_camera(
|
||||
db: Session,
|
||||
camera_id: int,
|
||||
name: Optional[str] = None,
|
||||
rtsp_url: Optional[str] = None,
|
||||
fps_limit: Optional[int] = None,
|
||||
process_every_n_frames: Optional[int] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
) -> Optional[Camera]:
|
||||
camera = get_camera_by_id(db, camera_id)
|
||||
if not camera:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
camera.name = name
|
||||
if rtsp_url is not None:
|
||||
camera.rtsp_url = rtsp_url
|
||||
if fps_limit is not None:
|
||||
camera.fps_limit = fps_limit
|
||||
if process_every_n_frames is not None:
|
||||
camera.process_every_n_frames = process_every_n_frames
|
||||
if enabled is not None:
|
||||
camera.enabled = enabled
|
||||
|
||||
db.commit()
|
||||
db.refresh(camera)
|
||||
return camera
|
||||
|
||||
|
||||
def delete_camera(db: Session, camera_id: int) -> bool:
|
||||
camera = get_camera_by_id(db, camera_id)
|
||||
if not camera:
|
||||
return False
|
||||
|
||||
db.delete(camera)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def get_camera_status(db: Session, camera_id: int) -> Optional[CameraStatus]:
|
||||
return (
|
||||
db.query(CameraStatus).filter(CameraStatus.camera_id == camera_id).first()
|
||||
)
|
||||
|
||||
|
||||
def update_camera_status(
|
||||
db: Session,
|
||||
camera_id: int,
|
||||
is_running: Optional[bool] = None,
|
||||
fps: Optional[float] = None,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Optional[CameraStatus]:
|
||||
status = get_camera_status(db, camera_id)
|
||||
if not status:
|
||||
status = CameraStatus(camera_id=camera_id)
|
||||
db.add(status)
|
||||
|
||||
if is_running is not None:
|
||||
status.is_running = is_running
|
||||
if fps is not None:
|
||||
status.fps = fps
|
||||
if error_message is not None:
|
||||
status.error_message = error_message
|
||||
|
||||
status.last_check_time = datetime.utcnow()
|
||||
|
||||
db.commit()
|
||||
db.refresh(status)
|
||||
return status
|
||||
|
||||
|
||||
def get_all_rois(db: Session, camera_id: Optional[int] = None) -> List[ROI]:
|
||||
query = db.query(ROI)
|
||||
if camera_id is not None:
|
||||
query = query.filter(ROI.camera_id == camera_id)
|
||||
return query.filter(ROI.enabled == True).all()
|
||||
|
||||
|
||||
def get_roi_by_id(db: Session, roi_id: int) -> Optional[ROI]:
|
||||
return db.query(ROI).filter(ROI.id == roi_id).first()
|
||||
|
||||
|
||||
def get_roi_by_roi_id(db: Session, camera_id: int, roi_id: str) -> Optional[ROI]:
|
||||
return (
|
||||
db.query(ROI)
|
||||
.filter(ROI.camera_id == camera_id, ROI.roi_id == roi_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def create_roi(
|
||||
db: Session,
|
||||
camera_id: int,
|
||||
roi_id: str,
|
||||
name: str,
|
||||
roi_type: str,
|
||||
points: List[List[float]],
|
||||
rule_type: str,
|
||||
direction: Optional[str] = None,
|
||||
stay_time: Optional[int] = None,
|
||||
threshold_sec: int = 360,
|
||||
confirm_sec: int = 30,
|
||||
return_sec: int = 5,
|
||||
) -> ROI:
|
||||
import json
|
||||
|
||||
roi = ROI(
|
||||
camera_id=camera_id,
|
||||
roi_id=roi_id,
|
||||
name=name,
|
||||
roi_type=roi_type,
|
||||
points=json.dumps(points),
|
||||
rule_type=rule_type,
|
||||
direction=direction,
|
||||
stay_time=stay_time,
|
||||
threshold_sec=threshold_sec,
|
||||
confirm_sec=confirm_sec,
|
||||
return_sec=return_sec,
|
||||
)
|
||||
db.add(roi)
|
||||
db.commit()
|
||||
db.refresh(roi)
|
||||
return roi
|
||||
|
||||
|
||||
def update_roi(
|
||||
db: Session,
|
||||
roi_id: int,
|
||||
name: Optional[str] = None,
|
||||
points: Optional[List[List[float]]] = None,
|
||||
rule_type: Optional[str] = None,
|
||||
direction: Optional[str] = None,
|
||||
stay_time: Optional[int] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
threshold_sec: Optional[int] = None,
|
||||
confirm_sec: Optional[int] = None,
|
||||
return_sec: Optional[int] = None,
|
||||
) -> Optional[ROI]:
|
||||
import json
|
||||
|
||||
roi = get_roi_by_id(db, roi_id)
|
||||
if not roi:
|
||||
return None
|
||||
|
||||
if name is not None:
|
||||
roi.name = name
|
||||
if points is not None:
|
||||
roi.points = json.dumps(points)
|
||||
if rule_type is not None:
|
||||
roi.rule_type = rule_type
|
||||
if direction is not None:
|
||||
roi.direction = direction
|
||||
if stay_time is not None:
|
||||
roi.stay_time = stay_time
|
||||
if enabled is not None:
|
||||
roi.enabled = enabled
|
||||
if threshold_sec is not None:
|
||||
roi.threshold_sec = threshold_sec
|
||||
if confirm_sec is not None:
|
||||
roi.confirm_sec = confirm_sec
|
||||
if return_sec is not None:
|
||||
roi.return_sec = return_sec
|
||||
|
||||
db.commit()
|
||||
db.refresh(roi)
|
||||
return roi
|
||||
|
||||
|
||||
def delete_roi(db: Session, roi_id: int) -> bool:
|
||||
roi = get_roi_by_id(db, roi_id)
|
||||
if not roi:
|
||||
return False
|
||||
|
||||
db.delete(roi)
|
||||
db.commit()
|
||||
return True
|
||||
|
||||
|
||||
def get_roi_points(db: Session, camera_id: int) -> List[dict]:
|
||||
import json
|
||||
|
||||
rois = get_all_rois(db, camera_id)
|
||||
return [
|
||||
{
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"name": roi.name,
|
||||
"type": roi.roi_type,
|
||||
"points": json.loads(roi.points),
|
||||
"rule": roi.rule_type,
|
||||
"direction": roi.direction,
|
||||
"stay_time": roi.stay_time,
|
||||
"enabled": roi.enabled,
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
}
|
||||
for roi in rois
|
||||
]
|
||||
|
||||
|
||||
def create_alarm(
|
||||
db: Session,
|
||||
camera_id: int,
|
||||
event_type: str,
|
||||
confidence: float = 0.0,
|
||||
snapshot_path: Optional[str] = None,
|
||||
roi_id: Optional[str] = None,
|
||||
llm_checked: bool = False,
|
||||
llm_result: Optional[str] = None,
|
||||
) -> Alarm:
|
||||
alarm = Alarm(
|
||||
camera_id=camera_id,
|
||||
roi_id=roi_id,
|
||||
event_type=event_type,
|
||||
confidence=confidence,
|
||||
snapshot_path=snapshot_path,
|
||||
llm_checked=llm_checked,
|
||||
llm_result=llm_result,
|
||||
)
|
||||
db.add(alarm)
|
||||
db.commit()
|
||||
db.refresh(alarm)
|
||||
return alarm
|
||||
|
||||
|
||||
def get_alarms(
|
||||
db: Session,
|
||||
camera_id: Optional[int] = None,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> List[Alarm]:
|
||||
query = db.query(Alarm)
|
||||
if camera_id is not None:
|
||||
query = query.filter(Alarm.camera_id == camera_id)
|
||||
if event_type is not None:
|
||||
query = query.filter(Alarm.event_type == event_type)
|
||||
return query.order_by(Alarm.created_at.desc()).offset(offset).limit(limit).all()
|
||||
|
||||
|
||||
def update_alarm(
|
||||
db: Session,
|
||||
alarm_id: int,
|
||||
llm_checked: Optional[bool] = None,
|
||||
llm_result: Optional[str] = None,
|
||||
processed: Optional[bool] = None,
|
||||
) -> Optional[Alarm]:
|
||||
alarm = db.query(Alarm).filter(Alarm.id == alarm_id).first()
|
||||
if not alarm:
|
||||
return None
|
||||
|
||||
if llm_checked is not None:
|
||||
alarm.llm_checked = llm_checked
|
||||
if llm_result is not None:
|
||||
alarm.llm_result = llm_result
|
||||
if processed is not None:
|
||||
alarm.processed = processed
|
||||
|
||||
db.commit()
|
||||
db.refresh(alarm)
|
||||
return alarm
|
||||
|
||||
|
||||
def get_alarm_stats(db: Session) -> dict:
|
||||
total = db.query(Alarm).count()
|
||||
unprocessed = db.query(Alarm).filter(Alarm.processed == False).count()
|
||||
llm_pending = db.query(Alarm).filter(
|
||||
Alarm.llm_checked == False, Alarm.processed == False
|
||||
).count()
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"unprocessed": unprocessed,
|
||||
"llm_pending": llm_pending,
|
||||
}
|
||||
167
db/models.py
Normal file
167
db/models.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
DateTime,
|
||||
Float,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
create_engine,
|
||||
event,
|
||||
)
|
||||
from sqlalchemy.orm import (
|
||||
DeclarativeBase,
|
||||
Mapped,
|
||||
mapped_column,
|
||||
relationship,
|
||||
sessionmaker,
|
||||
)
|
||||
|
||||
from config import get_config
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class Camera(Base):
|
||||
__tablename__ = "cameras"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
name: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
rtsp_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
fps_limit: Mapped[int] = mapped_column(Integer, default=30)
|
||||
process_every_n_frames: Mapped[int] = mapped_column(Integer, default=3)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
rois: Mapped[List["ROI"]] = relationship(
|
||||
"ROI", back_populates="camera", cascade="all, delete-orphan"
|
||||
)
|
||||
status: Mapped[Optional["CameraStatus"]] = relationship(
|
||||
"CameraStatus", back_populates="camera", uselist=False
|
||||
)
|
||||
alarms: Mapped[List["Alarm"]] = relationship(
|
||||
"Alarm", back_populates="camera", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class CameraStatus(Base):
|
||||
__tablename__ = "camera_status"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
camera_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("cameras.id"), unique=True, nullable=False
|
||||
)
|
||||
is_running: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
last_frame_time: Mapped[Optional[datetime]] = mapped_column(DateTime)
|
||||
fps: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text)
|
||||
last_check_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
camera: Mapped["Camera"] = relationship("Camera", back_populates="status")
|
||||
|
||||
|
||||
class ROI(Base):
|
||||
__tablename__ = "rois"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
camera_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("cameras.id"), nullable=False
|
||||
)
|
||||
roi_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
roi_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
points: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
rule_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
direction: Mapped[Optional[str]] = mapped_column(String(32))
|
||||
stay_time: Mapped[Optional[int]] = mapped_column(Integer)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
threshold_sec: Mapped[int] = mapped_column(Integer, default=360)
|
||||
confirm_sec: Mapped[int] = mapped_column(Integer, default=30)
|
||||
return_sec: Mapped[int] = mapped_column(Integer, default=5)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow
|
||||
)
|
||||
|
||||
camera: Mapped["Camera"] = relationship("Camera", back_populates="rois")
|
||||
|
||||
|
||||
class Alarm(Base):
|
||||
__tablename__ = "alarms"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
camera_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("cameras.id"), nullable=False
|
||||
)
|
||||
roi_id: Mapped[Optional[str]] = mapped_column(String(64))
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, default=0.0)
|
||||
snapshot_path: Mapped[Optional[str]] = mapped_column(Text)
|
||||
llm_checked: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
llm_result: Mapped[Optional[str]] = mapped_column(Text)
|
||||
processed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||
|
||||
camera: Mapped["Camera"] = relationship("Camera", back_populates="alarms")
|
||||
|
||||
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def get_engine():
|
||||
global _engine
|
||||
if _engine is None:
|
||||
config = get_config()
|
||||
_engine = create_engine(
|
||||
config.database.url,
|
||||
echo=config.database.echo,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
return _engine
|
||||
|
||||
|
||||
def get_session_factory():
|
||||
global _SessionLocal
|
||||
if _SessionLocal is None:
|
||||
_SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=get_engine()
|
||||
)
|
||||
return _SessionLocal
|
||||
|
||||
|
||||
def get_db():
|
||||
SessionLocal = get_session_factory()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def init_db():
|
||||
config = get_config()
|
||||
engine = get_engine()
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
if config.database.dialect == "sqlite":
|
||||
for table in [Camera, ROI, Alarm]:
|
||||
table.__table__.create(engine, checkfirst=True)
|
||||
|
||||
|
||||
def reset_engine():
|
||||
global _engine, _SessionLocal
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
218
detector.py
218
detector.py
@@ -1,218 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from ultralytics import YOLO
|
||||
from sort import Sort
|
||||
import time
|
||||
import datetime
|
||||
import threading
|
||||
import queue
|
||||
import torch
|
||||
from collections import deque
|
||||
|
||||
|
||||
class ThreadedFrameReader:
|
||||
def __init__(self, src, maxsize=1):
|
||||
self.cap = cv2.VideoCapture(src, cv2.CAP_FFMPEG)
|
||||
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
self.q = queue.Queue(maxsize=maxsize)
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._reader)
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _reader(self):
|
||||
while self.running:
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
if not self.q.empty():
|
||||
try:
|
||||
self.q.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
self.q.put(frame)
|
||||
|
||||
def read(self):
|
||||
if not self.q.empty():
|
||||
return True, self.q.get()
|
||||
return False, None
|
||||
|
||||
def release(self):
|
||||
self.running = False
|
||||
self.cap.release()
|
||||
|
||||
|
||||
def is_point_in_roi(x, y, roi):
|
||||
return cv2.pointPolygonTest(roi, (int(x), int(y)), False) >= 0
|
||||
|
||||
|
||||
class OffDutyCrowdDetector:
|
||||
def __init__(self, config, model, device, use_half):
|
||||
self.config = config
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.use_half = use_half
|
||||
|
||||
# 解析 ROI
|
||||
self.roi = np.array(config["roi_points"], dtype=np.int32)
|
||||
self.crowd_roi = np.array(config["crowd_roi_points"], dtype=np.int32)
|
||||
|
||||
# 状态变量
|
||||
self.tracker = Sort(
|
||||
max_age=30,
|
||||
min_hits=2,
|
||||
iou_threshold=0.3
|
||||
)
|
||||
|
||||
self.is_on_duty = False
|
||||
self.on_duty_start_time = None
|
||||
self.is_off_duty = True
|
||||
self.last_no_person_time = None
|
||||
self.off_duty_timer_start = None
|
||||
self.last_alert_time = 0
|
||||
|
||||
self.last_crowd_alert_time = 0
|
||||
self.crowd_history = deque(maxlen=1500) # 自动限制5分钟(假设5fps)
|
||||
|
||||
self.last_person_seen_time = None
|
||||
self.frame_count = 0
|
||||
|
||||
# 缓存配置
|
||||
self.working_start_min = config.get("working_hours", [9, 17])[0] * 60
|
||||
self.working_end_min = config.get("working_hours", [9, 17])[1] * 60
|
||||
self.process_every = config.get("process_every_n_frames", 3)
|
||||
|
||||
def in_working_hours(self):
|
||||
now = datetime.datetime.now()
|
||||
total_min = now.hour * 60 + now.minute
|
||||
return self.working_start_min <= total_min <= self.working_end_min
|
||||
|
||||
def count_people_in_roi(self, boxes, roi):
|
||||
count = 0
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
if is_point_in_roi(cx, cy, roi):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def run(self):
|
||||
"""主循环:供线程调用"""
|
||||
frame_reader = ThreadedFrameReader(self.config["rtsp_url"])
|
||||
try:
|
||||
while True:
|
||||
ret, frame = frame_reader.read()
|
||||
if not ret:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
self.frame_count += 1
|
||||
if self.frame_count % self.process_every != 0:
|
||||
continue
|
||||
|
||||
current_time = time.time()
|
||||
now = datetime.datetime.now()
|
||||
|
||||
# YOLO 推理
|
||||
results = self.model(
|
||||
frame,
|
||||
imgsz=self.config.get("imgsz", 480),
|
||||
conf=self.config.get("conf_thresh", 0.4),
|
||||
verbose=False,
|
||||
device=self.device,
|
||||
half=self.use_half,
|
||||
classes=[0] # person class
|
||||
)
|
||||
boxes = results[0].boxes
|
||||
|
||||
# 更新 tracker(可选,用于ID跟踪)
|
||||
dets = []
|
||||
for box in boxes:
|
||||
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
|
||||
conf = float(box.conf)
|
||||
dets.append([x1, y1, x2, y2, conf])
|
||||
dets = np.array(dets) if dets else np.empty((0, 5))
|
||||
self.tracker.update(dets)
|
||||
|
||||
# === 离岗检测 ===
|
||||
if self.in_working_hours():
|
||||
roi_has_person = self.count_people_in_roi(boxes, self.roi) > 0
|
||||
if roi_has_person:
|
||||
self.last_person_seen_time = current_time
|
||||
|
||||
# 入岗保护期
|
||||
effective_on_duty = (
|
||||
self.last_person_seen_time is not None and
|
||||
(current_time - self.last_person_seen_time) < 1.0
|
||||
)
|
||||
|
||||
if effective_on_duty:
|
||||
self.last_no_person_time = None
|
||||
if self.is_off_duty:
|
||||
if self.on_duty_start_time is None:
|
||||
self.on_duty_start_time = current_time
|
||||
elif current_time - self.on_duty_start_time >= self.config.get("on_duty_confirm", 5):
|
||||
self.is_on_duty = True
|
||||
self.is_off_duty = False
|
||||
self.on_duty_start_time = None
|
||||
print(f"[{self.config['id']}] ✅ 上岗确认")
|
||||
else:
|
||||
self.on_duty_start_time = None
|
||||
self.last_person_seen_time = None
|
||||
if not self.is_off_duty:
|
||||
if self.last_no_person_time is None:
|
||||
self.last_no_person_time = current_time
|
||||
elif current_time - self.last_no_person_time >= self.config.get("off_duty_confirm", 30):
|
||||
self.is_off_duty = True
|
||||
self.is_on_duty = False
|
||||
self.off_duty_timer_start = current_time
|
||||
print(f"[{self.config['id']}] ⏳ 开始离岗计时")
|
||||
|
||||
# 离岗告警
|
||||
if self.is_off_duty and self.off_duty_timer_start:
|
||||
elapsed = current_time - self.off_duty_timer_start
|
||||
if elapsed >= self.config.get("off_duty_threshold", 300):
|
||||
if current_time - self.last_alert_time >= self.config.get("alert_cooldown", 300):
|
||||
print(f"[{self.config['id']}] 🚨 离岗告警!已离岗 {elapsed/60:.1f} 分钟")
|
||||
self.last_alert_time = current_time
|
||||
|
||||
# === 聚集检测 ===
|
||||
crowd_count = self.count_people_in_roi(boxes, self.crowd_roi)
|
||||
self.crowd_history.append((current_time, crowd_count))
|
||||
|
||||
# 动态阈值
|
||||
if crowd_count >= 10:
|
||||
req_dur = 60
|
||||
elif crowd_count >= 7:
|
||||
req_dur = 120
|
||||
elif crowd_count >= 5:
|
||||
req_dur = 300
|
||||
else:
|
||||
req_dur = float('inf')
|
||||
|
||||
if req_dur < float('inf'):
|
||||
recent = [(t, c) for t, c in self.crowd_history if current_time - t <= req_dur]
|
||||
if recent:
|
||||
valid = [c for t, c in recent if c >= 4]
|
||||
ratio = len(valid) / len(recent)
|
||||
if ratio >= 0.9 and (current_time - self.last_crowd_alert_time) >= self.config.get("crowd_cooldown", 180):
|
||||
print(f"[{self.config['id']}] 🚨 聚集告警!{crowd_count}人持续{req_dur//60}分钟")
|
||||
self.last_crowd_alert_time = current_time
|
||||
|
||||
# 可视化(可选,部署时可关闭)
|
||||
if True: # 设为 True 可显示窗口
|
||||
vis = results[0].plot()
|
||||
overlay = vis.copy()
|
||||
cv2.fillPoly(overlay, [self.roi], (0,255,0))
|
||||
cv2.fillPoly(overlay, [self.crowd_roi], (0,0,255))
|
||||
cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis)
|
||||
cv2.imshow(f"Monitor - {self.config['id']}", vis)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.config['id']}] Error: {e}")
|
||||
finally:
|
||||
frame_reader.release()
|
||||
cv2.destroyAllWindows()
|
||||
63
docker-compose.yml
Normal file
63
docker-compose.yml
Normal file
@@ -0,0 +1,63 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
security-monitor:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: security-monitor
|
||||
runtime: nvidia
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "9090:9090"
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
- ./data:/app/data
|
||||
- ./logs:/app/logs
|
||||
- ./config.yaml:/app/config.yaml:ro
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
- CONFIG_PATH=/app/config.yaml
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus:v2.45.0
|
||||
container_name: prometheus
|
||||
ports:
|
||||
- "9091:9090"
|
||||
volumes:
|
||||
- ./prometheus.yml:/etc/prometheus/prometheus.yml:ro
|
||||
- prometheus_data:/prometheus
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
- '--storage.tsdb.path=/prometheus'
|
||||
- '--web.enable-lifecycle'
|
||||
restart: unless-stopped
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:10.0.0
|
||||
container_name: grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- grafana_data:/var/lib/grafana
|
||||
- ./grafana/provisioning:/etc/grafana/provisioning:ro
|
||||
environment:
|
||||
- GF_SECURITY_ADMIN_PASSWORD=admin
|
||||
restart: unless-stopched
|
||||
|
||||
volumes:
|
||||
prometheus_data:
|
||||
grafana_data:
|
||||
244
inference/engine.py
Normal file
244
inference/engine.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
from config import get_config
|
||||
|
||||
|
||||
class TensorRTEngine:
|
||||
def __init__(self, engine_path: Optional[str] = None, device: int = 0):
|
||||
config = get_config()
|
||||
self.engine_path = engine_path or config.model.engine_path
|
||||
self.device = device
|
||||
self.imgsz = tuple(config.model.imgsz)
|
||||
self.conf_thresh = config.model.conf_threshold
|
||||
self.iou_thresh = config.model.iou_threshold
|
||||
self.half = config.model.half
|
||||
|
||||
self.logger = trt.Logger(trt.Logger.INFO)
|
||||
self.engine = None
|
||||
self.context = None
|
||||
self.stream = None
|
||||
self.input_buffer = None
|
||||
self.output_buffers = []
|
||||
|
||||
self._load_engine()
|
||||
|
||||
def _load_engine(self):
|
||||
if not os.path.exists(self.engine_path):
|
||||
raise FileNotFoundError(f"TensorRT引擎文件不存在: {self.engine_path}")
|
||||
|
||||
with open(self.engine_path, "rb") as f:
|
||||
serialized_engine = f.read()
|
||||
|
||||
runtime = trt.Runtime(self.logger)
|
||||
self.engine = runtime.deserialize_cuda_engine(serialized_engine)
|
||||
|
||||
self.context = self.engine.create_execution_context()
|
||||
|
||||
self.stream = torch.cuda.Stream(device=self.device)
|
||||
|
||||
for i in range(self.engine.num_io_tensors):
|
||||
name = self.engine.get_tensor_name(i)
|
||||
dtype = self.engine.get_tensor_dtype(name)
|
||||
shape = self.engine.get_tensor_shape(name)
|
||||
|
||||
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
|
||||
self.context.set_tensor_address(name, None)
|
||||
else:
|
||||
if dtype == trt.float16:
|
||||
buffer = torch.zeros(shape, dtype=torch.float16, device=self.device)
|
||||
else:
|
||||
buffer = torch.zeros(shape, dtype=torch.float32, device=self.device)
|
||||
self.output_buffers.append(buffer)
|
||||
self.context.set_tensor_address(name, buffer.data_ptr())
|
||||
|
||||
self.context.set_optimization_profile_async(0, self.stream)
|
||||
|
||||
self.input_buffer = torch.zeros(
|
||||
(1, 3, self.imgsz[0], self.imgsz[1]),
|
||||
dtype=torch.float16 if self.half else torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def preprocess(self, frame: np.ndarray) -> torch.Tensor:
|
||||
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, self.imgsz)
|
||||
|
||||
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
|
||||
|
||||
if self.half:
|
||||
img = img.astype(np.float16)
|
||||
|
||||
tensor = torch.from_numpy(img).unsqueeze(0).to(self.device)
|
||||
|
||||
return tensor
|
||||
|
||||
def inference(self, images: List[np.ndarray]) -> List[Results]:
|
||||
batch_size = len(images)
|
||||
if batch_size == 0:
|
||||
return []
|
||||
|
||||
input_tensor = self.preprocess(images[0])
|
||||
|
||||
if batch_size > 1:
|
||||
for i in range(1, batch_size):
|
||||
input_tensor = torch.cat(
|
||||
[input_tensor, self.preprocess(images[i])], dim=0
|
||||
)
|
||||
|
||||
self.context.set_tensor_address(
|
||||
"input", input_tensor.contiguous().data_ptr()
|
||||
)
|
||||
|
||||
torch.cuda.synchronize(self.stream)
|
||||
self.context.execute_async_v3(self.stream.handle)
|
||||
torch.cuda.synchronize(self.stream)
|
||||
|
||||
results = []
|
||||
for i in range(batch_size):
|
||||
pred = self.output_buffers[0][i].cpu().numpy()
|
||||
boxes = pred[:, :4]
|
||||
scores = pred[:, 4]
|
||||
classes = pred[:, 5].astype(np.int32)
|
||||
|
||||
mask = scores > self.conf_thresh
|
||||
boxes = boxes[mask]
|
||||
scores = scores[mask]
|
||||
classes = classes[mask]
|
||||
|
||||
indices = cv2.dnn.NMSBoxes(
|
||||
boxes.tolist(),
|
||||
scores.tolist(),
|
||||
self.conf_thresh,
|
||||
self.iou_thresh,
|
||||
)
|
||||
|
||||
if len(indices) > 0:
|
||||
for idx in indices:
|
||||
box = boxes[idx]
|
||||
x1, y1, x2, y2 = box
|
||||
w, h = x2 - x1, y2 - y1
|
||||
conf = scores[idx]
|
||||
cls = classes[idx]
|
||||
|
||||
orig_h, orig_w = images[i].shape[:2]
|
||||
scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0]
|
||||
box_orig = [
|
||||
int(x1 * scale_x),
|
||||
int(y1 * scale_y),
|
||||
int(w * scale_x),
|
||||
int(h * scale_y),
|
||||
]
|
||||
|
||||
result = Results(
|
||||
orig_img=images[i],
|
||||
path="",
|
||||
names={0: "person"},
|
||||
boxes=Boxes(
|
||||
torch.tensor([box_orig + [conf, cls]]),
|
||||
orig_shape=(orig_h, orig_w),
|
||||
),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def inference_single(self, frame: np.ndarray) -> List[Results]:
|
||||
return self.inference([frame])
|
||||
|
||||
def warmup(self, num_warmup: int = 10):
|
||||
dummy_frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
for _ in range(num_warmup):
|
||||
self.inference_single(dummy_frame)
|
||||
|
||||
def __del__(self):
|
||||
if self.context:
|
||||
self.context.synchronize()
|
||||
if self.stream:
|
||||
self.stream.synchronize()
|
||||
|
||||
|
||||
class Boxes:
|
||||
def __init__(
|
||||
self,
|
||||
data: torch.Tensor,
|
||||
orig_shape: Tuple[int, int],
|
||||
is_track: bool = False,
|
||||
):
|
||||
self.data = data
|
||||
self.orig_shape = orig_shape
|
||||
self.is_track = is_track
|
||||
|
||||
@property
|
||||
def xyxy(self):
|
||||
if self.is_track:
|
||||
return self.data[:, :4]
|
||||
return self.data[:, :4]
|
||||
|
||||
@property
|
||||
def conf(self):
|
||||
if self.is_track:
|
||||
return self.data[:, 4]
|
||||
return self.data[:, 4]
|
||||
|
||||
@property
|
||||
def cls(self):
|
||||
if self.is_track:
|
||||
return self.data[:, 5]
|
||||
return self.data[:, 5]
|
||||
|
||||
|
||||
class YOLOEngine:
|
||||
def __init__(
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
device: int = 0,
|
||||
use_trt: bool = True,
|
||||
):
|
||||
self.use_trt = use_trt
|
||||
self.device = device
|
||||
self.trt_engine = None
|
||||
|
||||
if not use_trt:
|
||||
if model_path:
|
||||
pt_path = model_path
|
||||
elif hasattr(get_config().model, 'pt_model_path'):
|
||||
pt_path = get_config().model.pt_model_path
|
||||
else:
|
||||
pt_path = get_config().model.engine_path.replace(".engine", ".pt")
|
||||
self.model = YOLO(pt_path)
|
||||
self.model.to(device)
|
||||
else:
|
||||
try:
|
||||
self.trt_engine = TensorRTEngine(device=device)
|
||||
self.trt_engine.warmup()
|
||||
except Exception as e:
|
||||
print(f"TensorRT加载失败,回退到PyTorch: {e}")
|
||||
self.use_trt = False
|
||||
if model_path:
|
||||
pt_path = model_path
|
||||
elif hasattr(get_config().model, 'pt_model_path'):
|
||||
pt_path = get_config().model.pt_model_path
|
||||
else:
|
||||
pt_path = get_config().model.engine_path.replace(".engine", ".pt")
|
||||
self.model = YOLO(pt_path)
|
||||
self.model.to(device)
|
||||
|
||||
def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]:
|
||||
if self.use_trt:
|
||||
return self.trt_engine.inference_single(frame)
|
||||
else:
|
||||
results = self.model(frame, imgsz=get_config().model.imgsz, **kwargs)
|
||||
return results
|
||||
|
||||
def __del__(self):
|
||||
if self.trt_engine:
|
||||
del self.trt_engine
|
||||
376
inference/pipeline.py
Normal file
376
inference/pipeline.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import get_config
|
||||
from db.crud import (
|
||||
create_alarm,
|
||||
get_all_rois,
|
||||
update_camera_status,
|
||||
)
|
||||
from db.models import init_db
|
||||
from inference.engine import YOLOEngine
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
from inference.rules.algorithms import AlgorithmManager
|
||||
from inference.stream import StreamManager
|
||||
|
||||
|
||||
class InferencePipeline:
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
|
||||
self.db_initialized = False
|
||||
|
||||
self.yolo_engine = YOLOEngine(use_trt=True)
|
||||
self.stream_manager = StreamManager(
|
||||
buffer_size=self.config.stream.buffer_size,
|
||||
reconnect_delay=self.config.stream.reconnect_delay,
|
||||
)
|
||||
self.roi_filter = ROIFilter()
|
||||
self.algo_manager = AlgorithmManager(working_hours=[
|
||||
{
|
||||
"start": [wh.start[0], wh.start[1]],
|
||||
"end": [wh.end[0], wh.end[1]],
|
||||
}
|
||||
for wh in self.config.working_hours
|
||||
])
|
||||
|
||||
self.camera_threads: Dict[int, threading.Thread] = {}
|
||||
self.camera_stop_events: Dict[int, threading.Event] = {}
|
||||
self.camera_latest_frames: Dict[int, Any] = {}
|
||||
self.camera_frame_times: Dict[int, datetime] = {}
|
||||
self.camera_process_counts: Dict[int, int] = {}
|
||||
|
||||
self.event_queue: deque = deque(maxlen=self.config.inference.event_queue_maxlen)
|
||||
|
||||
self.running = False
|
||||
|
||||
def _init_database(self):
|
||||
if not self.db_initialized:
|
||||
init_db()
|
||||
self.db_initialized = True
|
||||
|
||||
def _get_db_session(self):
|
||||
from db.models import get_session_factory
|
||||
SessionLocal = get_session_factory()
|
||||
return SessionLocal()
|
||||
|
||||
def _load_cameras(self):
|
||||
db = self._get_db_session()
|
||||
try:
|
||||
from db.crud import get_all_cameras
|
||||
cameras = get_all_cameras(db, enabled_only=True)
|
||||
for camera in cameras:
|
||||
self.add_camera(camera)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def add_camera(self, camera) -> bool:
|
||||
camera_id = camera.id
|
||||
if camera_id in self.camera_threads:
|
||||
return False
|
||||
|
||||
self.camera_stop_events[camera_id] = threading.Event()
|
||||
self.camera_process_counts[camera_id] = 0
|
||||
|
||||
self.stream_manager.add_stream(
|
||||
str(camera_id),
|
||||
camera.rtsp_url,
|
||||
self.config.stream.buffer_size,
|
||||
)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=self._camera_inference_loop,
|
||||
args=(camera,),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
self.camera_threads[camera_id] = thread
|
||||
|
||||
self._update_camera_status(camera_id, is_running=True)
|
||||
return True
|
||||
|
||||
def remove_camera(self, camera_id: int):
|
||||
if camera_id not in self.camera_threads:
|
||||
return
|
||||
|
||||
self.camera_stop_events[camera_id].set()
|
||||
self.camera_threads[camera_id].join(timeout=10.0)
|
||||
|
||||
del self.camera_threads[camera_id]
|
||||
del self.camera_stop_events[camera_id]
|
||||
|
||||
self.stream_manager.remove_stream(str(camera_id))
|
||||
self.roi_filter.clear_cache(camera_id)
|
||||
self.algo_manager.remove_roi(str(camera_id))
|
||||
|
||||
if camera_id in self.camera_latest_frames:
|
||||
del self.camera_latest_frames[camera_id]
|
||||
if camera_id in self.camera_frame_times:
|
||||
del self.camera_frame_times[camera_id]
|
||||
if camera_id in self.camera_process_counts:
|
||||
del self.camera_process_counts[camera_id]
|
||||
|
||||
self._update_camera_status(camera_id, is_running=False)
|
||||
|
||||
def _update_camera_status(
|
||||
self,
|
||||
camera_id: int,
|
||||
is_running: Optional[bool] = None,
|
||||
fps: Optional[float] = None,
|
||||
error_message: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
db = self._get_db_session()
|
||||
update_camera_status(
|
||||
db,
|
||||
camera_id,
|
||||
is_running=is_running,
|
||||
fps=fps,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{camera_id}] 更新状态失败: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _camera_inference_loop(self, camera):
|
||||
camera_id = camera.id
|
||||
stop_event = self.camera_stop_events[camera_id]
|
||||
|
||||
while not stop_event.is_set():
|
||||
ret, frame = self.stream_manager.read(str(camera_id))
|
||||
if not ret or frame is None:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
self.camera_latest_frames[camera_id] = frame
|
||||
self.camera_frame_times[camera_id] = datetime.now()
|
||||
|
||||
self.camera_process_counts[camera_id] += 1
|
||||
|
||||
if self.camera_process_counts[camera_id] % camera.process_every_n_frames != 0:
|
||||
continue
|
||||
|
||||
try:
|
||||
self._process_frame(camera_id, frame, camera)
|
||||
except Exception as e:
|
||||
print(f"[{camera_id}] 处理帧失败: {e}")
|
||||
|
||||
print(f"[{camera_id}] 推理线程已停止")
|
||||
|
||||
def _process_frame(self, camera_id: int, frame: np.ndarray, camera):
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
db = self._get_db_session()
|
||||
try:
|
||||
rois = get_all_rois(db, camera_id)
|
||||
roi_configs = [
|
||||
{
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"type": roi.roi_type,
|
||||
"points": json.loads(roi.points),
|
||||
"rule": roi.rule_type,
|
||||
"direction": roi.direction,
|
||||
"enabled": roi.enabled,
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
}
|
||||
for roi in rois
|
||||
]
|
||||
|
||||
if roi_configs:
|
||||
self.roi_filter.update_cache(camera_id, roi_configs)
|
||||
|
||||
for roi_config in roi_configs:
|
||||
roi_id = roi_config["roi_id"]
|
||||
rule_type = roi_config["rule"]
|
||||
|
||||
self.algo_manager.register_algorithm(
|
||||
roi_id,
|
||||
rule_type,
|
||||
{
|
||||
"threshold_sec": roi_config.get("threshold_sec", 360),
|
||||
"confirm_sec": roi_config.get("confirm_sec", 30),
|
||||
"return_sec": roi_config.get("return_sec", 5),
|
||||
},
|
||||
)
|
||||
|
||||
results = self.yolo_engine(frame, verbose=False, classes=[0])
|
||||
|
||||
if not results:
|
||||
return
|
||||
|
||||
result = results[0]
|
||||
detections = []
|
||||
if hasattr(result, "boxes") and result.boxes is not None:
|
||||
boxes = result.boxes.xyxy.cpu().numpy()
|
||||
confs = result.boxes.conf.cpu().numpy()
|
||||
for i, box in enumerate(boxes):
|
||||
detections.append({
|
||||
"bbox": box.tolist(),
|
||||
"conf": float(confs[i]),
|
||||
"cls": 0,
|
||||
})
|
||||
|
||||
if roi_configs:
|
||||
filtered_detections = self.roi_filter.filter_detections(
|
||||
detections, roi_configs
|
||||
)
|
||||
else:
|
||||
filtered_detections = detections
|
||||
|
||||
for detection in filtered_detections:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
for roi_conf in matched_rois:
|
||||
roi_id = roi_conf["roi_id"]
|
||||
rule_type = roi_conf["rule"]
|
||||
|
||||
alerts = self.algo_manager.process(
|
||||
roi_id,
|
||||
str(camera_id),
|
||||
rule_type,
|
||||
[detection],
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
for alert in alerts:
|
||||
self._handle_alert(camera_id, alert, frame, roi_conf)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _handle_alert(
|
||||
self,
|
||||
camera_id: int,
|
||||
alert: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
roi_config: Dict[str, Any],
|
||||
):
|
||||
try:
|
||||
snapshot_path = None
|
||||
bbox = alert.get("bbox", [])
|
||||
if bbox and len(bbox) >= 4:
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox]
|
||||
x1, y1 = max(0, x1), max(0, y1)
|
||||
x2, y2 = min(frame.shape[1], x2), min(frame.shape[0], y2)
|
||||
|
||||
snapshot_dir = self.config.alert.snapshot_path
|
||||
os.makedirs(snapshot_dir, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"cam_{camera_id}_{roi_config['roi_id']}_{alert['alert_type']}_{timestamp}.jpg"
|
||||
snapshot_path = os.path.join(snapshot_dir, filename)
|
||||
|
||||
cv2.imwrite(snapshot_path, frame, [cv2.IMWRITE_JPEG_QUALITY, self.config.alert.image_quality])
|
||||
|
||||
db = self._get_db_session()
|
||||
try:
|
||||
alarm = create_alarm(
|
||||
db,
|
||||
camera_id=camera_id,
|
||||
event_type=alert["alert_type"],
|
||||
confidence=alert.get("confidence", alert.get("conf", 0.0)),
|
||||
snapshot_path=snapshot_path,
|
||||
roi_id=roi_config["roi_id"],
|
||||
)
|
||||
alert["alarm_id"] = alarm.id
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
self.event_queue.append({
|
||||
"camera_id": camera_id,
|
||||
"roi_id": roi_config["roi_id"],
|
||||
"event_type": alert["alert_type"],
|
||||
"confidence": alert.get("confidence", alert.get("conf", 0.0)),
|
||||
"message": alert.get("message", ""),
|
||||
"snapshot_path": snapshot_path,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"llm_checked": False,
|
||||
})
|
||||
|
||||
print(f"[{camera_id}] 🚨 告警: {alert['alert_type']} - {alert.get('message', '')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{camera_id}] 处理告警失败: {e}")
|
||||
|
||||
def get_latest_frame(self, camera_id: int) -> Optional[np.ndarray]:
|
||||
return self.camera_latest_frames.get(camera_id)
|
||||
|
||||
def get_camera_fps(self, camera_id: int) -> float:
|
||||
stream = self.stream_manager.get_stream(str(camera_id))
|
||||
if stream:
|
||||
return stream.fps
|
||||
return 0.0
|
||||
|
||||
def get_event_queue(self) -> List[Dict[str, Any]]:
|
||||
return list(self.event_queue)
|
||||
|
||||
def start(self):
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self._init_database()
|
||||
self._load_cameras()
|
||||
self.running = True
|
||||
|
||||
def stop(self):
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
|
||||
for camera_id in list(self.camera_threads.keys()):
|
||||
self.remove_camera(camera_id)
|
||||
|
||||
self.stream_manager.stop_all()
|
||||
self.algo_manager.reset_all()
|
||||
|
||||
print("推理pipeline已停止")
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"running": self.running,
|
||||
"camera_count": len(self.camera_threads),
|
||||
"cameras": {
|
||||
cid: {
|
||||
"running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(),
|
||||
"fps": self.get_camera_fps(cid),
|
||||
"frame_time": self.camera_frame_times.get(cid).isoformat() if self.camera_frame_times.get(cid) else None,
|
||||
}
|
||||
for cid in self.camera_threads
|
||||
},
|
||||
"event_queue_size": len(self.event_queue),
|
||||
}
|
||||
|
||||
|
||||
_pipeline: Optional[InferencePipeline] = None
|
||||
|
||||
|
||||
def get_pipeline() -> InferencePipeline:
|
||||
global _pipeline
|
||||
if _pipeline is None:
|
||||
_pipeline = InferencePipeline()
|
||||
return _pipeline
|
||||
|
||||
|
||||
def start_pipeline():
|
||||
pipeline = get_pipeline()
|
||||
pipeline.start()
|
||||
return pipeline
|
||||
|
||||
|
||||
def stop_pipeline():
|
||||
global _pipeline
|
||||
if _pipeline is not None:
|
||||
_pipeline.stop()
|
||||
_pipeline = None
|
||||
168
inference/roi/roi_filter.py
Normal file
168
inference/roi/roi_filter.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import json
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from shapely.geometry import LineString, Point, Polygon
|
||||
|
||||
|
||||
class ROIFilter:
|
||||
def __init__(self):
|
||||
self.roi_cache: Dict[int, List[Dict]] = {}
|
||||
|
||||
def parse_points(self, points_json: str) -> List[Tuple[float, float]]:
|
||||
if isinstance(points_json, str):
|
||||
points = json.loads(points_json)
|
||||
else:
|
||||
points = points_json
|
||||
return [(float(p[0]), float(p[1])) for p in points]
|
||||
|
||||
def create_polygon(self, points: List[Tuple[float, float]]) -> Polygon:
|
||||
return Polygon(points)
|
||||
|
||||
def create_line(self, points: List[Tuple[float, float]]) -> LineString:
|
||||
return LineString(points)
|
||||
|
||||
def is_point_in_polygon(self, point: Tuple[float, float], polygon: Polygon) -> bool:
|
||||
return polygon.contains(Point(point))
|
||||
|
||||
def is_line_crossed(
|
||||
self,
|
||||
line_start: Tuple[float, float],
|
||||
line_end: Tuple[float, float],
|
||||
line_obj: LineString,
|
||||
) -> bool:
|
||||
trajectory = LineString([line_start, line_end])
|
||||
return line_obj.intersects(trajectory)
|
||||
|
||||
def get_bbox_center(self, bbox: List[float]) -> Tuple[float, float]:
|
||||
x1, y1, x2, y2 = bbox[:4]
|
||||
return ((x1 + x2) / 2, (y1 + y2) / 2)
|
||||
|
||||
def is_bbox_in_roi(self, bbox: List[float], roi_config: Dict) -> bool:
|
||||
center = self.get_bbox_center(bbox)
|
||||
roi_type = roi_config.get("type", "polygon")
|
||||
points = self.parse_points(roi_config["points"])
|
||||
|
||||
if roi_type == "polygon":
|
||||
polygon = self.create_polygon(points)
|
||||
return self.is_point_in_polygon(center, polygon)
|
||||
elif roi_type == "line":
|
||||
line = self.create_line(points)
|
||||
trajectory_start = (
|
||||
bbox[0],
|
||||
bbox[1],
|
||||
)
|
||||
trajectory_end = (
|
||||
bbox[2],
|
||||
bbox[3],
|
||||
)
|
||||
return self.is_line_crossed(trajectory_start, trajectory_end, line)
|
||||
|
||||
return False
|
||||
|
||||
def filter_detections(
|
||||
self,
|
||||
detections: List[Dict],
|
||||
roi_configs: List[Dict],
|
||||
require_all_rois: bool = False,
|
||||
) -> List[Dict]:
|
||||
if not roi_configs:
|
||||
return detections
|
||||
|
||||
filtered = []
|
||||
for det in detections:
|
||||
bbox = det.get("bbox", [])
|
||||
if not bbox:
|
||||
filtered.append(det)
|
||||
continue
|
||||
|
||||
matches = []
|
||||
for roi_config in roi_configs:
|
||||
if not roi_config.get("enabled", True):
|
||||
continue
|
||||
if self.is_bbox_in_roi(bbox, roi_config):
|
||||
matches.append(roi_config)
|
||||
|
||||
if matches:
|
||||
if require_all_rois:
|
||||
if len(matches) == len(roi_configs):
|
||||
det["matched_rois"] = matches
|
||||
filtered.append(det)
|
||||
else:
|
||||
det["matched_rois"] = matches
|
||||
filtered.append(det)
|
||||
|
||||
return filtered
|
||||
|
||||
def filter_by_rule_type(
|
||||
self,
|
||||
detections: List[Dict],
|
||||
roi_configs: List[Dict],
|
||||
rule_type: str,
|
||||
) -> List[Dict]:
|
||||
rule_rois = [
|
||||
roi for roi in roi_configs
|
||||
if roi.get("rule") == rule_type and roi.get("enabled", True)
|
||||
]
|
||||
return self.filter_detections(detections, rule_rois)
|
||||
|
||||
def update_cache(self, camera_id: int, rois: List[Dict]):
|
||||
self.roi_cache[camera_id] = rois
|
||||
|
||||
def get_cached_rois(self, camera_id: int) -> List[Dict]:
|
||||
return self.roi_cache.get(camera_id, [])
|
||||
|
||||
def clear_cache(self, camera_id: Optional[int] = None):
|
||||
if camera_id is None:
|
||||
self.roi_cache.clear()
|
||||
elif camera_id in self.roi_cache:
|
||||
del self.roi_cache[camera_id]
|
||||
|
||||
def check_direction(
|
||||
self,
|
||||
prev_bbox: List[float],
|
||||
curr_bbox: List[float],
|
||||
roi_config: Dict,
|
||||
) -> bool:
|
||||
if roi_config.get("direction") is None:
|
||||
return True
|
||||
|
||||
direction = roi_config["direction"]
|
||||
prev_center = self.get_bbox_center(prev_bbox)
|
||||
curr_center = self.get_bbox_center(curr_bbox)
|
||||
|
||||
roi_points = self.parse_points(roi_config["points"])
|
||||
if len(roi_points) != 2:
|
||||
return True
|
||||
|
||||
line = self.create_line(roi_points)
|
||||
line_start = roi_points[0]
|
||||
line_end = roi_points[1]
|
||||
|
||||
dx_line = line_end[0] - line_start[0]
|
||||
dy_line = line_end[1] - line_start[1]
|
||||
|
||||
dx_person = curr_center[0] - prev_center[0]
|
||||
dy_person = curr_center[1] - prev_center[1]
|
||||
|
||||
dot_product = dx_person * dx_line + dy_person * dy_line
|
||||
|
||||
if direction == "A_to_B":
|
||||
return dot_product > 0
|
||||
elif direction == "B_to_A":
|
||||
return dot_product < 0
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_detection_roi_info(
|
||||
self,
|
||||
bbox: List[float],
|
||||
roi_configs: List[Dict],
|
||||
) -> List[Dict]:
|
||||
matched = []
|
||||
for roi_config in roi_configs:
|
||||
if not roi_config.get("enabled", True):
|
||||
continue
|
||||
if self.is_bbox_in_roi(bbox, roi_config):
|
||||
matched.append(roi_config)
|
||||
return matched
|
||||
303
inference/rules/algorithms.py
Normal file
303
inference/rules/algorithms.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sort import Sort
|
||||
|
||||
|
||||
class LeavePostAlgorithm:
|
||||
def __init__(
|
||||
self,
|
||||
threshold_sec: int = 360,
|
||||
confirm_sec: int = 30,
|
||||
return_sec: int = 5,
|
||||
working_hours: Optional[List[Dict]] = None,
|
||||
):
|
||||
self.threshold_sec = threshold_sec
|
||||
self.confirm_sec = confirm_sec
|
||||
self.return_sec = return_sec
|
||||
self.working_hours = working_hours or []
|
||||
|
||||
self.track_states: Dict[str, Dict[str, Any]] = {}
|
||||
self.tracker = Sort(max_age=10, min_hits=2, iou_threshold=0.3)
|
||||
|
||||
self.alert_cooldowns: Dict[str, datetime] = {}
|
||||
self.cooldown_seconds = 300
|
||||
|
||||
def is_in_working_hours(self, dt: Optional[datetime] = None) -> bool:
|
||||
if not self.working_hours:
|
||||
return True
|
||||
|
||||
dt = dt or datetime.now()
|
||||
current_minutes = dt.hour * 60 + dt.minute
|
||||
|
||||
for period in self.working_hours:
|
||||
start_minutes = period["start"][0] * 60 + period["start"][1]
|
||||
end_minutes = period["end"][0] * 60 + period["end"][1]
|
||||
if start_minutes <= current_minutes < end_minutes:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def process(
|
||||
self,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
if not self.is_in_working_hours(current_time):
|
||||
return []
|
||||
|
||||
if not tracks:
|
||||
return []
|
||||
|
||||
detections = []
|
||||
for track in tracks:
|
||||
bbox = track.get("bbox", [])
|
||||
if len(bbox) >= 4:
|
||||
detections.append(bbox + [track.get("conf", 0.0)])
|
||||
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
detections = np.array(detections)
|
||||
tracked = self.tracker.update(detections)
|
||||
|
||||
alerts = []
|
||||
current_time = current_time or datetime.now()
|
||||
|
||||
for track_data in tracked:
|
||||
x1, y1, x2, y2, track_id = track_data
|
||||
track_id = str(int(track_id))
|
||||
|
||||
if track_id not in self.track_states:
|
||||
self.track_states[track_id] = {
|
||||
"first_seen": current_time,
|
||||
"last_seen": current_time,
|
||||
"off_duty_start": None,
|
||||
"alerted": False,
|
||||
"last_position": (x1, y1, x2, y2),
|
||||
}
|
||||
|
||||
state = self.track_states[track_id]
|
||||
state["last_seen"] = current_time
|
||||
state["last_position"] = (x1, y1, x2, y2)
|
||||
|
||||
if state["off_duty_start"] is None:
|
||||
off_duty_duration = (current_time - state["first_seen"]).total_seconds()
|
||||
if off_duty_duration > self.confirm_sec:
|
||||
state["off_duty_start"] = current_time
|
||||
else:
|
||||
elapsed = (current_time - state["off_duty_start"]).total_seconds()
|
||||
if elapsed > self.threshold_sec:
|
||||
if not state["alerted"]:
|
||||
cooldown_key = f"{camera_id}_{track_id}"
|
||||
now = datetime.now()
|
||||
if cooldown_key not in self.alert_cooldowns or (
|
||||
now - self.alert_cooldowns[cooldown_key]
|
||||
).total_seconds() > self.cooldown_seconds:
|
||||
alerts.append({
|
||||
"track_id": track_id,
|
||||
"bbox": [x1, y1, x2, y2],
|
||||
"off_duty_duration": elapsed,
|
||||
"alert_type": "leave_post",
|
||||
"message": f"离岗超过 {int(elapsed / 60)} 分钟",
|
||||
})
|
||||
state["alerted"] = True
|
||||
self.alert_cooldowns[cooldown_key] = now
|
||||
else:
|
||||
if elapsed < self.return_sec:
|
||||
state["off_duty_start"] = None
|
||||
state["alerted"] = False
|
||||
|
||||
cleanup_time = current_time - timedelta(minutes=5)
|
||||
for track_id, state in list(self.track_states.items()):
|
||||
if state["last_seen"] < cleanup_time:
|
||||
del self.track_states[track_id]
|
||||
|
||||
return alerts
|
||||
|
||||
def reset(self):
|
||||
self.track_states.clear()
|
||||
self.alert_cooldowns.clear()
|
||||
|
||||
def get_state(self, track_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self.track_states.get(track_id)
|
||||
|
||||
|
||||
class IntrusionAlgorithm:
|
||||
def __init__(
|
||||
self,
|
||||
check_interval_sec: float = 1.0,
|
||||
direction_sensitive: bool = False,
|
||||
):
|
||||
self.check_interval_sec = check_interval_sec
|
||||
self.direction_sensitive = direction_sensitive
|
||||
|
||||
self.last_check_times: Dict[str, float] = {}
|
||||
self.tracker = Sort(max_age=5, min_hits=1, iou_threshold=0.3)
|
||||
|
||||
self.alert_cooldowns: Dict[str, datetime] = {}
|
||||
self.cooldown_seconds = 300
|
||||
|
||||
def process(
|
||||
self,
|
||||
camera_id: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
if not tracks:
|
||||
return []
|
||||
|
||||
detections = []
|
||||
for track in tracks:
|
||||
bbox = track.get("bbox", [])
|
||||
if len(bbox) >= 4:
|
||||
detections.append(bbox + [track.get("conf", 0.0)])
|
||||
|
||||
if not detections:
|
||||
return []
|
||||
|
||||
current_ts = current_time.timestamp() if current_time else datetime.now().timestamp()
|
||||
|
||||
if camera_id in self.last_check_times:
|
||||
if current_ts - self.last_check_times[camera_id] < self.check_interval_sec:
|
||||
return []
|
||||
self.last_check_times[camera_id] = current_ts
|
||||
|
||||
detections = np.array(detections)
|
||||
tracked = self.tracker.update(detections)
|
||||
|
||||
alerts = []
|
||||
now = datetime.now()
|
||||
|
||||
for track_data in tracked:
|
||||
x1, y1, x2, y2, track_id = track_data
|
||||
cooldown_key = f"{camera_id}_{int(track_id)}"
|
||||
|
||||
if cooldown_key not in self.alert_cooldowns or (
|
||||
now - self.alert_cooldowns[cooldown_key]
|
||||
).total_seconds() > self.cooldown_seconds:
|
||||
alerts.append({
|
||||
"track_id": str(int(track_id)),
|
||||
"bbox": [x1, y1, x2, y2],
|
||||
"alert_type": "intrusion",
|
||||
"confidence": track_data[4] if len(track_data) > 4 else 0.0,
|
||||
"message": "检测到周界入侵",
|
||||
})
|
||||
self.alert_cooldowns[cooldown_key] = now
|
||||
|
||||
return alerts
|
||||
|
||||
def reset(self):
|
||||
self.last_check_times.clear()
|
||||
self.alert_cooldowns.clear()
|
||||
|
||||
|
||||
class AlgorithmManager:
|
||||
def __init__(self, working_hours: Optional[List[Dict]] = None):
|
||||
self.algorithms: Dict[str, Dict[str, Any]] = {}
|
||||
self.working_hours = working_hours or []
|
||||
|
||||
self.default_params = {
|
||||
"leave_post": {
|
||||
"threshold_sec": 360,
|
||||
"confirm_sec": 30,
|
||||
"return_sec": 5,
|
||||
},
|
||||
"intrusion": {
|
||||
"check_interval_sec": 1.0,
|
||||
"direction_sensitive": False,
|
||||
},
|
||||
}
|
||||
|
||||
def register_algorithm(
|
||||
self,
|
||||
roi_id: str,
|
||||
algorithm_type: str,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if roi_id in self.algorithms:
|
||||
if algorithm_type in self.algorithms[roi_id]:
|
||||
return
|
||||
|
||||
if roi_id not in self.algorithms:
|
||||
self.algorithms[roi_id] = {}
|
||||
|
||||
algo_params = self.default_params.get(algorithm_type, {})
|
||||
if params:
|
||||
algo_params.update(params)
|
||||
|
||||
if algorithm_type == "leave_post":
|
||||
self.algorithms[roi_id]["leave_post"] = LeavePostAlgorithm(
|
||||
threshold_sec=algo_params.get("threshold_sec", 360),
|
||||
confirm_sec=algo_params.get("confirm_sec", 30),
|
||||
return_sec=algo_params.get("return_sec", 5),
|
||||
working_hours=self.working_hours,
|
||||
)
|
||||
elif algorithm_type == "intrusion":
|
||||
self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm(
|
||||
check_interval_sec=algo_params.get("check_interval_sec", 1.0),
|
||||
direction_sensitive=algo_params.get("direction_sensitive", False),
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
roi_id: str,
|
||||
camera_id: str,
|
||||
algorithm_type: str,
|
||||
tracks: List[Dict],
|
||||
current_time: Optional[datetime] = None,
|
||||
) -> List[Dict]:
|
||||
algo = self.algorithms.get(roi_id, {}).get(algorithm_type)
|
||||
if algo is None:
|
||||
return []
|
||||
return algo.process(camera_id, tracks, current_time)
|
||||
|
||||
def update_roi_params(
|
||||
self,
|
||||
roi_id: str,
|
||||
algorithm_type: str,
|
||||
params: Dict[str, Any],
|
||||
):
|
||||
if roi_id in self.algorithms and algorithm_type in self.algorithms[roi_id]:
|
||||
algo = self.algorithms[roi_id][algorithm_type]
|
||||
for key, value in params.items():
|
||||
if hasattr(algo, key):
|
||||
setattr(algo, key, value)
|
||||
|
||||
def reset_algorithm(self, roi_id: str, algorithm_type: Optional[str] = None):
|
||||
if roi_id not in self.algorithms:
|
||||
return
|
||||
|
||||
if algorithm_type:
|
||||
if algorithm_type in self.algorithms[roi_id]:
|
||||
self.algorithms[roi_id][algorithm_type].reset()
|
||||
else:
|
||||
for algo in self.algorithms[roi_id].values():
|
||||
algo.reset()
|
||||
|
||||
def reset_all(self):
|
||||
for roi_algorithms in self.algorithms.values():
|
||||
for algo in roi_algorithms.values():
|
||||
algo.reset()
|
||||
|
||||
def remove_roi(self, roi_id: str):
|
||||
if roi_id in self.algorithms:
|
||||
self.reset_algorithm(roi_id)
|
||||
del self.algorithms[roi_id]
|
||||
|
||||
def get_status(self, roi_id: str) -> Dict[str, Any]:
|
||||
status = {}
|
||||
if roi_id in self.algorithms:
|
||||
for algo_type, algo in self.algorithms[roi_id].items():
|
||||
status[algo_type] = {
|
||||
"track_count": len(getattr(algo, "track_states", {})),
|
||||
}
|
||||
return status
|
||||
192
inference/stream.py
Normal file
192
inference/stream.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class StreamReader:
|
||||
def __init__(
|
||||
self,
|
||||
camera_id: str,
|
||||
rtsp_url: str,
|
||||
buffer_size: int = 2,
|
||||
reconnect_delay: float = 3.0,
|
||||
timeout: float = 10.0,
|
||||
):
|
||||
self.camera_id = camera_id
|
||||
self.rtsp_url = rtsp_url
|
||||
self.buffer_size = buffer_size
|
||||
self.reconnect_delay = reconnect_delay
|
||||
self.timeout = timeout
|
||||
|
||||
self.cap = None
|
||||
self.running = False
|
||||
self.thread: Optional[threading.Thread] = None
|
||||
self.frame_buffer: deque = deque(maxlen=buffer_size)
|
||||
self.lock = threading.Lock()
|
||||
self.fps = 0.0
|
||||
self.last_frame_time = time.time()
|
||||
self.frame_count = 0
|
||||
|
||||
def _reader_loop(self):
|
||||
while self.running:
|
||||
with self.lock:
|
||||
if self.cap is None or not self.cap.isOpened():
|
||||
self._connect()
|
||||
if self.cap is None:
|
||||
time.sleep(self.reconnect_delay)
|
||||
continue
|
||||
|
||||
ret = False
|
||||
frame = None
|
||||
with self.lock:
|
||||
if self.cap is not None and self.cap.isOpened():
|
||||
ret, frame = self.cap.read()
|
||||
|
||||
if ret and frame is not None:
|
||||
self.frame_buffer.append(frame)
|
||||
self.frame_count += 1
|
||||
current_time = time.time()
|
||||
|
||||
if current_time - self.last_frame_time >= 1.0:
|
||||
self.fps = self.frame_count / (current_time - self.last_frame_time)
|
||||
self.frame_count = 0
|
||||
self.last_frame_time = current_time
|
||||
else:
|
||||
time.sleep(0.1)
|
||||
|
||||
def _connect(self):
|
||||
if self.cap is not None:
|
||||
try:
|
||||
self.cap.release()
|
||||
except Exception:
|
||||
pass
|
||||
self.cap = None
|
||||
|
||||
try:
|
||||
self.cap = cv2.VideoCapture(self.rtsp_url, cv2.CAP_FFMPEG)
|
||||
if self.cap.isOpened():
|
||||
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
print(f"[{self.camera_id}] RTSP连接成功: {self.rtsp_url[:50]}...")
|
||||
else:
|
||||
print(f"[{self.camera_id}] 无法打开视频流")
|
||||
self.cap = None
|
||||
except Exception as e:
|
||||
print(f"[{self.camera_id}] 连接失败: {e}")
|
||||
self.cap = None
|
||||
|
||||
def start(self):
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.thread = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
def stop(self):
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
|
||||
if self.thread is not None:
|
||||
self.thread.join(timeout=5.0)
|
||||
if self.thread.is_alive():
|
||||
print(f"[{self.camera_id}] 警告: 线程未能在5秒内结束")
|
||||
|
||||
with self.lock:
|
||||
if self.cap is not None:
|
||||
try:
|
||||
self.cap.release()
|
||||
except Exception:
|
||||
pass
|
||||
self.cap = None
|
||||
|
||||
self.frame_buffer.clear()
|
||||
|
||||
def read(self) -> Tuple[bool, Optional[np.ndarray]]:
|
||||
if len(self.frame_buffer) > 0:
|
||||
return True, self.frame_buffer.popleft()
|
||||
return False, None
|
||||
|
||||
def get_frame(self) -> Optional[np.ndarray]:
|
||||
if len(self.frame_buffer) > 0:
|
||||
return self.frame_buffer[-1]
|
||||
return None
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
with self.lock:
|
||||
return self.cap is not None and self.cap.isOpened()
|
||||
|
||||
def get_info(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"camera_id": self.camera_id,
|
||||
"running": self.running,
|
||||
"connected": self.is_connected(),
|
||||
"fps": self.fps,
|
||||
"buffer_size": len(self.frame_buffer),
|
||||
"buffer_max": self.buffer_size,
|
||||
}
|
||||
|
||||
|
||||
class StreamManager:
|
||||
def __init__(self, buffer_size: int = 2, reconnect_delay: float = 3.0):
|
||||
self.streams: Dict[str, StreamReader] = {}
|
||||
self.buffer_size = buffer_size
|
||||
self.reconnect_delay = reconnect_delay
|
||||
|
||||
def add_stream(
|
||||
self,
|
||||
camera_id: str,
|
||||
rtsp_url: str,
|
||||
buffer_size: Optional[int] = None,
|
||||
) -> StreamReader:
|
||||
if camera_id in self.streams:
|
||||
self.remove_stream(camera_id)
|
||||
|
||||
stream = StreamReader(
|
||||
camera_id=camera_id,
|
||||
rtsp_url=rtsp_url,
|
||||
buffer_size=buffer_size or self.buffer_size,
|
||||
reconnect_delay=self.reconnect_delay,
|
||||
)
|
||||
stream.start()
|
||||
self.streams[camera_id] = stream
|
||||
return stream
|
||||
|
||||
def remove_stream(self, camera_id: str):
|
||||
if camera_id in self.streams:
|
||||
self.streams[camera_id].stop()
|
||||
del self.streams[camera_id]
|
||||
|
||||
def get_stream(self, camera_id: str) -> Optional[StreamReader]:
|
||||
return self.streams.get(camera_id)
|
||||
|
||||
def read(self, camera_id: str) -> Tuple[bool, Optional[np.ndarray]]:
|
||||
stream = self.get_stream(camera_id)
|
||||
if stream is None:
|
||||
return False, None
|
||||
return stream.read()
|
||||
|
||||
def get_frame(self, camera_id: str) -> Optional[np.ndarray]:
|
||||
stream = self.get_stream(camera_id)
|
||||
if stream is None:
|
||||
return None
|
||||
return stream.get_frame()
|
||||
|
||||
def stop_all(self):
|
||||
for camera_id in list(self.streams.keys()):
|
||||
self.remove_stream(camera_id)
|
||||
|
||||
def get_all_info(self) -> List[Dict[str, Any]]:
|
||||
return [stream.get_info() for stream in self.streams.values()]
|
||||
|
||||
def __contains__(self, camera_id: str) -> bool:
|
||||
return camera_id in self.streams
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.streams)
|
||||
62
logs/app.log
Normal file
62
logs/app.log
Normal file
@@ -0,0 +1,62 @@
|
||||
2026-01-20 16:14:08,029 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 16:14:08,042 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 16:35:00,666 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 16:35:00,681 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 16:35:35,087 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 16:35:35,099 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 16:36:43,838 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 16:36:43,854 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 16:36:44,074 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 16:40:28,288 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 16:40:36,045 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 16:50:54,485 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 16:50:54,499 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 16:51:08,066 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 16:56:03,282 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 16:56:03,295 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 16:56:15,787 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 16:57:32,966 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 16:57:38,237 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 17:01:03,511 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:01:03,525 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:01:16,086 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:02:31,685 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 17:02:38,875 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 17:04:14,608 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:04:14,624 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:04:27,113 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:08:05,883 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 17:08:13,257 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 17:09:38,432 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:09:38,445 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:09:50,944 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:10:54,647 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 17:11:01,985 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 17:11:15,714 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:11:15,732 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:11:28,265 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:14:17,481 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 17:14:22,406 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 17:15:56,666 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:15:56,680 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:16:09,183 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:16:36,221 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:16:36,235 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:16:48,752 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:16:53,218 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 17:17:08,230 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 17:17:13,977 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:17:13,991 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:17:26,500 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:18:25,049 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 17:18:31,834 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 17:18:37,609 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:18:37,623 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:18:50,135 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:19:25,216 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 17:19:25,384 - security_monitor - INFO - 系统已关闭
|
||||
2026-01-20 17:29:49,295 - security_monitor - INFO - 启动安保异常行为识别系统
|
||||
2026-01-20 17:29:49,307 - security_monitor - INFO - 数据库初始化完成
|
||||
2026-01-20 17:30:01,820 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2
|
||||
2026-01-20 17:31:10,482 - security_monitor - INFO - 正在关闭系统...
|
||||
2026-01-20 17:31:10,612 - security_monitor - INFO - 系统已关闭
|
||||
215
main.py
215
main.py
@@ -1,16 +1,211 @@
|
||||
# 这是一个示例 Python 脚本。
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# 按 Shift+F10 执行或将其替换为您的代码。
|
||||
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
|
||||
import cv2
|
||||
import numpy as np
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from api.alarm import router as alarm_router
|
||||
from api.camera import router as camera_router
|
||||
from api.roi import router as roi_router
|
||||
from config import get_config, load_config
|
||||
from db.models import init_db
|
||||
from inference.pipeline import get_pipeline, start_pipeline, stop_pipeline
|
||||
from utils.logger import setup_logger
|
||||
from utils.metrics import get_metrics_server, start_metrics_server, update_system_info
|
||||
|
||||
|
||||
def print_hi(name):
|
||||
# 在下面的代码行中使用断点来调试脚本。
|
||||
print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。
|
||||
logger = None
|
||||
|
||||
|
||||
# 按装订区域中的绿色按钮以运行脚本。
|
||||
if __name__ == '__main__':
|
||||
print_hi('PyCharm')
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global logger
|
||||
|
||||
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助
|
||||
config = get_config()
|
||||
logger = setup_logger(
|
||||
name="security_monitor",
|
||||
log_file=config.logging.file,
|
||||
level=config.logging.level,
|
||||
max_size=config.logging.max_size,
|
||||
backup_count=config.logging.backup_count,
|
||||
)
|
||||
|
||||
logger.info("启动安保异常行为识别系统")
|
||||
|
||||
init_db()
|
||||
logger.info("数据库初始化完成")
|
||||
|
||||
start_metrics_server()
|
||||
update_system_info()
|
||||
|
||||
pipeline = start_pipeline()
|
||||
logger.info(f"推理Pipeline启动,活跃摄像头数: {len(pipeline.camera_threads)}")
|
||||
|
||||
try:
|
||||
yield
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
logger.info("正在关闭系统...")
|
||||
stop_pipeline()
|
||||
logger.info("系统已关闭")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="安保异常行为识别系统",
|
||||
description="基于YOLO和规则算法的智能监控系统",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(camera_router)
|
||||
app.include_router(roi_router)
|
||||
app.include_router(alarm_router)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"name": "安保异常行为识别系统",
|
||||
"version": "1.0.0",
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
pipeline = get_pipeline()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"running": pipeline.running,
|
||||
"cameras": len(pipeline.camera_threads),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/camera/{camera_id}/snapshot")
|
||||
async def get_snapshot(camera_id: int):
|
||||
pipeline = get_pipeline()
|
||||
frame = pipeline.get_latest_frame(camera_id)
|
||||
|
||||
if frame is None:
|
||||
raise HTTPException(status_code=404, detail="无法获取帧")
|
||||
|
||||
_, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
return StreamingResponse(
|
||||
iter([buffer.tobytes()]),
|
||||
media_type="image/jpeg",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/camera/{camera_id}/snapshot/base64")
|
||||
async def get_snapshot_base64(camera_id: int):
|
||||
pipeline = get_pipeline()
|
||||
frame = pipeline.get_latest_frame(camera_id)
|
||||
|
||||
if frame is None:
|
||||
raise HTTPException(status_code=404, detail="无法获取帧")
|
||||
|
||||
_, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
return {"image": base64.b64encode(buffer).decode("utf-8")}
|
||||
|
||||
|
||||
@app.get("/api/camera/{camera_id}/detect")
|
||||
async def get_detection_with_overlay(camera_id: int):
|
||||
pipeline = get_pipeline()
|
||||
frame = pipeline.get_latest_frame(camera_id)
|
||||
|
||||
if frame is None:
|
||||
raise HTTPException(status_code=404, detail="无法获取帧")
|
||||
|
||||
import json
|
||||
from db.crud import get_all_rois
|
||||
from db.models import get_session_factory
|
||||
|
||||
SessionLocal = get_session_factory()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
rois = get_all_rois(db, camera_id)
|
||||
roi_configs = [
|
||||
{
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"name": roi.name,
|
||||
"type": roi.roi_type,
|
||||
"points": json.loads(roi.points),
|
||||
"rule": roi.rule_type,
|
||||
"enabled": roi.enabled,
|
||||
}
|
||||
for roi in rois
|
||||
]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
from utils.helpers import draw_detections
|
||||
|
||||
overlay_frame = draw_detections(frame, [], roi_configs if roi_configs else None)
|
||||
|
||||
_, buffer = cv2.imencode(".jpg", overlay_frame, [cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||
return StreamingResponse(
|
||||
iter([buffer.tobytes()]),
|
||||
media_type="image/jpeg",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/pipeline/status")
|
||||
async def get_pipeline_status():
|
||||
pipeline = get_pipeline()
|
||||
return pipeline.get_status()
|
||||
|
||||
|
||||
@app.get("/api/stream/list")
|
||||
async def list_streams():
|
||||
pipeline = get_pipeline()
|
||||
return {"streams": pipeline.stream_manager.get_all_info()}
|
||||
|
||||
|
||||
@app.post("/api/pipeline/reload")
|
||||
async def reload_pipeline():
|
||||
stop_pipeline()
|
||||
import time
|
||||
time.sleep(1)
|
||||
start_pipeline()
|
||||
return {"message": "Pipeline已重新加载"}
|
||||
|
||||
|
||||
def main():
|
||||
import uvicorn
|
||||
|
||||
config = load_config()
|
||||
|
||||
uvicorn.run(
|
||||
"main:app",
|
||||
host="0.0.0.0",
|
||||
port=8000,
|
||||
reload=False,
|
||||
log_level=config.logging.level.lower(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
BIN
models/yolo11n_fp16_480.engine
Normal file
BIN
models/yolo11n_fp16_480.engine
Normal file
Binary file not shown.
6
package-lock.json
generated
Normal file
6
package-lock.json
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"name": "Detector",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {}
|
||||
}
|
||||
13
prometheus.yml
Normal file
13
prometheus.yml
Normal file
@@ -0,0 +1,13 @@
|
||||
global:
|
||||
scrape_interval: 15s
|
||||
evaluation_interval: 15s
|
||||
|
||||
scrape_configs:
|
||||
- job_name: 'prometheus'
|
||||
static_configs:
|
||||
- targets: ['localhost:9090']
|
||||
|
||||
- job_name: 'security-monitor'
|
||||
static_configs:
|
||||
- targets: ['localhost:9090']
|
||||
metrics_path: /metrics
|
||||
51
requirements.txt
Normal file
51
requirements.txt
Normal file
@@ -0,0 +1,51 @@
|
||||
# 安保异常行为识别系统 - 依赖配置
|
||||
|
||||
# 核心框架
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
click>=8.0.0
|
||||
pydantic==2.5.3
|
||||
pydantic-settings==2.1.0
|
||||
|
||||
# 数据库
|
||||
sqlalchemy==2.0.25
|
||||
alembic==1.13.1
|
||||
pymysql==1.1.0
|
||||
cryptography==42.0.0
|
||||
|
||||
# 推理引擎
|
||||
torch>=2.0.0
|
||||
tensorrt>=10.0.0
|
||||
pycuda>=2023.1
|
||||
ultralytics>=8.1.0
|
||||
numpy>=1.24.0
|
||||
|
||||
# 视频处理
|
||||
opencv-python>=4.8.0
|
||||
|
||||
# 几何处理
|
||||
shapely>=2.0.0
|
||||
|
||||
# 配置管理
|
||||
pyyaml>=6.0.1
|
||||
python-dotenv>=1.0.0
|
||||
|
||||
# 监控与日志
|
||||
prometheus-client==0.19.0
|
||||
python-json-logger==2.0.7
|
||||
|
||||
# 异步支持
|
||||
asyncio>=3.4.3
|
||||
|
||||
# 工具库
|
||||
python-dateutil>=2.8.2
|
||||
pillow>=10.0.0
|
||||
|
||||
# 前端(React打包后静态文件)
|
||||
# 静态文件服务用aiofiles
|
||||
aiofiles>=23.2.1
|
||||
|
||||
# 开发与测试
|
||||
pytest==7.4.4
|
||||
pytest-asyncio==0.23.3
|
||||
httpx==0.26.0
|
||||
@@ -1,52 +0,0 @@
|
||||
import yaml
|
||||
import threading
|
||||
from ultralytics import YOLO
|
||||
import torch
|
||||
from detector import OffDutyCrowdDetector
|
||||
import os
|
||||
|
||||
|
||||
def load_config(config_path="config.yaml"):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def main():
|
||||
config = load_config()
|
||||
|
||||
# 全局模型(共享)
|
||||
model_path = config["model"]["path"]
|
||||
device = config["model"].get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
use_half = (device == "cuda")
|
||||
|
||||
print(f"Loading model {model_path} on {device} (FP16: {use_half})")
|
||||
model = YOLO(model_path)
|
||||
model.to(device)
|
||||
if use_half:
|
||||
model.model.half()
|
||||
|
||||
# 启动每个摄像头的检测线程
|
||||
threads = []
|
||||
for cam_cfg in config["cameras"]:
|
||||
# 合并 common 配置
|
||||
full_cfg = {**config.get("common", {}), **cam_cfg}
|
||||
full_cfg["imgsz"] = config["model"]["imgsz"]
|
||||
full_cfg["conf_thresh"] = config["model"]["conf_thresh"]
|
||||
full_cfg["working_hours"] = config["common"]["working_hours"]
|
||||
|
||||
detector = OffDutyCrowdDetector(full_cfg, model, device, use_half)
|
||||
thread = threading.Thread(target=detector.run, daemon=True)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
print(f"Started detector for {cam_cfg['id']}")
|
||||
|
||||
print(f"✅ 已启动 {len(threads)} 路摄像头检测,按 Ctrl+C 退出")
|
||||
try:
|
||||
for t in threads:
|
||||
t.join()
|
||||
except KeyboardInterrupt:
|
||||
print("\n🛑 Shutting down...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
89
scripts/build_engine.py
Normal file
89
scripts/build_engine.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# TensorRT Engine 生成脚本
|
||||
# 使用方法: python scripts/build_engine.py
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
import torch
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True):
|
||||
"""构建TensorRT引擎"""
|
||||
from tensorrt import Builder, NetworkDefinitionLayer, Runtime
|
||||
from tensorrt.parsers import onnxparser
|
||||
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
|
||||
network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(network_flags)
|
||||
|
||||
parser = onnxparser.create_onnx_parser(network)
|
||||
parser.parse(onnx_path)
|
||||
parser.report_status()
|
||||
|
||||
# 动态形状配置
|
||||
if dynamic_batch:
|
||||
profile = builder.create_optimization_profile()
|
||||
min_shape = (1, 3, 480, 480)
|
||||
opt_shape = (4, 3, 480, 480)
|
||||
max_shape = (8, 3, 480, 480)
|
||||
|
||||
profile.set_shape("input", min_shape, opt_shape, max_shape)
|
||||
network.get_input(0).set_dynamic_range(-1.0, 1.0)
|
||||
network.set_precision_constraints(trt.PrecisionConstraints.PREFER)
|
||||
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_allocator(trt.MemoryAllocator())
|
||||
config.max_workspace_size = 4 << 30 # 4GB
|
||||
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
serialized_engine = builder.build_serialized_network(network, config)
|
||||
|
||||
with open(engine_path, "wb") as f:
|
||||
f.write(serialized_engine)
|
||||
|
||||
print(f"✅ TensorRT引擎已保存: {engine_path}")
|
||||
|
||||
|
||||
def export_onnx(model_path, onnx_path, imgsz=480):
|
||||
"""导出ONNX模型"""
|
||||
model = YOLO(model_path)
|
||||
model.export(
|
||||
format="onnx",
|
||||
imgsz=[imgsz, imgsz],
|
||||
simplify=True,
|
||||
opset=12,
|
||||
dynamic=True,
|
||||
)
|
||||
print(f"✅ ONNX模型已导出: {onnx_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="TensorRT Engine Builder")
|
||||
parser.add_argument("--model", type=str, default="models/yolo11n.pt",
|
||||
help="YOLO模型路径")
|
||||
parser.add_argument("--engine", type=str, default="models/yolo11n_fp16_480.engine",
|
||||
help="输出引擎路径")
|
||||
parser.add_argument("--onnx", type=str, default="models/yolo11n_480.onnx",
|
||||
help="临时ONNX路径")
|
||||
parser.add_argument("--fp16", action="store_true", default=True,
|
||||
help="启用FP16")
|
||||
parser.add_argument("--no-dynamic", action="store_true",
|
||||
help="禁用动态Batch")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(os.path.dirname(args.engine), exist_ok=True)
|
||||
|
||||
if not os.path.exists(args.onnx):
|
||||
export_onnx(args.model, args.onnx)
|
||||
|
||||
build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic)
|
||||
BIN
security_monitor.db
Normal file
BIN
security_monitor.db
Normal file
Binary file not shown.
168
tests/test_core.py
Normal file
168
tests/test_core.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
def test_roi_filter_parse_points():
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
|
||||
filter = ROIFilter()
|
||||
points = filter.parse_points("[[100, 200], [300, 400]]")
|
||||
assert len(points) == 2
|
||||
assert points[0] == (100.0, 200.0)
|
||||
|
||||
|
||||
def test_roi_filter_polygon():
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
|
||||
filter = ROIFilter()
|
||||
points = [(0, 0), (100, 0), (100, 100), (0, 100)]
|
||||
polygon = filter.create_polygon(points)
|
||||
|
||||
assert polygon.area == 10000
|
||||
assert filter.is_point_in_polygon((50, 50), polygon) == True
|
||||
assert filter.is_point_in_polygon((150, 150), polygon) == False
|
||||
|
||||
|
||||
def test_roi_filter_bbox_center():
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
|
||||
filter = ROIFilter()
|
||||
center = filter.get_bbox_center([10, 20, 100, 200])
|
||||
assert center == (55.0, 110.0)
|
||||
|
||||
|
||||
def test_leave_post_algorithm_init():
|
||||
from inference.rules.algorithms import LeavePostAlgorithm
|
||||
|
||||
algo = LeavePostAlgorithm(
|
||||
threshold_sec=300,
|
||||
confirm_sec=30,
|
||||
return_sec=5,
|
||||
)
|
||||
|
||||
assert algo.threshold_sec == 300
|
||||
assert algo.confirm_sec == 30
|
||||
assert algo.return_sec == 5
|
||||
|
||||
|
||||
def test_leave_post_algorithm_process():
|
||||
from inference.rules.algorithms import LeavePostAlgorithm
|
||||
from datetime import datetime
|
||||
|
||||
algo = LeavePostAlgorithm(threshold_sec=360, confirm_sec=30, return_sec=5)
|
||||
|
||||
tracks = [
|
||||
{"bbox": [100, 100, 200, 200], "conf": 0.9, "cls": 0},
|
||||
]
|
||||
|
||||
alerts = algo.process("test_cam", tracks, datetime.now())
|
||||
assert isinstance(alerts, list)
|
||||
|
||||
|
||||
def test_intrusion_algorithm_init():
|
||||
from inference.rules.algorithms import IntrusionAlgorithm
|
||||
|
||||
algo = IntrusionAlgorithm(
|
||||
check_interval_sec=1.0,
|
||||
direction_sensitive=False,
|
||||
)
|
||||
|
||||
assert algo.check_interval_sec == 1.0
|
||||
|
||||
|
||||
def test_algorithm_manager():
|
||||
from inference.rules.algorithms import AlgorithmManager
|
||||
|
||||
manager = AlgorithmManager()
|
||||
|
||||
manager.register_algorithm("roi_1", "leave_post", {"threshold_sec": 300})
|
||||
|
||||
assert "roi_1" in manager.algorithms
|
||||
assert "leave_post" in manager.algorithms["roi_1"]
|
||||
|
||||
|
||||
def test_config_load():
|
||||
from config import load_config, get_config
|
||||
|
||||
config = load_config()
|
||||
|
||||
assert config.database.dialect in ["sqlite", "mysql"]
|
||||
assert config.model.imgsz == [480, 480]
|
||||
assert config.model.batch_size == 8
|
||||
|
||||
|
||||
def test_database_models():
|
||||
from db.models import Camera, ROI, Alarm, init_db
|
||||
|
||||
init_db()
|
||||
|
||||
camera = Camera(
|
||||
name="测试摄像头",
|
||||
rtsp_url="rtsp://test.local/cam1",
|
||||
fps_limit=30,
|
||||
process_every_n_frames=3,
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
assert camera.name == "测试摄像头"
|
||||
assert camera.enabled == True
|
||||
|
||||
|
||||
def test_camera_crud():
|
||||
from db.crud import create_camera, get_camera_by_id
|
||||
from db.models import get_session_factory, init_db
|
||||
|
||||
init_db()
|
||||
SessionLocal = get_session_factory()
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
camera = create_camera(
|
||||
db,
|
||||
name="测试摄像头",
|
||||
rtsp_url="rtsp://test.local/cam1",
|
||||
fps_limit=30,
|
||||
)
|
||||
|
||||
assert camera.id is not None
|
||||
|
||||
fetched = get_camera_by_id(db, camera.id)
|
||||
assert fetched is not None
|
||||
assert fetched.name == "测试摄像头"
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def test_stream_reader_init():
|
||||
from inference.stream import StreamReader
|
||||
|
||||
reader = StreamReader(
|
||||
camera_id="test_cam",
|
||||
rtsp_url="rtsp://test.local/stream",
|
||||
buffer_size=2,
|
||||
)
|
||||
|
||||
assert reader.camera_id == "test_cam"
|
||||
assert reader.buffer_size == 2
|
||||
|
||||
|
||||
def test_utils_helpers():
|
||||
from utils.helpers import draw_bbox, draw_roi, format_duration
|
||||
|
||||
import numpy as np
|
||||
|
||||
image = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
result = draw_bbox(image, [100, 100, 200, 200], "Test")
|
||||
|
||||
assert result.shape == image.shape
|
||||
|
||||
duration = format_duration(125.5)
|
||||
assert "2分" in duration
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
195
utils/helpers.py
Normal file
195
utils/helpers.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def draw_bbox(
|
||||
image: np.ndarray,
|
||||
bbox: List[float],
|
||||
label: str = "",
|
||||
color: Tuple[int, int, int] = (0, 255, 0),
|
||||
thickness: int = 2,
|
||||
) -> np.ndarray:
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox]
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
|
||||
|
||||
if label:
|
||||
(text_width, text_height), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, thickness
|
||||
)
|
||||
cv2.rectangle(
|
||||
image,
|
||||
(x1, y1 - text_height - 10),
|
||||
(x1 + text_width + 10, y1),
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
cv2.putText(
|
||||
image,
|
||||
label,
|
||||
(x1 + 5, y1 - 5),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.5,
|
||||
(0, 0, 0),
|
||||
thickness,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def draw_roi(
|
||||
image: np.ndarray,
|
||||
points: List[List[float]],
|
||||
roi_type: str = "polygon",
|
||||
color: Tuple[int, int, int] = (255, 0, 0),
|
||||
thickness: int = 2,
|
||||
label: str = "",
|
||||
) -> np.ndarray:
|
||||
points = np.array(points, dtype=np.int32)
|
||||
|
||||
if roi_type == "polygon":
|
||||
cv2.polylines(image, [points], True, color, thickness)
|
||||
cv2.fillPoly(image, [points], color=(color[0], color[1], color[2], 30))
|
||||
elif roi_type == "line":
|
||||
cv2.line(image, tuple(points[0]), tuple(points[1]), color, thickness)
|
||||
elif roi_type == "rectangle":
|
||||
x1, y1 = points[0]
|
||||
x2, y2 = points[1]
|
||||
cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
|
||||
|
||||
if label:
|
||||
cx = int(np.mean(points[:, 0]))
|
||||
cy = int(np.mean(points[:, 1]))
|
||||
cv2.putText(
|
||||
image,
|
||||
label,
|
||||
(cx, cy),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.7,
|
||||
color,
|
||||
thickness,
|
||||
)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def draw_detections(
|
||||
image: np.ndarray,
|
||||
detections: List[Dict[str, Any]],
|
||||
roi_configs: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> np.ndarray:
|
||||
result = image.copy()
|
||||
|
||||
for detection in detections:
|
||||
bbox = detection.get("bbox", [])
|
||||
conf = detection.get("conf", 0.0)
|
||||
cls = detection.get("cls", 0)
|
||||
|
||||
label = f"Person: {conf:.2f}"
|
||||
color = (0, 255, 0)
|
||||
|
||||
if roi_configs:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
for roi_conf in matched_rois:
|
||||
if roi_conf.get("enabled", True):
|
||||
roi_points = roi_conf.get("points", [])
|
||||
roi_type = roi_conf.get("type", "polygon")
|
||||
roi_label = roi_conf.get("name", "")
|
||||
roi_color = (255, 0, 0) if roi_conf.get("rule") == "intrusion" else (0, 165, 255)
|
||||
result = draw_roi(result, roi_points, roi_type, roi_color, 2, roi_label)
|
||||
|
||||
result = draw_bbox(result, bbox, label, color, 2)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
max_size: int = 480,
|
||||
maintain_aspect: bool = True,
|
||||
) -> np.ndarray:
|
||||
h, w = image.shape[:2]
|
||||
|
||||
if maintain_aspect:
|
||||
scale = min(max_size / h, max_size / w)
|
||||
new_h, new_w = int(h * scale), int(w * scale)
|
||||
else:
|
||||
new_h, new_w = max_size, max_size
|
||||
scale = max_size / max(h, w)
|
||||
|
||||
result = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
return result, scale
|
||||
|
||||
|
||||
def encode_image_base64(image: np.ndarray, quality: int = 85) -> str:
|
||||
_, buffer = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality])
|
||||
return base64.b64encode(buffer).decode("utf-8")
|
||||
|
||||
|
||||
def decode_image_base64(data: str) -> np.ndarray:
|
||||
import base64
|
||||
buffer = base64.b64decode(data)
|
||||
nparr = np.frombuffer(buffer, np.uint8)
|
||||
return cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
|
||||
|
||||
def get_timestamp_str() -> str:
|
||||
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{int(seconds)}秒"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds // 60)
|
||||
secs = int(seconds % 60)
|
||||
return f"{minutes}分{secs}秒"
|
||||
else:
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
return f"{hours}小时{minutes}分"
|
||||
|
||||
|
||||
def is_time_in_ranges(
|
||||
current_time: datetime,
|
||||
time_ranges: List[Dict[str, List[int]]],
|
||||
) -> bool:
|
||||
if not time_ranges:
|
||||
return True
|
||||
|
||||
current_minutes = current_time.hour * 60 + current_time.minute
|
||||
|
||||
for time_range in time_ranges:
|
||||
start = time_range["start"]
|
||||
end = time_range["end"]
|
||||
start_minutes = start[0] * 60 + start[1]
|
||||
end_minutes = end[0] * 60 + end[1]
|
||||
|
||||
if start_minutes <= current_minutes < end_minutes:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class FPSCounter:
|
||||
def __init__(self, window_size: int = 30):
|
||||
self.timestamps: List[float] = []
|
||||
self.window_size = window_size
|
||||
|
||||
def update(self):
|
||||
import time
|
||||
now = time.time()
|
||||
self.timestamps.append(now)
|
||||
self.timestamps = self.timestamps[-self.window_size:]
|
||||
|
||||
@property
|
||||
def fps(self) -> float:
|
||||
if len(self.timestamps) < 2:
|
||||
return 0.0
|
||||
elapsed = self.timestamps[-1] - self.timestamps[0]
|
||||
if elapsed <= 0:
|
||||
return 0.0
|
||||
return (len(self.timestamps) - 1) / elapsed
|
||||
114
utils/logger.py
Normal file
114
utils/logger.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from typing import Optional
|
||||
|
||||
from pythonjsonlogger import jsonlogger
|
||||
|
||||
|
||||
def setup_logger(
|
||||
name: str = "security_monitor",
|
||||
log_file: str = "logs/app.log",
|
||||
level: str = "INFO",
|
||||
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
max_size: str = "100MB",
|
||||
backup_count: int = 5,
|
||||
json_format: bool = False,
|
||||
) -> logging.Logger:
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
||||
|
||||
if logger.handlers:
|
||||
for handler in logger.handlers:
|
||||
logger.removeHandler(handler)
|
||||
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
|
||||
if json_format:
|
||||
handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=int(max_size.replace("MB", "")) * 1024 * 1024,
|
||||
backupCount=backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
formatter = jsonlogger.JsonFormatter(
|
||||
"%(asctime)s %(name)s %(levelname)s %(message)s",
|
||||
rename_fields={"levelname": "severity", "asctime": "timestamp"},
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
else:
|
||||
handler = RotatingFileHandler(
|
||||
log_file,
|
||||
maxBytes=int(max_size.replace("MB", "")) * 1024 * 1024,
|
||||
backupCount=backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
formatter = logging.Formatter(format)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger.addHandler(handler)
|
||||
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setFormatter(logging.Formatter(format))
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str = "security_monitor") -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
class LoggerMixin:
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
return get_logger(self.__class__.__name__)
|
||||
|
||||
|
||||
def log_execution_time(logger: Optional[logging.Logger] = None):
|
||||
def decorator(func):
|
||||
import time
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
log = logger or get_logger(func.__module__)
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
log.info(
|
||||
f"函数执行完成",
|
||||
extra={
|
||||
"function": func.__name__,
|
||||
"execution_time_ms": int((end_time - start_time) * 1000),
|
||||
},
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def log_function_call(logger: Optional[logging.Logger] = None):
|
||||
def decorator(func):
|
||||
from functools import wraps
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
log = logger or get_logger(func.__module__)
|
||||
log.info(
|
||||
f"函数调用",
|
||||
extra={
|
||||
"function": func.__name__,
|
||||
"args": str(args),
|
||||
"kwargs": str(kwargs),
|
||||
},
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
113
utils/metrics.py
Normal file
113
utils/metrics.py
Normal file
@@ -0,0 +1,113 @@
|
||||
from typing import Optional
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram, Info, start_http_server
|
||||
|
||||
from config import get_config
|
||||
|
||||
SYSTEM_INFO = Info("system", "System information")
|
||||
|
||||
CAMERA_COUNT = Gauge("camera_count", "Number of active cameras")
|
||||
|
||||
CAMERA_FPS = Gauge("camera_fps", "Camera FPS", ["camera_id"])
|
||||
|
||||
INFERENCE_LATENCY = Histogram(
|
||||
"inference_latency_seconds",
|
||||
"Inference latency in seconds",
|
||||
["camera_id"],
|
||||
buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0],
|
||||
)
|
||||
|
||||
ALERT_COUNT = Counter(
|
||||
"alert_total",
|
||||
"Total number of alerts",
|
||||
["camera_id", "event_type"],
|
||||
)
|
||||
|
||||
EVENT_QUEUE_SIZE = Gauge(
|
||||
"event_queue_size",
|
||||
"Current size of event queue",
|
||||
)
|
||||
|
||||
DETECTION_COUNT = Counter(
|
||||
"detection_total",
|
||||
"Total number of detections",
|
||||
["camera_id", "roi_id"],
|
||||
)
|
||||
|
||||
GPU_MEMORY_USED = Gauge(
|
||||
"gpu_memory_used_bytes",
|
||||
"GPU memory used",
|
||||
["device"],
|
||||
)
|
||||
|
||||
GPU_UTILIZATION = Gauge(
|
||||
"gpu_utilization_percent",
|
||||
"GPU utilization",
|
||||
["device"],
|
||||
)
|
||||
|
||||
|
||||
class MetricsServer:
|
||||
def __init__(self, port: int = 9090):
|
||||
self.port = port
|
||||
self.started = False
|
||||
|
||||
def start(self):
|
||||
if self.started:
|
||||
return
|
||||
|
||||
config = get_config()
|
||||
if not config.monitoring.enabled:
|
||||
return
|
||||
|
||||
start_http_server(self.port)
|
||||
self.started = True
|
||||
print(f"Prometheus metrics server started on port {self.port}")
|
||||
|
||||
def update_camera_metrics(self, camera_id: int, fps: float):
|
||||
CAMERA_FPS.labels(camera_id=str(camera_id)).set(fps)
|
||||
|
||||
def record_inference(self, camera_id: int, latency: float):
|
||||
INFERENCE_LATENCY.labels(camera_id=str(camera_id)).observe(latency)
|
||||
|
||||
def record_alert(self, camera_id: int, event_type: str):
|
||||
ALERT_COUNT.labels(camera_id=str(camera_id), event_type=event_type).inc()
|
||||
|
||||
def update_event_queue(self, size: int):
|
||||
EVENT_QUEUE_SIZE.set(size)
|
||||
|
||||
def record_detection(self, camera_id: int, roi_id: str):
|
||||
DETECTION_COUNT.labels(camera_id=str(camera_id), roi_id=roi_id).inc()
|
||||
|
||||
def update_gpu_metrics(self, device: int, memory_bytes: float, utilization: float):
|
||||
GPU_MEMORY_USED.labels(device=str(device)).set(memory_bytes)
|
||||
GPU_UTILIZATION.labels(device=str(device)).set(utilization)
|
||||
|
||||
|
||||
_metrics_server: Optional[MetricsServer] = None
|
||||
|
||||
|
||||
def get_metrics_server() -> MetricsServer:
|
||||
global _metrics_server
|
||||
if _metrics_server is None:
|
||||
config = get_config()
|
||||
_metrics_server = MetricsServer(port=config.monitoring.port)
|
||||
return _metrics_server
|
||||
|
||||
|
||||
def start_metrics_server():
|
||||
server = get_metrics_server()
|
||||
server.start()
|
||||
|
||||
|
||||
def update_system_info():
|
||||
import platform
|
||||
import psutil
|
||||
|
||||
SYSTEM_INFO.info({
|
||||
"os": platform.system(),
|
||||
"os_version": platform.version(),
|
||||
"python_version": platform.python_version(),
|
||||
"cpu_count": str(psutil.cpu_count()),
|
||||
"memory_total_gb": str(round(psutil.virtual_memory().total / (1024**3), 2)),
|
||||
})
|
||||
BIN
yolo11n.onnx
Normal file
BIN
yolo11n.onnx
Normal file
Binary file not shown.
Reference in New Issue
Block a user