diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..aa74d34 --- /dev/null +++ b/.env.example @@ -0,0 +1,18 @@ +# Redis 配置(可选,用于分布式部署) + +# 连接配置 +REDIS_HOST=localhost +REDIS_PORT=6379 +REDIS_DB=0 + +# 事件队列配置 +REDIS_EVENT_QUEUE=security_events +REDIS_ALARM_QUEUE=security_alarms + +# 大模型 API 配置 +LLM_API_KEY=your-api-key +LLM_BASE_URL=https://api.openai.com/v1 +LLM_MODEL=gpt-4-vision-preview + +# 其他配置 +LOG_LEVEL=INFO diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..78c0f96 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,39 @@ +FROM nvidia/cuda:12.2-devel-ubuntu22.04 + +LABEL maintainer="security-monitor" + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update && apt-get install -y \ + python3.10 \ + python3.10-dev \ + python3-pip \ + python3-venv \ + git \ + wget \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender1 \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +RUN python3.10 -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +COPY requirements.txt . +RUN pip install --no-cache-dir --upgrade pip wheel && \ + pip install --no-cache-dir -r requirements.txt + +COPY . . + +RUN mkdir -p /app/models /app/data /app/logs + +EXPOSE 8000 9090 + +ENV PYTHONUNBUFFERED=1 + +CMD ["python", "main.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..49c0f0b --- /dev/null +++ b/README.md @@ -0,0 +1,136 @@ +# 安保异常行为识别系统 + +## 项目概述 + +基于 YOLO + TensorRT 的安保异常行为识别系统,支持离岗检测和周界入侵检测。 + +## 快速开始 + +### 1. 安装依赖 + +```bash +pip install -r requirements.txt +``` + +### 2. 生成 TensorRT 引擎 + +```bash +python scripts/build_engine.py --model models/yolo11n.pt --fp16 +``` + +### 3. 配置数据库 + +编辑 `config.yaml` 配置数据库连接。 + +### 4. 启动服务 + +```bash +python main.py +``` + +### 5. 访问前端 + +打开浏览器访问 `http://localhost:3000` + +## 目录结构 + +``` +project_root/ +├── main.py # FastAPI 入口 +├── config.yaml # 配置文件 +├── requirements.txt # Python 依赖 +├── inference/ +│ ├── engine.py # TensorRT 引擎封装 +│ ├── stream.py # RTSP 流处理 +│ ├── pipeline.py # 推理主流程 +│ ├── roi/ +│ │ └── roi_filter.py # ROI 过滤 +│ └── rules/ +│ └── algorithms.py # 规则算法 +├── db/ +│ ├── models.py # SQLAlchemy ORM 模型 +│ └── crud.py # 数据库操作 +├── api/ +│ ├── camera.py # 摄像头管理接口 +│ ├── roi.py # ROI 管理接口 +│ └── alarm.py # 告警管理接口 +├── utils/ +│ ├── logger.py # 日志工具 +│ ├── helpers.py # 辅助函数 +│ └── metrics.py # Prometheus 监控 +├── frontend/ # React 前端 +├── scripts/ +│ └── build_engine.py # TensorRT 引擎构建脚本 +├── tests/ # 单元测试 +└── Dockerfile # Docker 配置 +``` + +## API 接口 + +### 摄像头管理 + +- `GET /api/cameras` - 获取摄像头列表 +- `POST /api/cameras` - 添加摄像头 +- `PUT /api/cameras/{id}` - 更新摄像头 +- `DELETE /api/cameras/{id}` - 删除摄像头 + +### ROI 管理 + +- `GET /api/camera/{id}/rois` - 获取摄像头 ROI 列表 +- `POST /api/camera/{id}/roi` - 添加 ROI +- `PUT /api/camera/{id}/roi/{roi_id}` - 更新 ROI +- `DELETE /api/camera/{id}/roi/{roi_id}` - 删除 ROI + +### 告警管理 + +- `GET /api/alarms` - 获取告警列表 +- `GET /api/alarms/stats` - 获取告警统计 +- `PUT /api/alarms/{id}` - 更新告警状态 +- `POST /api/alarms/{id}/llm-check` - 触发大模型检查 + +### 其他接口 + +- `GET /api/camera/{id}/snapshot` - 获取实时截图 +- `GET /api/camera/{id}/detect` - 获取带检测框的截图 +- `GET /api/pipeline/status` - 获取 Pipeline 状态 +- `GET /health` - 健康检查 + +## 配置说明 + +### config.yaml + +```yaml +database: + dialect: "sqlite" # sqlite 或 mysql + name: "security_monitor" + +model: + engine_path: "models/yolo11n_fp16_480.engine" + imgsz: [480, 480] + batch_size: 8 + half: true + +stream: + buffer_size: 2 + reconnect_delay: 3.0 + +alert: + snapshot_path: "data/alerts" + cooldown_sec: 300 +``` + +## Docker 部署 + +```bash +docker-compose up -d +``` + +## 监控指标 + +系统暴露 Prometheus 格式的监控指标: + +- `camera_count` - 活跃摄像头数量 +- `camera_fps{camera_id="*"}` - 各摄像头 FPS +- `inference_latency_seconds{camera_id="*"}` - 推理延迟 +- `alert_total{camera_id="*", event_type="*"}` - 告警总数 +- `event_queue_size` - 事件队列大小 diff --git a/TRT_BUILD.md b/TRT_BUILD.md new file mode 100644 index 0000000..a9bc44c --- /dev/null +++ b/TRT_BUILD.md @@ -0,0 +1,38 @@ +# TensorRT Engine 生成 +# 使用 trtexec 命令行工具 + +# 1. 导出 ONNX 模型 +trtexec \ + --onnx=yolo11n_480.onnx \ + --saveEngine=yolo11n_fp16_480.engine \ + --fp16 \ + --minShapes=input:1x3x480x480 \ + --optShapes=input:4x3x480x480 \ + --maxShapes=input:8x3x480x480 \ + --workspace=4096 \ + --verbose + +# 或者使用优化器设置 +trtexec \ + --onnx=yolo11n_480.onnx \ + --saveEngine=yolo11n_fp16_480.engine \ + --fp16 \ + --optShapes=input:4x3x480x480 \ + --workspace=4096 \ + --builderOptimizationLevel=5 \ + --refit=False \ + --sparsity=False + +# INT8 量化(需要校准数据) +trtexec \ + --onnx=yolo11n_480.onnx \ + --saveEngine=yolo11n_int8_480.engine \ + --int8 \ + --calib=calibration.bin \ + --minShapes=input:1x3x480x480 \ + --optShapes=input:4x3x480x480 \ + --maxShapes=input:8x3x480x480 \ + --workspace=4096 + +# 验证引擎 +trtexec --loadEngine=yolo11n_fp16_480.engine --dumpOutput diff --git a/api/alarm.py b/api/alarm.py new file mode 100644 index 0000000..eeb362a --- /dev/null +++ b/api/alarm.py @@ -0,0 +1,142 @@ +from datetime import datetime +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.orm import Session + +from db.crud import ( + create_alarm, + get_alarm_stats, + get_alarms, + update_alarm, +) +from db.models import get_db +from inference.pipeline import get_pipeline + +router = APIRouter(prefix="/api/alarms", tags=["告警管理"]) + + +@router.get("", response_model=List[dict]) +def list_alarms( + camera_id: Optional[int] = None, + event_type: Optional[str] = None, + limit: int = Query(default=100, le=1000), + offset: int = Query(default=0, ge=0), + db: Session = Depends(get_db), +): + alarms = get_alarms(db, camera_id=camera_id, event_type=event_type, limit=limit, offset=offset) + return [ + { + "id": alarm.id, + "camera_id": alarm.camera_id, + "roi_id": alarm.roi_id, + "event_type": alarm.event_type, + "confidence": alarm.confidence, + "snapshot_path": alarm.snapshot_path, + "llm_checked": alarm.llm_checked, + "llm_result": alarm.llm_result, + "processed": alarm.processed, + "created_at": alarm.created_at.isoformat() if alarm.created_at else None, + } + for alarm in alarms + ] + + +@router.get("/stats") +def get_alarm_statistics(db: Session = Depends(get_db)): + stats = get_alarm_stats(db) + return stats + + +@router.get("/{alarm_id}", response_model=dict) +def get_alarm(alarm_id: int, db: Session = Depends(get_db)): + from db.crud import get_alarms + alarms = get_alarms(db, limit=1) + alarm = next((a for a in alarms if a.id == alarm_id), None) + if not alarm: + raise HTTPException(status_code=404, detail="告警不存在") + return { + "id": alarm.id, + "camera_id": alarm.camera_id, + "roi_id": alarm.roi_id, + "event_type": alarm.event_type, + "confidence": alarm.confidence, + "snapshot_path": alarm.snapshot_path, + "llm_checked": alarm.llm_checked, + "llm_result": alarm.llm_result, + "processed": alarm.processed, + "created_at": alarm.created_at.isoformat() if alarm.created_at else None, + } + + +@router.put("/{alarm_id}") +def update_alarm_status( + alarm_id: int, + llm_checked: Optional[bool] = None, + llm_result: Optional[str] = None, + processed: Optional[bool] = None, + db: Session = Depends(get_db), +): + alarm = update_alarm(db, alarm_id, llm_checked=llm_checked, llm_result=llm_result, processed=processed) + if not alarm: + raise HTTPException(status_code=404, detail="告警不存在") + return {"message": "更新成功"} + + +@router.post("/{alarm_id}/llm-check") +async def trigger_llm_check(alarm_id: int, db: Session = Depends(get_db)): + from db.crud import get_alarms + alarms = get_alarms(db, limit=1) + alarm = next((a for a in alarms if a.id == alarm_id), None) + if not alarm: + raise HTTPException(status_code=404, detail="告警不存在") + + if not alarm.snapshot_path or not os.path.exists(alarm.snapshot_path): + raise HTTPException(status_code=400, detail="截图不存在") + + try: + from config import get_config + config = get_config() + if not config.llm.enabled: + raise HTTPException(status_code=400, detail="大模型功能未启用") + + import base64 + with open(alarm.snapshot_path, "rb") as f: + img_base64 = base64.b64encode(f.read()).decode("utf-8") + + from openai import OpenAI + client = OpenAI( + api_key=config.llm.api_key, + base_url=config.llm.base_url, + ) + + prompt = """分析这张监控截图,判断是否存在异常行为。请简要说明: +1. 画面中是否有人 +2. 人员位置和行为 +3. 是否存在异常""" + + response = client.chat.completions.create( + model=config.llm.model, + messages=[ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}}, + {"type": "text", "text": prompt}, + ], + } + ], + ) + + result = response.choices[0].message.content + update_alarm(db, alarm_id, llm_checked=True, llm_result=result) + + return {"message": "大模型分析完成", "result": result} + except Exception as e: + raise HTTPException(status_code=500, detail=f"大模型调用失败: {str(e)}") + + +@router.get("/queue/size") +def get_event_queue_size(): + pipeline = get_pipeline() + return {"size": len(pipeline.event_queue), "max_size": pipeline.config.inference.event_queue_maxlen} diff --git a/api/camera.py b/api/camera.py new file mode 100644 index 0000000..c423a09 --- /dev/null +++ b/api/camera.py @@ -0,0 +1,160 @@ +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session + +from db.crud import ( + create_camera, + delete_camera, + get_all_cameras, + get_camera_by_id, + update_camera, +) +from db.models import get_db +from inference.pipeline import get_pipeline + +router = APIRouter(prefix="/api/cameras", tags=["摄像头管理"]) + + +@router.get("", response_model=List[dict]) +def list_cameras( + enabled_only: bool = True, + db: Session = Depends(get_db), +): + cameras = get_all_cameras(db, enabled_only=enabled_only) + return [ + { + "id": cam.id, + "name": cam.name, + "rtsp_url": cam.rtsp_url, + "enabled": cam.enabled, + "fps_limit": cam.fps_limit, + "process_every_n_frames": cam.process_every_n_frames, + "created_at": cam.created_at.isoformat() if cam.created_at else None, + } + for cam in cameras + ] + + +@router.get("/{camera_id}", response_model=dict) +def get_camera(camera_id: int, db: Session = Depends(get_db)): + camera = get_camera_by_id(db, camera_id) + if not camera: + raise HTTPException(status_code=404, detail="摄像头不存在") + return { + "id": camera.id, + "name": camera.name, + "rtsp_url": camera.rtsp_url, + "enabled": camera.enabled, + "fps_limit": camera.fps_limit, + "process_every_n_frames": camera.process_every_n_frames, + "created_at": camera.created_at.isoformat() if camera.created_at else None, + } + + +@router.post("", response_model=dict) +def add_camera( + name: str, + rtsp_url: str, + fps_limit: int = 30, + process_every_n_frames: int = 3, + db: Session = Depends(get_db), +): + camera = create_camera( + db, + name=name, + rtsp_url=rtsp_url, + fps_limit=fps_limit, + process_every_n_frames=process_every_n_frames, + ) + + if camera.enabled: + pipeline = get_pipeline() + pipeline.add_camera(camera) + + return { + "id": camera.id, + "name": camera.name, + "rtsp_url": camera.rtsp_url, + "enabled": camera.enabled, + } + + +@router.put("/{camera_id}", response_model=dict) +def modify_camera( + camera_id: int, + name: Optional[str] = None, + rtsp_url: Optional[str] = None, + fps_limit: Optional[int] = None, + process_every_n_frames: Optional[int] = None, + enabled: Optional[bool] = None, + db: Session = Depends(get_db), +): + camera = update_camera( + db, + camera_id=camera_id, + name=name, + rtsp_url=rtsp_url, + fps_limit=fps_limit, + process_every_n_frames=process_every_n_frames, + enabled=enabled, + ) + if not camera: + raise HTTPException(status_code=404, detail="摄像头不存在") + + pipeline = get_pipeline() + if enabled is True: + pipeline.add_camera(camera) + elif enabled is False: + pipeline.remove_camera(camera_id) + + return { + "id": camera.id, + "name": camera.name, + "rtsp_url": camera.rtsp_url, + "enabled": camera.enabled, + } + + +@router.delete("/{camera_id}") +def remove_camera(camera_id: int, db: Session = Depends(get_db)): + pipeline = get_pipeline() + pipeline.remove_camera(camera_id) + + if not delete_camera(db, camera_id): + raise HTTPException(status_code=404, detail="摄像头不存在") + + return {"message": "删除成功"} + + +@router.get("/{camera_id}/status") +def get_camera_status(camera_id: int, db: Session = Depends(get_db)): + from db.crud import get_camera_status as get_status + + status = get_status(db, camera_id) + pipeline = get_pipeline() + + stream = pipeline.stream_manager.get_stream(str(camera_id)) + if stream: + stream_info = stream.get_info() + else: + stream_info = None + + if status: + return { + "camera_id": camera_id, + "is_running": status.is_running, + "fps": status.fps, + "error_message": status.error_message, + "last_check_time": status.last_check_time.isoformat() if status.last_check_time else None, + "stream": stream_info, + } + else: + return { + "camera_id": camera_id, + "is_running": False, + "fps": 0.0, + "error_message": None, + "last_check_time": None, + "stream": stream_info, + } diff --git a/api/roi.py b/api/roi.py new file mode 100644 index 0000000..5a5cbc9 --- /dev/null +++ b/api/roi.py @@ -0,0 +1,192 @@ +import json +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy.orm import Session + +from db.crud import ( + create_roi, + delete_roi, + get_all_rois, + get_roi_by_id, + get_roi_points, + update_roi, +) +from db.models import get_db +from inference.pipeline import get_pipeline +from inference.roi.roi_filter import ROIFilter + +router = APIRouter(prefix="/api/camera", tags=["ROI管理"]) + + +@router.get("/{camera_id}/rois", response_model=List[dict]) +def list_rois(camera_id: int, db: Session = Depends(get_db)): + roi_configs = get_all_rois(db, camera_id) + return [ + { + "id": roi.id, + "roi_id": roi.roi_id, + "name": roi.name, + "type": roi.roi_type, + "points": json.loads(roi.points), + "rule": roi.rule_type, + "direction": roi.direction, + "stay_time": roi.stay_time, + "enabled": roi.enabled, + "threshold_sec": roi.threshold_sec, + "confirm_sec": roi.confirm_sec, + "return_sec": roi.return_sec, + } + for roi in roi_configs + ] + + +@router.get("/{camera_id}/roi/{roi_id}", response_model=dict) +def get_roi(camera_id: int, roi_id: int, db: Session = Depends(get_db)): + roi = get_roi_by_id(db, roi_id) + if not roi: + raise HTTPException(status_code=404, detail="ROI不存在") + return { + "id": roi.id, + "roi_id": roi.roi_id, + "name": roi.name, + "type": roi.roi_type, + "points": json.loads(roi.points), + "rule": roi.rule_type, + "direction": roi.direction, + "stay_time": roi.stay_time, + "enabled": roi.enabled, + "threshold_sec": roi.threshold_sec, + "confirm_sec": roi.confirm_sec, + "return_sec": roi.return_sec, + } + + +@router.post("/{camera_id}/roi", response_model=dict) +def add_roi( + camera_id: int, + roi_id: str, + name: str, + roi_type: str, + points: List[List[float]], + rule_type: str, + direction: Optional[str] = None, + stay_time: Optional[int] = None, + threshold_sec: int = 360, + confirm_sec: int = 30, + return_sec: int = 5, + db: Session = Depends(get_db), +): + roi = create_roi( + db, + camera_id=camera_id, + roi_id=roi_id, + name=name, + roi_type=roi_type, + points=points, + rule_type=rule_type, + direction=direction, + stay_time=stay_time, + threshold_sec=threshold_sec, + confirm_sec=confirm_sec, + return_sec=return_sec, + ) + + pipeline = get_pipeline() + pipeline.roi_filter.update_cache(camera_id, get_roi_points(db, camera_id)) + + return { + "id": roi.id, + "roi_id": roi.roi_id, + "name": roi.name, + "type": roi.roi_type, + "points": points, + "rule": roi.rule_type, + "enabled": roi.enabled, + } + + +@router.put("/{camera_id}/roi/{roi_id}", response_model=dict) +def modify_roi( + camera_id: int, + roi_id: int, + name: Optional[str] = None, + points: Optional[List[List[float]]] = None, + rule_type: Optional[str] = None, + direction: Optional[str] = None, + stay_time: Optional[int] = None, + enabled: Optional[bool] = None, + threshold_sec: Optional[int] = None, + confirm_sec: Optional[int] = None, + return_sec: Optional[int] = None, + db: Session = Depends(get_db), +): + roi = update_roi( + db, + roi_id=roi_id, + name=name, + points=points, + rule_type=rule_type, + direction=direction, + stay_time=stay_time, + enabled=enabled, + threshold_sec=threshold_sec, + confirm_sec=confirm_sec, + return_sec=return_sec, + ) + if not roi: + raise HTTPException(status_code=404, detail="ROI不存在") + + pipeline = get_pipeline() + pipeline.roi_filter.update_cache(camera_id, get_roi_points(db, camera_id)) + + return { + "id": roi.id, + "roi_id": roi.roi_id, + "name": roi.name, + "type": roi.roi_type, + "points": json.loads(roi.points), + "rule": roi.rule_type, + "enabled": roi.enabled, + } + + +@router.delete("/{camera_id}/roi/{roi_id}") +def remove_roi(camera_id: int, roi_id: int, db: Session = Depends(get_db)): + if not delete_roi(db, roi_id): + raise HTTPException(status_code=404, detail="ROI不存在") + + pipeline = get_pipeline() + pipeline.roi_filter.update_cache(camera_id, get_roi_points(db, camera_id)) + + return {"message": "删除成功"} + + +@router.get("/{camera_id}/roi/validate") +def validate_roi( + camera_id: int, + roi_type: str, + points: List[List[float]], + db: Session = Depends(get_db), +): + try: + if roi_type == "polygon": + from shapely.geometry import Polygon + poly = Polygon(points) + is_valid = poly.is_valid + area = poly.area + elif roi_type == "line": + from shapely.geometry import LineString + line = LineString(points) + is_valid = True + area = 0.0 + else: + raise HTTPException(status_code=400, detail="不支持的ROI类型") + + return { + "valid": is_valid, + "area": area, + "message": "有效" if is_valid else "无效:自交或自重叠", + } + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/config.py b/config.py new file mode 100644 index 0000000..faed4e2 --- /dev/null +++ b/config.py @@ -0,0 +1,149 @@ +import os +from pathlib import Path +from typing import Any, Dict, List, Optional + +import yaml +from pydantic import BaseModel, Field + + +class DatabaseConfig(BaseModel): + dialect: str = "sqlite" + host: str = "localhost" + port: int = 3306 + username: str = "root" + password: str = "password" + name: str = "security_monitor" + echo: bool = False + + @property + def url(self) -> str: + if self.dialect == "sqlite": + return f"sqlite:///{self.name}.db" + return f"mysql+pymysql://{self.username}:{self.password}@{self.host}:{self.port}/{self.name}" + + +class ModelConfig(BaseModel): + engine_path: str = "models/yolo11n_fp16_480.engine" + pt_model_path: str = "models/yolo11n.pt" + imgsz: List[int] = [480, 480] + conf_threshold: float = 0.5 + iou_threshold: float = 0.45 + device: int = 0 + batch_size: int = 8 + half: bool = True + + +class StreamConfig(BaseModel): + buffer_size: int = 2 + reconnect_delay: float = 3.0 + timeout: float = 10.0 + fps_limit: int = 30 + + +class InferenceConfig(BaseModel): + queue_maxlen: int = 100 + event_queue_maxlen: int = 1000 + worker_threads: int = 4 + + +class AlertConfig(BaseModel): + snapshot_path: str = "data/alerts" + cooldown_sec: int = 300 + image_quality: int = 85 + + +class ROIConfig(BaseModel): + default_types: List[str] = ["polygon", "line"] + max_points: int = 50 + + +class WorkingHours(BaseModel): + start: List[int] = Field(default_factory=lambda: [8, 30]) + end: List[int] = Field(default_factory=lambda: [17, 30]) + + +class AlgorithmsConfig(BaseModel): + leave_post_threshold_sec: int = 360 + leave_post_confirm_sec: int = 30 + leave_post_return_sec: int = 5 + intrusion_check_interval_sec: float = 1.0 + intrusion_direction_sensitive: bool = False + + +class LoggingConfig(BaseModel): + level: str = "INFO" + format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + file: str = "logs/app.log" + max_size: str = "100MB" + backup_count: int = 5 + + +class MonitoringConfig(BaseModel): + enabled: bool = True + port: int = 9090 + path: str = "/metrics" + + +class LLMConfig(BaseModel): + enabled: bool = False + api_key: str = "" + base_url: str = "" + model: str = "qwen-vl-max" + timeout: int = 30 + + +class Config(BaseModel): + database: DatabaseConfig = Field(default_factory=DatabaseConfig) + model: ModelConfig = Field(default_factory=ModelConfig) + stream: StreamConfig = Field(default_factory=StreamConfig) + inference: InferenceConfig = Field(default_factory=InferenceConfig) + alert: AlertConfig = Field(default_factory=AlertConfig) + roi: ROIConfig = Field(default_factory=ROIConfig) + working_hours: List[WorkingHours] = Field(default_factory=lambda: [ + WorkingHours(start=[8, 30], end=[11, 0]), + WorkingHours(start=[12, 0], end=[17, 30]) + ]) + algorithms: AlgorithmsConfig = Field(default_factory=AlgorithmsConfig) + logging: LoggingConfig = Field(default_factory=LoggingConfig) + monitoring: MonitoringConfig = Field(default_factory=MonitoringConfig) + llm: LLMConfig = Field(default_factory=LLMConfig) + + +_config: Optional[Config] = None + + +def load_config(config_path: Optional[str] = None) -> Config: + """加载配置文件""" + global _config + + if _config is not None: + return _config + + if config_path is None: + config_path = os.environ.get( + "CONFIG_PATH", + str(Path(__file__).parent / "config.yaml") + ) + + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + yaml_config = yaml.safe_load(f) or {} + else: + yaml_config = {} + + _config = Config(**yaml_config) + return _config + + +def get_config() -> Config: + """获取配置单例""" + if _config is None: + return load_config() + return _config + + +def reload_config(config_path: Optional[str] = None) -> Config: + """重新加载配置""" + global _config + _config = None + return load_config(config_path) diff --git a/config.yaml b/config.yaml index b15c7ce..0b061ea 100644 --- a/config.yaml +++ b/config.yaml @@ -1,107 +1,87 @@ +# 安保异常行为识别系统 - 核心配置 + +# 数据库配置 +database: + dialect: "sqlite" # sqlite 或 mysql + host: "localhost" + port: 3306 + username: "root" + password: "password" + name: "security_monitor" + echo: true # SQL日志输出 + +# TensorRT模型配置 model: - path: "C:/Users/16337/PycharmProjects/Security/yolo11n.pt" - imgsz: 480 + engine_path: "models/yolo11n_fp16_480.engine" + pt_model_path: "models/yolo11n.pt" + imgsz: [480, 480] conf_threshold: 0.5 - device: "cuda" # cuda, cpu + iou_threshold: 0.45 + device: 0 # GPU设备号 + batch_size: 8 # 最大batch size + half: true # FP16推理 -# 大模型配置 +# RTSP流配置 +stream: + buffer_size: 2 # 每路摄像头帧缓冲大小 + reconnect_delay: 3.0 # 重连延迟(秒) + timeout: 10.0 # 连接超时(秒) + fps_limit: 30 # 最大处理FPS + +# 推理队列配置 +inference: + queue_maxlen: 100 # 检测结果队列最大长度 + event_queue_maxlen: 1000 # 异常事件队列最大长度 + worker_threads: 4 # 处理线程数 + +# 告警配置 +alert: + snapshot_path: "data/alerts" + cooldown_sec: 300 # 同类型告警冷却时间 + image_quality: 85 # JPEG质量 + +# ROI配置 +roi: + default_types: + - "polygon" + - "line" + max_points: 50 # 多边形最大顶点数 + +# 工作时间配置(全局默认) +working_hours: + - start: [8, 30] # 8:30 + end: [11, 0] # 11:00 + - start: [12, 0] # 12:00 + end: [17, 30] # 17:30 + +# 算法默认参数 +algorithms: + leave_post: + default_threshold_sec: 360 # 离岗超时(6分钟) + confirm_sec: 30 # 离岗确认时间 + return_sec: 5 # 上岗确认时间 + intrusion: + check_interval_sec: 1.0 # 检测间隔 + direction_sensitive: false # 方向敏感 + +# 日志配置 +logging: + level: "INFO" + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + file: "logs/app.log" + max_size: "100MB" + backup_count: 5 + +# 监控配置 +monitoring: + enabled: true + port: 9090 # Prometheus metrics端口 + path: "/metrics" + +# 大模型配置(预留) llm: - api_key: "sk-21e61bef09074682b589da3bdbfe07a2" # 请替换为实际的API密钥(阿里云DashScope API Key) - base_url: "https://dashscope.aliyuncs.com/compatible-mode/v1/" - model_name: "qwen3-vl-flash" # 模型名称,可选:qwen-vl-max, qwen-vl-plus, qwen3-vl-flash等 - -common: - # 工作时间段:支持多个时间段,格式为 [开始小时, 开始分钟, 结束小时, 结束分钟] - # 8:30-11:00, 12:00-17:30 - working_hours: - - [8, 30, 11, 0] # 8:30-11:00 - - [12, 0, 17, 30] # 12:00-17:30 - process_every_n_frames: 3 # 每3帧处理1帧(用于人员离岗) - alert_cooldown_sec: 300 # 离岗告警冷却(秒) - off_duty_alert_threshold_sec: 360 # 离岗超过6分钟(360秒)触发告警 - -cameras: - - id: "cam_01" - rtsp_url: "rtsp://admin:admin@172.16.8.19:554/cam/realmonitor?channel=16&subtype=1" - process_every_n_frames: 5 - rois: - - name: "离岗检测区域" - points: [[380, 50], [530, 100], [550, 550], [140, 420]] - algorithms: - - name: "人员离岗" - enabled: true - off_duty_threshold_sec: 300 # 离岗超时告警(秒) - on_duty_confirm_sec: 5 # 上岗确认时间(秒) - off_duty_confirm_sec: 30 # 离岗确认时间(秒) - - name: "周界入侵" - enabled: false - # - name: "周界入侵区域1" - # points: [[100, 100], [200, 100], [200, 300], [100, 300]] - # algorithms: - # - name: "人员离岗" - # enabled: false - # - name: "周界入侵" - # enabled: false - - - id: "cam_02" - rtsp_url: "rtsp://admin:admin@172.16.8.13:554/cam/realmonitor?channel=7&subtype=1" - rois: - - name: "离岗检测区域" - points: [[380, 50], [530, 100], [550, 550], [140, 420]] - algorithms: - - name: "人员离岗" - enabled: true - off_duty_threshold_sec: 600 - on_duty_confirm_sec: 10 - off_duty_confirm_sec: 20 - - name: "周界入侵" - enabled: false - - - id: "cam_03" - rtsp_url: "rtsp://admin:admin@172.16.8.26:554/cam/realmonitor?channel=3&subtype=1" - rois: - - name: "离岗检测区域" - points: [[380, 50], [530, 100], [550, 550], [140, 420]] - algorithms: - - name: "人员离岗" - enabled: true - off_duty_threshold_sec: 600 - on_duty_confirm_sec: 10 - off_duty_confirm_sec: 20 - - name: "周界入侵" - enabled: false - - - id: "cam_04" - rtsp_url: "rtsp://admin:admin@172.16.8.20:554/cam/realmonitor?channel=14&subtype=1" - process_every_n_frames: 5 - rois: - - name: "离岗检测区域" -# points: [ [ 380, 50 ], [ 530, 100 ], [ 550, 550 ], [ 140, 420 ] ] - algorithms: - - name: "人员离岗" - enabled: false - - name: "周界入侵" - enabled: true - - id: "cam_05" - rtsp_url: "rtsp://admin:admin@172.16.8.31:554/cam/realmonitor?channel=15&subtype=1" - process_every_n_frames: 5 - rois: - - name: "离岗检测区域" -# points: [ [ 380, 50 ], [ 530, 100 ], [ 550, 550 ], [ 140, 420 ] ] - algorithms: - - name: "人员离岗" - enabled: false # 离岗确认时间(秒) - - name: "周界入侵" - enabled: true - - - id: "cam_06" - rtsp_url: "rtsp://admin:admin@172.16.8.35:554/cam/realmonitor?channel=13&subtype=1" - process_every_n_frames: 5 - rois: - - name: "离岗检测区域" - points: [ [ 150, 100 ], [ 600, 100 ], [ 600, 500 ], [ 150, 500 ] ] - algorithms: - - name: "人员离岗" - enabled: false - - name: "周界入侵" - enabled: true \ No newline at end of file + enabled: false + api_key: "" + base_url: "" + model: "qwen3-vl-max" + timeout: 30 diff --git a/db/crud.py b/db/crud.py new file mode 100644 index 0000000..421174a --- /dev/null +++ b/db/crud.py @@ -0,0 +1,314 @@ +from datetime import datetime +from typing import List, Optional + +from sqlalchemy.orm import Session + +from db.models import Camera, CameraStatus, ROI, Alarm + + +def get_all_cameras(db: Session, enabled_only: bool = True) -> List[Camera]: + query = db.query(Camera) + if enabled_only: + query = query.filter(Camera.enabled == True) + return query.all() + + +def get_camera_by_id(db: Session, camera_id: int) -> Optional[Camera]: + return db.query(Camera).filter(Camera.id == camera_id).first() + + +def create_camera( + db: Session, + name: str, + rtsp_url: str, + fps_limit: int = 30, + process_every_n_frames: int = 3, +) -> Camera: + camera = Camera( + name=name, + rtsp_url=rtsp_url, + fps_limit=fps_limit, + process_every_n_frames=process_every_n_frames, + ) + db.add(camera) + db.commit() + db.refresh(camera) + return camera + + +def update_camera( + db: Session, + camera_id: int, + name: Optional[str] = None, + rtsp_url: Optional[str] = None, + fps_limit: Optional[int] = None, + process_every_n_frames: Optional[int] = None, + enabled: Optional[bool] = None, +) -> Optional[Camera]: + camera = get_camera_by_id(db, camera_id) + if not camera: + return None + + if name is not None: + camera.name = name + if rtsp_url is not None: + camera.rtsp_url = rtsp_url + if fps_limit is not None: + camera.fps_limit = fps_limit + if process_every_n_frames is not None: + camera.process_every_n_frames = process_every_n_frames + if enabled is not None: + camera.enabled = enabled + + db.commit() + db.refresh(camera) + return camera + + +def delete_camera(db: Session, camera_id: int) -> bool: + camera = get_camera_by_id(db, camera_id) + if not camera: + return False + + db.delete(camera) + db.commit() + return True + + +def get_camera_status(db: Session, camera_id: int) -> Optional[CameraStatus]: + return ( + db.query(CameraStatus).filter(CameraStatus.camera_id == camera_id).first() + ) + + +def update_camera_status( + db: Session, + camera_id: int, + is_running: Optional[bool] = None, + fps: Optional[float] = None, + error_message: Optional[str] = None, +) -> Optional[CameraStatus]: + status = get_camera_status(db, camera_id) + if not status: + status = CameraStatus(camera_id=camera_id) + db.add(status) + + if is_running is not None: + status.is_running = is_running + if fps is not None: + status.fps = fps + if error_message is not None: + status.error_message = error_message + + status.last_check_time = datetime.utcnow() + + db.commit() + db.refresh(status) + return status + + +def get_all_rois(db: Session, camera_id: Optional[int] = None) -> List[ROI]: + query = db.query(ROI) + if camera_id is not None: + query = query.filter(ROI.camera_id == camera_id) + return query.filter(ROI.enabled == True).all() + + +def get_roi_by_id(db: Session, roi_id: int) -> Optional[ROI]: + return db.query(ROI).filter(ROI.id == roi_id).first() + + +def get_roi_by_roi_id(db: Session, camera_id: int, roi_id: str) -> Optional[ROI]: + return ( + db.query(ROI) + .filter(ROI.camera_id == camera_id, ROI.roi_id == roi_id) + .first() + ) + + +def create_roi( + db: Session, + camera_id: int, + roi_id: str, + name: str, + roi_type: str, + points: List[List[float]], + rule_type: str, + direction: Optional[str] = None, + stay_time: Optional[int] = None, + threshold_sec: int = 360, + confirm_sec: int = 30, + return_sec: int = 5, +) -> ROI: + import json + + roi = ROI( + camera_id=camera_id, + roi_id=roi_id, + name=name, + roi_type=roi_type, + points=json.dumps(points), + rule_type=rule_type, + direction=direction, + stay_time=stay_time, + threshold_sec=threshold_sec, + confirm_sec=confirm_sec, + return_sec=return_sec, + ) + db.add(roi) + db.commit() + db.refresh(roi) + return roi + + +def update_roi( + db: Session, + roi_id: int, + name: Optional[str] = None, + points: Optional[List[List[float]]] = None, + rule_type: Optional[str] = None, + direction: Optional[str] = None, + stay_time: Optional[int] = None, + enabled: Optional[bool] = None, + threshold_sec: Optional[int] = None, + confirm_sec: Optional[int] = None, + return_sec: Optional[int] = None, +) -> Optional[ROI]: + import json + + roi = get_roi_by_id(db, roi_id) + if not roi: + return None + + if name is not None: + roi.name = name + if points is not None: + roi.points = json.dumps(points) + if rule_type is not None: + roi.rule_type = rule_type + if direction is not None: + roi.direction = direction + if stay_time is not None: + roi.stay_time = stay_time + if enabled is not None: + roi.enabled = enabled + if threshold_sec is not None: + roi.threshold_sec = threshold_sec + if confirm_sec is not None: + roi.confirm_sec = confirm_sec + if return_sec is not None: + roi.return_sec = return_sec + + db.commit() + db.refresh(roi) + return roi + + +def delete_roi(db: Session, roi_id: int) -> bool: + roi = get_roi_by_id(db, roi_id) + if not roi: + return False + + db.delete(roi) + db.commit() + return True + + +def get_roi_points(db: Session, camera_id: int) -> List[dict]: + import json + + rois = get_all_rois(db, camera_id) + return [ + { + "id": roi.id, + "roi_id": roi.roi_id, + "name": roi.name, + "type": roi.roi_type, + "points": json.loads(roi.points), + "rule": roi.rule_type, + "direction": roi.direction, + "stay_time": roi.stay_time, + "enabled": roi.enabled, + "threshold_sec": roi.threshold_sec, + "confirm_sec": roi.confirm_sec, + "return_sec": roi.return_sec, + } + for roi in rois + ] + + +def create_alarm( + db: Session, + camera_id: int, + event_type: str, + confidence: float = 0.0, + snapshot_path: Optional[str] = None, + roi_id: Optional[str] = None, + llm_checked: bool = False, + llm_result: Optional[str] = None, +) -> Alarm: + alarm = Alarm( + camera_id=camera_id, + roi_id=roi_id, + event_type=event_type, + confidence=confidence, + snapshot_path=snapshot_path, + llm_checked=llm_checked, + llm_result=llm_result, + ) + db.add(alarm) + db.commit() + db.refresh(alarm) + return alarm + + +def get_alarms( + db: Session, + camera_id: Optional[int] = None, + event_type: Optional[str] = None, + limit: int = 100, + offset: int = 0, +) -> List[Alarm]: + query = db.query(Alarm) + if camera_id is not None: + query = query.filter(Alarm.camera_id == camera_id) + if event_type is not None: + query = query.filter(Alarm.event_type == event_type) + return query.order_by(Alarm.created_at.desc()).offset(offset).limit(limit).all() + + +def update_alarm( + db: Session, + alarm_id: int, + llm_checked: Optional[bool] = None, + llm_result: Optional[str] = None, + processed: Optional[bool] = None, +) -> Optional[Alarm]: + alarm = db.query(Alarm).filter(Alarm.id == alarm_id).first() + if not alarm: + return None + + if llm_checked is not None: + alarm.llm_checked = llm_checked + if llm_result is not None: + alarm.llm_result = llm_result + if processed is not None: + alarm.processed = processed + + db.commit() + db.refresh(alarm) + return alarm + + +def get_alarm_stats(db: Session) -> dict: + total = db.query(Alarm).count() + unprocessed = db.query(Alarm).filter(Alarm.processed == False).count() + llm_pending = db.query(Alarm).filter( + Alarm.llm_checked == False, Alarm.processed == False + ).count() + + return { + "total": total, + "unprocessed": unprocessed, + "llm_pending": llm_pending, + } diff --git a/db/models.py b/db/models.py new file mode 100644 index 0000000..5f07b1f --- /dev/null +++ b/db/models.py @@ -0,0 +1,167 @@ +import os +import sys +from datetime import datetime +from typing import List, Optional + +from sqlalchemy import ( + Boolean, + DateTime, + Float, + ForeignKey, + Integer, + String, + Text, + create_engine, + event, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + relationship, + sessionmaker, +) + +from config import get_config + + +class Base(DeclarativeBase): + pass + + +class Camera(Base): + __tablename__ = "cameras" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + name: Mapped[str] = mapped_column(String(64), nullable=False) + rtsp_url: Mapped[str] = mapped_column(Text, nullable=False) + enabled: Mapped[bool] = mapped_column(Boolean, default=True) + fps_limit: Mapped[int] = mapped_column(Integer, default=30) + process_every_n_frames: Mapped[int] = mapped_column(Integer, default=3) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + rois: Mapped[List["ROI"]] = relationship( + "ROI", back_populates="camera", cascade="all, delete-orphan" + ) + status: Mapped[Optional["CameraStatus"]] = relationship( + "CameraStatus", back_populates="camera", uselist=False + ) + alarms: Mapped[List["Alarm"]] = relationship( + "Alarm", back_populates="camera", cascade="all, delete-orphan" + ) + + +class CameraStatus(Base): + __tablename__ = "camera_status" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + camera_id: Mapped[int] = mapped_column( + Integer, ForeignKey("cameras.id"), unique=True, nullable=False + ) + is_running: Mapped[bool] = mapped_column(Boolean, default=False) + last_frame_time: Mapped[Optional[datetime]] = mapped_column(DateTime) + fps: Mapped[float] = mapped_column(Float, default=0.0) + error_message: Mapped[Optional[str]] = mapped_column(Text) + last_check_time: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + camera: Mapped["Camera"] = relationship("Camera", back_populates="status") + + +class ROI(Base): + __tablename__ = "rois" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + camera_id: Mapped[int] = mapped_column( + Integer, ForeignKey("cameras.id"), nullable=False + ) + roi_id: Mapped[str] = mapped_column(String(64), nullable=False) + name: Mapped[str] = mapped_column(String(128), nullable=False) + roi_type: Mapped[str] = mapped_column(String(32), nullable=False) + points: Mapped[str] = mapped_column(Text, nullable=False) + rule_type: Mapped[str] = mapped_column(String(32), nullable=False) + direction: Mapped[Optional[str]] = mapped_column(String(32)) + stay_time: Mapped[Optional[int]] = mapped_column(Integer) + enabled: Mapped[bool] = mapped_column(Boolean, default=True) + threshold_sec: Mapped[int] = mapped_column(Integer, default=360) + confirm_sec: Mapped[int] = mapped_column(Integer, default=30) + return_sec: Mapped[int] = mapped_column(Integer, default=5) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, default=datetime.utcnow, onupdate=datetime.utcnow + ) + + camera: Mapped["Camera"] = relationship("Camera", back_populates="rois") + + +class Alarm(Base): + __tablename__ = "alarms" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + camera_id: Mapped[int] = mapped_column( + Integer, ForeignKey("cameras.id"), nullable=False + ) + roi_id: Mapped[Optional[str]] = mapped_column(String(64)) + event_type: Mapped[str] = mapped_column(String(32), nullable=False) + confidence: Mapped[float] = mapped_column(Float, default=0.0) + snapshot_path: Mapped[Optional[str]] = mapped_column(Text) + llm_checked: Mapped[bool] = mapped_column(Boolean, default=False) + llm_result: Mapped[Optional[str]] = mapped_column(Text) + processed: Mapped[bool] = mapped_column(Boolean, default=False) + created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow) + + camera: Mapped["Camera"] = relationship("Camera", back_populates="alarms") + + +_engine = None +_SessionLocal = None + + +def get_engine(): + global _engine + if _engine is None: + config = get_config() + _engine = create_engine( + config.database.url, + echo=config.database.echo, + pool_pre_ping=True, + pool_recycle=3600, + ) + return _engine + + +def get_session_factory(): + global _SessionLocal + if _SessionLocal is None: + _SessionLocal = sessionmaker( + autocommit=False, autoflush=False, bind=get_engine() + ) + return _SessionLocal + + +def get_db(): + SessionLocal = get_session_factory() + db = SessionLocal() + try: + yield db + finally: + db.close() + + +def init_db(): + config = get_config() + engine = get_engine() + + Base.metadata.create_all(bind=engine) + + if config.database.dialect == "sqlite": + for table in [Camera, ROI, Alarm]: + table.__table__.create(engine, checkfirst=True) + + +def reset_engine(): + global _engine, _SessionLocal + _engine = None + _SessionLocal = None diff --git a/detector.py b/detector.py deleted file mode 100644 index f729806..0000000 --- a/detector.py +++ /dev/null @@ -1,218 +0,0 @@ -import cv2 -import numpy as np -from ultralytics import YOLO -from sort import Sort -import time -import datetime -import threading -import queue -import torch -from collections import deque - - -class ThreadedFrameReader: - def __init__(self, src, maxsize=1): - self.cap = cv2.VideoCapture(src, cv2.CAP_FFMPEG) - self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) - self.q = queue.Queue(maxsize=maxsize) - self.running = True - self.thread = threading.Thread(target=self._reader) - self.thread.daemon = True - self.thread.start() - - def _reader(self): - while self.running: - ret, frame = self.cap.read() - if not ret: - time.sleep(0.1) - continue - if not self.q.empty(): - try: - self.q.get_nowait() - except queue.Empty: - pass - self.q.put(frame) - - def read(self): - if not self.q.empty(): - return True, self.q.get() - return False, None - - def release(self): - self.running = False - self.cap.release() - - -def is_point_in_roi(x, y, roi): - return cv2.pointPolygonTest(roi, (int(x), int(y)), False) >= 0 - - -class OffDutyCrowdDetector: - def __init__(self, config, model, device, use_half): - self.config = config - self.model = model - self.device = device - self.use_half = use_half - - # 解析 ROI - self.roi = np.array(config["roi_points"], dtype=np.int32) - self.crowd_roi = np.array(config["crowd_roi_points"], dtype=np.int32) - - # 状态变量 - self.tracker = Sort( - max_age=30, - min_hits=2, - iou_threshold=0.3 - ) - - self.is_on_duty = False - self.on_duty_start_time = None - self.is_off_duty = True - self.last_no_person_time = None - self.off_duty_timer_start = None - self.last_alert_time = 0 - - self.last_crowd_alert_time = 0 - self.crowd_history = deque(maxlen=1500) # 自动限制5分钟(假设5fps) - - self.last_person_seen_time = None - self.frame_count = 0 - - # 缓存配置 - self.working_start_min = config.get("working_hours", [9, 17])[0] * 60 - self.working_end_min = config.get("working_hours", [9, 17])[1] * 60 - self.process_every = config.get("process_every_n_frames", 3) - - def in_working_hours(self): - now = datetime.datetime.now() - total_min = now.hour * 60 + now.minute - return self.working_start_min <= total_min <= self.working_end_min - - def count_people_in_roi(self, boxes, roi): - count = 0 - for box in boxes: - x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() - cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 - if is_point_in_roi(cx, cy, roi): - count += 1 - return count - - def run(self): - """主循环:供线程调用""" - frame_reader = ThreadedFrameReader(self.config["rtsp_url"]) - try: - while True: - ret, frame = frame_reader.read() - if not ret: - time.sleep(0.01) - continue - - self.frame_count += 1 - if self.frame_count % self.process_every != 0: - continue - - current_time = time.time() - now = datetime.datetime.now() - - # YOLO 推理 - results = self.model( - frame, - imgsz=self.config.get("imgsz", 480), - conf=self.config.get("conf_thresh", 0.4), - verbose=False, - device=self.device, - half=self.use_half, - classes=[0] # person class - ) - boxes = results[0].boxes - - # 更新 tracker(可选,用于ID跟踪) - dets = [] - for box in boxes: - x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() - conf = float(box.conf) - dets.append([x1, y1, x2, y2, conf]) - dets = np.array(dets) if dets else np.empty((0, 5)) - self.tracker.update(dets) - - # === 离岗检测 === - if self.in_working_hours(): - roi_has_person = self.count_people_in_roi(boxes, self.roi) > 0 - if roi_has_person: - self.last_person_seen_time = current_time - - # 入岗保护期 - effective_on_duty = ( - self.last_person_seen_time is not None and - (current_time - self.last_person_seen_time) < 1.0 - ) - - if effective_on_duty: - self.last_no_person_time = None - if self.is_off_duty: - if self.on_duty_start_time is None: - self.on_duty_start_time = current_time - elif current_time - self.on_duty_start_time >= self.config.get("on_duty_confirm", 5): - self.is_on_duty = True - self.is_off_duty = False - self.on_duty_start_time = None - print(f"[{self.config['id']}] ✅ 上岗确认") - else: - self.on_duty_start_time = None - self.last_person_seen_time = None - if not self.is_off_duty: - if self.last_no_person_time is None: - self.last_no_person_time = current_time - elif current_time - self.last_no_person_time >= self.config.get("off_duty_confirm", 30): - self.is_off_duty = True - self.is_on_duty = False - self.off_duty_timer_start = current_time - print(f"[{self.config['id']}] ⏳ 开始离岗计时") - - # 离岗告警 - if self.is_off_duty and self.off_duty_timer_start: - elapsed = current_time - self.off_duty_timer_start - if elapsed >= self.config.get("off_duty_threshold", 300): - if current_time - self.last_alert_time >= self.config.get("alert_cooldown", 300): - print(f"[{self.config['id']}] 🚨 离岗告警!已离岗 {elapsed/60:.1f} 分钟") - self.last_alert_time = current_time - - # === 聚集检测 === - crowd_count = self.count_people_in_roi(boxes, self.crowd_roi) - self.crowd_history.append((current_time, crowd_count)) - - # 动态阈值 - if crowd_count >= 10: - req_dur = 60 - elif crowd_count >= 7: - req_dur = 120 - elif crowd_count >= 5: - req_dur = 300 - else: - req_dur = float('inf') - - if req_dur < float('inf'): - recent = [(t, c) for t, c in self.crowd_history if current_time - t <= req_dur] - if recent: - valid = [c for t, c in recent if c >= 4] - ratio = len(valid) / len(recent) - if ratio >= 0.9 and (current_time - self.last_crowd_alert_time) >= self.config.get("crowd_cooldown", 180): - print(f"[{self.config['id']}] 🚨 聚集告警!{crowd_count}人持续{req_dur//60}分钟") - self.last_crowd_alert_time = current_time - - # 可视化(可选,部署时可关闭) - if True: # 设为 True 可显示窗口 - vis = results[0].plot() - overlay = vis.copy() - cv2.fillPoly(overlay, [self.roi], (0,255,0)) - cv2.fillPoly(overlay, [self.crowd_roi], (0,0,255)) - cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis) - cv2.imshow(f"Monitor - {self.config['id']}", vis) - if cv2.waitKey(1) & 0xFF == ord('q'): - break - - except Exception as e: - print(f"[{self.config['id']}] Error: {e}") - finally: - frame_reader.release() - cv2.destroyAllWindows() \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..d838e74 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,63 @@ +version: '3.8' + +services: + security-monitor: + build: + context: . + dockerfile: Dockerfile + container_name: security-monitor + runtime: nvidia + ports: + - "8000:8000" + - "9090:9090" + volumes: + - ./models:/app/models + - ./data:/app/data + - ./logs:/app/logs + - ./config.yaml:/app/config.yaml:ro + environment: + - CUDA_VISIBLE_DEVICES=0 + - CONFIG_PATH=/app/config.yaml + restart: unless-stopped + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8000/health"] + interval: 30s + timeout: 10s + retries: 3 + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + prometheus: + image: prom/prometheus:v2.45.0 + container_name: prometheus + ports: + - "9091:9090" + volumes: + - ./prometheus.yml:/etc/prometheus/prometheus.yml:ro + - prometheus_data:/prometheus + command: + - '--config.file=/etc/prometheus/prometheus.yml' + - '--storage.tsdb.path=/prometheus' + - '--web.enable-lifecycle' + restart: unless-stopped + + grafana: + image: grafana/grafana:10.0.0 + container_name: grafana + ports: + - "3000:3000" + volumes: + - grafana_data:/var/lib/grafana + - ./grafana/provisioning:/etc/grafana/provisioning:ro + environment: + - GF_SECURITY_ADMIN_PASSWORD=admin + restart: unless-stopched + +volumes: + prometheus_data: + grafana_data: diff --git a/inference/engine.py b/inference/engine.py new file mode 100644 index 0000000..1b83982 --- /dev/null +++ b/inference/engine.py @@ -0,0 +1,244 @@ +import os +import time +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import tensorrt as trt +import torch +from ultralytics import YOLO +from ultralytics.engine.results import Results + +from config import get_config + + +class TensorRTEngine: + def __init__(self, engine_path: Optional[str] = None, device: int = 0): + config = get_config() + self.engine_path = engine_path or config.model.engine_path + self.device = device + self.imgsz = tuple(config.model.imgsz) + self.conf_thresh = config.model.conf_threshold + self.iou_thresh = config.model.iou_threshold + self.half = config.model.half + + self.logger = trt.Logger(trt.Logger.INFO) + self.engine = None + self.context = None + self.stream = None + self.input_buffer = None + self.output_buffers = [] + + self._load_engine() + + def _load_engine(self): + if not os.path.exists(self.engine_path): + raise FileNotFoundError(f"TensorRT引擎文件不存在: {self.engine_path}") + + with open(self.engine_path, "rb") as f: + serialized_engine = f.read() + + runtime = trt.Runtime(self.logger) + self.engine = runtime.deserialize_cuda_engine(serialized_engine) + + self.context = self.engine.create_execution_context() + + self.stream = torch.cuda.Stream(device=self.device) + + for i in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(i) + dtype = self.engine.get_tensor_dtype(name) + shape = self.engine.get_tensor_shape(name) + + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self.context.set_tensor_address(name, None) + else: + if dtype == trt.float16: + buffer = torch.zeros(shape, dtype=torch.float16, device=self.device) + else: + buffer = torch.zeros(shape, dtype=torch.float32, device=self.device) + self.output_buffers.append(buffer) + self.context.set_tensor_address(name, buffer.data_ptr()) + + self.context.set_optimization_profile_async(0, self.stream) + + self.input_buffer = torch.zeros( + (1, 3, self.imgsz[0], self.imgsz[1]), + dtype=torch.float16 if self.half else torch.float32, + device=self.device, + ) + + def preprocess(self, frame: np.ndarray) -> torch.Tensor: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, self.imgsz) + + img = img.transpose(2, 0, 1).astype(np.float32) / 255.0 + + if self.half: + img = img.astype(np.float16) + + tensor = torch.from_numpy(img).unsqueeze(0).to(self.device) + + return tensor + + def inference(self, images: List[np.ndarray]) -> List[Results]: + batch_size = len(images) + if batch_size == 0: + return [] + + input_tensor = self.preprocess(images[0]) + + if batch_size > 1: + for i in range(1, batch_size): + input_tensor = torch.cat( + [input_tensor, self.preprocess(images[i])], dim=0 + ) + + self.context.set_tensor_address( + "input", input_tensor.contiguous().data_ptr() + ) + + torch.cuda.synchronize(self.stream) + self.context.execute_async_v3(self.stream.handle) + torch.cuda.synchronize(self.stream) + + results = [] + for i in range(batch_size): + pred = self.output_buffers[0][i].cpu().numpy() + boxes = pred[:, :4] + scores = pred[:, 4] + classes = pred[:, 5].astype(np.int32) + + mask = scores > self.conf_thresh + boxes = boxes[mask] + scores = scores[mask] + classes = classes[mask] + + indices = cv2.dnn.NMSBoxes( + boxes.tolist(), + scores.tolist(), + self.conf_thresh, + self.iou_thresh, + ) + + if len(indices) > 0: + for idx in indices: + box = boxes[idx] + x1, y1, x2, y2 = box + w, h = x2 - x1, y2 - y1 + conf = scores[idx] + cls = classes[idx] + + orig_h, orig_w = images[i].shape[:2] + scale_x, scale_y = orig_w / self.imgsz[1], orig_h / self.imgsz[0] + box_orig = [ + int(x1 * scale_x), + int(y1 * scale_y), + int(w * scale_x), + int(h * scale_y), + ] + + result = Results( + orig_img=images[i], + path="", + names={0: "person"}, + boxes=Boxes( + torch.tensor([box_orig + [conf, cls]]), + orig_shape=(orig_h, orig_w), + ), + ) + results.append(result) + + return results + + def inference_single(self, frame: np.ndarray) -> List[Results]: + return self.inference([frame]) + + def warmup(self, num_warmup: int = 10): + dummy_frame = np.zeros((480, 640, 3), dtype=np.uint8) + for _ in range(num_warmup): + self.inference_single(dummy_frame) + + def __del__(self): + if self.context: + self.context.synchronize() + if self.stream: + self.stream.synchronize() + + +class Boxes: + def __init__( + self, + data: torch.Tensor, + orig_shape: Tuple[int, int], + is_track: bool = False, + ): + self.data = data + self.orig_shape = orig_shape + self.is_track = is_track + + @property + def xyxy(self): + if self.is_track: + return self.data[:, :4] + return self.data[:, :4] + + @property + def conf(self): + if self.is_track: + return self.data[:, 4] + return self.data[:, 4] + + @property + def cls(self): + if self.is_track: + return self.data[:, 5] + return self.data[:, 5] + + +class YOLOEngine: + def __init__( + self, + model_path: Optional[str] = None, + device: int = 0, + use_trt: bool = True, + ): + self.use_trt = use_trt + self.device = device + self.trt_engine = None + + if not use_trt: + if model_path: + pt_path = model_path + elif hasattr(get_config().model, 'pt_model_path'): + pt_path = get_config().model.pt_model_path + else: + pt_path = get_config().model.engine_path.replace(".engine", ".pt") + self.model = YOLO(pt_path) + self.model.to(device) + else: + try: + self.trt_engine = TensorRTEngine(device=device) + self.trt_engine.warmup() + except Exception as e: + print(f"TensorRT加载失败,回退到PyTorch: {e}") + self.use_trt = False + if model_path: + pt_path = model_path + elif hasattr(get_config().model, 'pt_model_path'): + pt_path = get_config().model.pt_model_path + else: + pt_path = get_config().model.engine_path.replace(".engine", ".pt") + self.model = YOLO(pt_path) + self.model.to(device) + + def __call__(self, frame: np.ndarray, **kwargs) -> List[Results]: + if self.use_trt: + return self.trt_engine.inference_single(frame) + else: + results = self.model(frame, imgsz=get_config().model.imgsz, **kwargs) + return results + + def __del__(self): + if self.trt_engine: + del self.trt_engine diff --git a/inference/pipeline.py b/inference/pipeline.py new file mode 100644 index 0000000..a82a258 --- /dev/null +++ b/inference/pipeline.py @@ -0,0 +1,376 @@ +import asyncio +import json +import os +import threading +import time +from collections import deque +from datetime import datetime +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from config import get_config +from db.crud import ( + create_alarm, + get_all_rois, + update_camera_status, +) +from db.models import init_db +from inference.engine import YOLOEngine +from inference.roi.roi_filter import ROIFilter +from inference.rules.algorithms import AlgorithmManager +from inference.stream import StreamManager + + +class InferencePipeline: + def __init__(self): + self.config = get_config() + + self.db_initialized = False + + self.yolo_engine = YOLOEngine(use_trt=True) + self.stream_manager = StreamManager( + buffer_size=self.config.stream.buffer_size, + reconnect_delay=self.config.stream.reconnect_delay, + ) + self.roi_filter = ROIFilter() + self.algo_manager = AlgorithmManager(working_hours=[ + { + "start": [wh.start[0], wh.start[1]], + "end": [wh.end[0], wh.end[1]], + } + for wh in self.config.working_hours + ]) + + self.camera_threads: Dict[int, threading.Thread] = {} + self.camera_stop_events: Dict[int, threading.Event] = {} + self.camera_latest_frames: Dict[int, Any] = {} + self.camera_frame_times: Dict[int, datetime] = {} + self.camera_process_counts: Dict[int, int] = {} + + self.event_queue: deque = deque(maxlen=self.config.inference.event_queue_maxlen) + + self.running = False + + def _init_database(self): + if not self.db_initialized: + init_db() + self.db_initialized = True + + def _get_db_session(self): + from db.models import get_session_factory + SessionLocal = get_session_factory() + return SessionLocal() + + def _load_cameras(self): + db = self._get_db_session() + try: + from db.crud import get_all_cameras + cameras = get_all_cameras(db, enabled_only=True) + for camera in cameras: + self.add_camera(camera) + finally: + db.close() + + def add_camera(self, camera) -> bool: + camera_id = camera.id + if camera_id in self.camera_threads: + return False + + self.camera_stop_events[camera_id] = threading.Event() + self.camera_process_counts[camera_id] = 0 + + self.stream_manager.add_stream( + str(camera_id), + camera.rtsp_url, + self.config.stream.buffer_size, + ) + + thread = threading.Thread( + target=self._camera_inference_loop, + args=(camera,), + daemon=True, + ) + thread.start() + self.camera_threads[camera_id] = thread + + self._update_camera_status(camera_id, is_running=True) + return True + + def remove_camera(self, camera_id: int): + if camera_id not in self.camera_threads: + return + + self.camera_stop_events[camera_id].set() + self.camera_threads[camera_id].join(timeout=10.0) + + del self.camera_threads[camera_id] + del self.camera_stop_events[camera_id] + + self.stream_manager.remove_stream(str(camera_id)) + self.roi_filter.clear_cache(camera_id) + self.algo_manager.remove_roi(str(camera_id)) + + if camera_id in self.camera_latest_frames: + del self.camera_latest_frames[camera_id] + if camera_id in self.camera_frame_times: + del self.camera_frame_times[camera_id] + if camera_id in self.camera_process_counts: + del self.camera_process_counts[camera_id] + + self._update_camera_status(camera_id, is_running=False) + + def _update_camera_status( + self, + camera_id: int, + is_running: Optional[bool] = None, + fps: Optional[float] = None, + error_message: Optional[str] = None, + ): + try: + db = self._get_db_session() + update_camera_status( + db, + camera_id, + is_running=is_running, + fps=fps, + error_message=error_message, + ) + except Exception as e: + print(f"[{camera_id}] 更新状态失败: {e}") + finally: + db.close() + + def _camera_inference_loop(self, camera): + camera_id = camera.id + stop_event = self.camera_stop_events[camera_id] + + while not stop_event.is_set(): + ret, frame = self.stream_manager.read(str(camera_id)) + if not ret or frame is None: + time.sleep(0.1) + continue + + self.camera_latest_frames[camera_id] = frame + self.camera_frame_times[camera_id] = datetime.now() + + self.camera_process_counts[camera_id] += 1 + + if self.camera_process_counts[camera_id] % camera.process_every_n_frames != 0: + continue + + try: + self._process_frame(camera_id, frame, camera) + except Exception as e: + print(f"[{camera_id}] 处理帧失败: {e}") + + print(f"[{camera_id}] 推理线程已停止") + + def _process_frame(self, camera_id: int, frame: np.ndarray, camera): + from ultralytics.engine.results import Results + + db = self._get_db_session() + try: + rois = get_all_rois(db, camera_id) + roi_configs = [ + { + "id": roi.id, + "roi_id": roi.roi_id, + "type": roi.roi_type, + "points": json.loads(roi.points), + "rule": roi.rule_type, + "direction": roi.direction, + "enabled": roi.enabled, + "threshold_sec": roi.threshold_sec, + "confirm_sec": roi.confirm_sec, + "return_sec": roi.return_sec, + } + for roi in rois + ] + + if roi_configs: + self.roi_filter.update_cache(camera_id, roi_configs) + + for roi_config in roi_configs: + roi_id = roi_config["roi_id"] + rule_type = roi_config["rule"] + + self.algo_manager.register_algorithm( + roi_id, + rule_type, + { + "threshold_sec": roi_config.get("threshold_sec", 360), + "confirm_sec": roi_config.get("confirm_sec", 30), + "return_sec": roi_config.get("return_sec", 5), + }, + ) + + results = self.yolo_engine(frame, verbose=False, classes=[0]) + + if not results: + return + + result = results[0] + detections = [] + if hasattr(result, "boxes") and result.boxes is not None: + boxes = result.boxes.xyxy.cpu().numpy() + confs = result.boxes.conf.cpu().numpy() + for i, box in enumerate(boxes): + detections.append({ + "bbox": box.tolist(), + "conf": float(confs[i]), + "cls": 0, + }) + + if roi_configs: + filtered_detections = self.roi_filter.filter_detections( + detections, roi_configs + ) + else: + filtered_detections = detections + + for detection in filtered_detections: + matched_rois = detection.get("matched_rois", []) + for roi_conf in matched_rois: + roi_id = roi_conf["roi_id"] + rule_type = roi_conf["rule"] + + alerts = self.algo_manager.process( + roi_id, + str(camera_id), + rule_type, + [detection], + datetime.now(), + ) + + for alert in alerts: + self._handle_alert(camera_id, alert, frame, roi_conf) + + finally: + db.close() + + def _handle_alert( + self, + camera_id: int, + alert: Dict[str, Any], + frame: np.ndarray, + roi_config: Dict[str, Any], + ): + try: + snapshot_path = None + bbox = alert.get("bbox", []) + if bbox and len(bbox) >= 4: + x1, y1, x2, y2 = [int(v) for v in bbox] + x1, y1 = max(0, x1), max(0, y1) + x2, y2 = min(frame.shape[1], x2), min(frame.shape[0], y2) + + snapshot_dir = self.config.alert.snapshot_path + os.makedirs(snapshot_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"cam_{camera_id}_{roi_config['roi_id']}_{alert['alert_type']}_{timestamp}.jpg" + snapshot_path = os.path.join(snapshot_dir, filename) + + cv2.imwrite(snapshot_path, frame, [cv2.IMWRITE_JPEG_QUALITY, self.config.alert.image_quality]) + + db = self._get_db_session() + try: + alarm = create_alarm( + db, + camera_id=camera_id, + event_type=alert["alert_type"], + confidence=alert.get("confidence", alert.get("conf", 0.0)), + snapshot_path=snapshot_path, + roi_id=roi_config["roi_id"], + ) + alert["alarm_id"] = alarm.id + finally: + db.close() + + self.event_queue.append({ + "camera_id": camera_id, + "roi_id": roi_config["roi_id"], + "event_type": alert["alert_type"], + "confidence": alert.get("confidence", alert.get("conf", 0.0)), + "message": alert.get("message", ""), + "snapshot_path": snapshot_path, + "timestamp": datetime.now().isoformat(), + "llm_checked": False, + }) + + print(f"[{camera_id}] 🚨 告警: {alert['alert_type']} - {alert.get('message', '')}") + + except Exception as e: + print(f"[{camera_id}] 处理告警失败: {e}") + + def get_latest_frame(self, camera_id: int) -> Optional[np.ndarray]: + return self.camera_latest_frames.get(camera_id) + + def get_camera_fps(self, camera_id: int) -> float: + stream = self.stream_manager.get_stream(str(camera_id)) + if stream: + return stream.fps + return 0.0 + + def get_event_queue(self) -> List[Dict[str, Any]]: + return list(self.event_queue) + + def start(self): + if self.running: + return + + self._init_database() + self._load_cameras() + self.running = True + + def stop(self): + if not self.running: + return + + self.running = False + + for camera_id in list(self.camera_threads.keys()): + self.remove_camera(camera_id) + + self.stream_manager.stop_all() + self.algo_manager.reset_all() + + print("推理pipeline已停止") + + def get_status(self) -> Dict[str, Any]: + return { + "running": self.running, + "camera_count": len(self.camera_threads), + "cameras": { + cid: { + "running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(), + "fps": self.get_camera_fps(cid), + "frame_time": self.camera_frame_times.get(cid).isoformat() if self.camera_frame_times.get(cid) else None, + } + for cid in self.camera_threads + }, + "event_queue_size": len(self.event_queue), + } + + +_pipeline: Optional[InferencePipeline] = None + + +def get_pipeline() -> InferencePipeline: + global _pipeline + if _pipeline is None: + _pipeline = InferencePipeline() + return _pipeline + + +def start_pipeline(): + pipeline = get_pipeline() + pipeline.start() + return pipeline + + +def stop_pipeline(): + global _pipeline + if _pipeline is not None: + _pipeline.stop() + _pipeline = None diff --git a/inference/roi/roi_filter.py b/inference/roi/roi_filter.py new file mode 100644 index 0000000..f0953aa --- /dev/null +++ b/inference/roi/roi_filter.py @@ -0,0 +1,168 @@ +import json +from typing import Dict, List, Optional, Tuple + +import numpy as np +from shapely.geometry import LineString, Point, Polygon + + +class ROIFilter: + def __init__(self): + self.roi_cache: Dict[int, List[Dict]] = {} + + def parse_points(self, points_json: str) -> List[Tuple[float, float]]: + if isinstance(points_json, str): + points = json.loads(points_json) + else: + points = points_json + return [(float(p[0]), float(p[1])) for p in points] + + def create_polygon(self, points: List[Tuple[float, float]]) -> Polygon: + return Polygon(points) + + def create_line(self, points: List[Tuple[float, float]]) -> LineString: + return LineString(points) + + def is_point_in_polygon(self, point: Tuple[float, float], polygon: Polygon) -> bool: + return polygon.contains(Point(point)) + + def is_line_crossed( + self, + line_start: Tuple[float, float], + line_end: Tuple[float, float], + line_obj: LineString, + ) -> bool: + trajectory = LineString([line_start, line_end]) + return line_obj.intersects(trajectory) + + def get_bbox_center(self, bbox: List[float]) -> Tuple[float, float]: + x1, y1, x2, y2 = bbox[:4] + return ((x1 + x2) / 2, (y1 + y2) / 2) + + def is_bbox_in_roi(self, bbox: List[float], roi_config: Dict) -> bool: + center = self.get_bbox_center(bbox) + roi_type = roi_config.get("type", "polygon") + points = self.parse_points(roi_config["points"]) + + if roi_type == "polygon": + polygon = self.create_polygon(points) + return self.is_point_in_polygon(center, polygon) + elif roi_type == "line": + line = self.create_line(points) + trajectory_start = ( + bbox[0], + bbox[1], + ) + trajectory_end = ( + bbox[2], + bbox[3], + ) + return self.is_line_crossed(trajectory_start, trajectory_end, line) + + return False + + def filter_detections( + self, + detections: List[Dict], + roi_configs: List[Dict], + require_all_rois: bool = False, + ) -> List[Dict]: + if not roi_configs: + return detections + + filtered = [] + for det in detections: + bbox = det.get("bbox", []) + if not bbox: + filtered.append(det) + continue + + matches = [] + for roi_config in roi_configs: + if not roi_config.get("enabled", True): + continue + if self.is_bbox_in_roi(bbox, roi_config): + matches.append(roi_config) + + if matches: + if require_all_rois: + if len(matches) == len(roi_configs): + det["matched_rois"] = matches + filtered.append(det) + else: + det["matched_rois"] = matches + filtered.append(det) + + return filtered + + def filter_by_rule_type( + self, + detections: List[Dict], + roi_configs: List[Dict], + rule_type: str, + ) -> List[Dict]: + rule_rois = [ + roi for roi in roi_configs + if roi.get("rule") == rule_type and roi.get("enabled", True) + ] + return self.filter_detections(detections, rule_rois) + + def update_cache(self, camera_id: int, rois: List[Dict]): + self.roi_cache[camera_id] = rois + + def get_cached_rois(self, camera_id: int) -> List[Dict]: + return self.roi_cache.get(camera_id, []) + + def clear_cache(self, camera_id: Optional[int] = None): + if camera_id is None: + self.roi_cache.clear() + elif camera_id in self.roi_cache: + del self.roi_cache[camera_id] + + def check_direction( + self, + prev_bbox: List[float], + curr_bbox: List[float], + roi_config: Dict, + ) -> bool: + if roi_config.get("direction") is None: + return True + + direction = roi_config["direction"] + prev_center = self.get_bbox_center(prev_bbox) + curr_center = self.get_bbox_center(curr_bbox) + + roi_points = self.parse_points(roi_config["points"]) + if len(roi_points) != 2: + return True + + line = self.create_line(roi_points) + line_start = roi_points[0] + line_end = roi_points[1] + + dx_line = line_end[0] - line_start[0] + dy_line = line_end[1] - line_start[1] + + dx_person = curr_center[0] - prev_center[0] + dy_person = curr_center[1] - prev_center[1] + + dot_product = dx_person * dx_line + dy_person * dy_line + + if direction == "A_to_B": + return dot_product > 0 + elif direction == "B_to_A": + return dot_product < 0 + else: + return True + + def get_detection_roi_info( + self, + bbox: List[float], + roi_configs: List[Dict], + ) -> List[Dict]: + matched = [] + for roi_config in roi_configs: + if not roi_config.get("enabled", True): + continue + if self.is_bbox_in_roi(bbox, roi_config): + matched.append(roi_config) + return matched diff --git a/inference/rules/algorithms.py b/inference/rules/algorithms.py new file mode 100644 index 0000000..cee4667 --- /dev/null +++ b/inference/rules/algorithms.py @@ -0,0 +1,303 @@ +import os +import sys +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from sort import Sort + + +class LeavePostAlgorithm: + def __init__( + self, + threshold_sec: int = 360, + confirm_sec: int = 30, + return_sec: int = 5, + working_hours: Optional[List[Dict]] = None, + ): + self.threshold_sec = threshold_sec + self.confirm_sec = confirm_sec + self.return_sec = return_sec + self.working_hours = working_hours or [] + + self.track_states: Dict[str, Dict[str, Any]] = {} + self.tracker = Sort(max_age=10, min_hits=2, iou_threshold=0.3) + + self.alert_cooldowns: Dict[str, datetime] = {} + self.cooldown_seconds = 300 + + def is_in_working_hours(self, dt: Optional[datetime] = None) -> bool: + if not self.working_hours: + return True + + dt = dt or datetime.now() + current_minutes = dt.hour * 60 + dt.minute + + for period in self.working_hours: + start_minutes = period["start"][0] * 60 + period["start"][1] + end_minutes = period["end"][0] * 60 + period["end"][1] + if start_minutes <= current_minutes < end_minutes: + return True + + return False + + def process( + self, + camera_id: str, + tracks: List[Dict], + current_time: Optional[datetime] = None, + ) -> List[Dict]: + if not self.is_in_working_hours(current_time): + return [] + + if not tracks: + return [] + + detections = [] + for track in tracks: + bbox = track.get("bbox", []) + if len(bbox) >= 4: + detections.append(bbox + [track.get("conf", 0.0)]) + + if not detections: + return [] + + detections = np.array(detections) + tracked = self.tracker.update(detections) + + alerts = [] + current_time = current_time or datetime.now() + + for track_data in tracked: + x1, y1, x2, y2, track_id = track_data + track_id = str(int(track_id)) + + if track_id not in self.track_states: + self.track_states[track_id] = { + "first_seen": current_time, + "last_seen": current_time, + "off_duty_start": None, + "alerted": False, + "last_position": (x1, y1, x2, y2), + } + + state = self.track_states[track_id] + state["last_seen"] = current_time + state["last_position"] = (x1, y1, x2, y2) + + if state["off_duty_start"] is None: + off_duty_duration = (current_time - state["first_seen"]).total_seconds() + if off_duty_duration > self.confirm_sec: + state["off_duty_start"] = current_time + else: + elapsed = (current_time - state["off_duty_start"]).total_seconds() + if elapsed > self.threshold_sec: + if not state["alerted"]: + cooldown_key = f"{camera_id}_{track_id}" + now = datetime.now() + if cooldown_key not in self.alert_cooldowns or ( + now - self.alert_cooldowns[cooldown_key] + ).total_seconds() > self.cooldown_seconds: + alerts.append({ + "track_id": track_id, + "bbox": [x1, y1, x2, y2], + "off_duty_duration": elapsed, + "alert_type": "leave_post", + "message": f"离岗超过 {int(elapsed / 60)} 分钟", + }) + state["alerted"] = True + self.alert_cooldowns[cooldown_key] = now + else: + if elapsed < self.return_sec: + state["off_duty_start"] = None + state["alerted"] = False + + cleanup_time = current_time - timedelta(minutes=5) + for track_id, state in list(self.track_states.items()): + if state["last_seen"] < cleanup_time: + del self.track_states[track_id] + + return alerts + + def reset(self): + self.track_states.clear() + self.alert_cooldowns.clear() + + def get_state(self, track_id: str) -> Optional[Dict[str, Any]]: + return self.track_states.get(track_id) + + +class IntrusionAlgorithm: + def __init__( + self, + check_interval_sec: float = 1.0, + direction_sensitive: bool = False, + ): + self.check_interval_sec = check_interval_sec + self.direction_sensitive = direction_sensitive + + self.last_check_times: Dict[str, float] = {} + self.tracker = Sort(max_age=5, min_hits=1, iou_threshold=0.3) + + self.alert_cooldowns: Dict[str, datetime] = {} + self.cooldown_seconds = 300 + + def process( + self, + camera_id: str, + tracks: List[Dict], + current_time: Optional[datetime] = None, + ) -> List[Dict]: + if not tracks: + return [] + + detections = [] + for track in tracks: + bbox = track.get("bbox", []) + if len(bbox) >= 4: + detections.append(bbox + [track.get("conf", 0.0)]) + + if not detections: + return [] + + current_ts = current_time.timestamp() if current_time else datetime.now().timestamp() + + if camera_id in self.last_check_times: + if current_ts - self.last_check_times[camera_id] < self.check_interval_sec: + return [] + self.last_check_times[camera_id] = current_ts + + detections = np.array(detections) + tracked = self.tracker.update(detections) + + alerts = [] + now = datetime.now() + + for track_data in tracked: + x1, y1, x2, y2, track_id = track_data + cooldown_key = f"{camera_id}_{int(track_id)}" + + if cooldown_key not in self.alert_cooldowns or ( + now - self.alert_cooldowns[cooldown_key] + ).total_seconds() > self.cooldown_seconds: + alerts.append({ + "track_id": str(int(track_id)), + "bbox": [x1, y1, x2, y2], + "alert_type": "intrusion", + "confidence": track_data[4] if len(track_data) > 4 else 0.0, + "message": "检测到周界入侵", + }) + self.alert_cooldowns[cooldown_key] = now + + return alerts + + def reset(self): + self.last_check_times.clear() + self.alert_cooldowns.clear() + + +class AlgorithmManager: + def __init__(self, working_hours: Optional[List[Dict]] = None): + self.algorithms: Dict[str, Dict[str, Any]] = {} + self.working_hours = working_hours or [] + + self.default_params = { + "leave_post": { + "threshold_sec": 360, + "confirm_sec": 30, + "return_sec": 5, + }, + "intrusion": { + "check_interval_sec": 1.0, + "direction_sensitive": False, + }, + } + + def register_algorithm( + self, + roi_id: str, + algorithm_type: str, + params: Optional[Dict[str, Any]] = None, + ): + if roi_id in self.algorithms: + if algorithm_type in self.algorithms[roi_id]: + return + + if roi_id not in self.algorithms: + self.algorithms[roi_id] = {} + + algo_params = self.default_params.get(algorithm_type, {}) + if params: + algo_params.update(params) + + if algorithm_type == "leave_post": + self.algorithms[roi_id]["leave_post"] = LeavePostAlgorithm( + threshold_sec=algo_params.get("threshold_sec", 360), + confirm_sec=algo_params.get("confirm_sec", 30), + return_sec=algo_params.get("return_sec", 5), + working_hours=self.working_hours, + ) + elif algorithm_type == "intrusion": + self.algorithms[roi_id]["intrusion"] = IntrusionAlgorithm( + check_interval_sec=algo_params.get("check_interval_sec", 1.0), + direction_sensitive=algo_params.get("direction_sensitive", False), + ) + + def process( + self, + roi_id: str, + camera_id: str, + algorithm_type: str, + tracks: List[Dict], + current_time: Optional[datetime] = None, + ) -> List[Dict]: + algo = self.algorithms.get(roi_id, {}).get(algorithm_type) + if algo is None: + return [] + return algo.process(camera_id, tracks, current_time) + + def update_roi_params( + self, + roi_id: str, + algorithm_type: str, + params: Dict[str, Any], + ): + if roi_id in self.algorithms and algorithm_type in self.algorithms[roi_id]: + algo = self.algorithms[roi_id][algorithm_type] + for key, value in params.items(): + if hasattr(algo, key): + setattr(algo, key, value) + + def reset_algorithm(self, roi_id: str, algorithm_type: Optional[str] = None): + if roi_id not in self.algorithms: + return + + if algorithm_type: + if algorithm_type in self.algorithms[roi_id]: + self.algorithms[roi_id][algorithm_type].reset() + else: + for algo in self.algorithms[roi_id].values(): + algo.reset() + + def reset_all(self): + for roi_algorithms in self.algorithms.values(): + for algo in roi_algorithms.values(): + algo.reset() + + def remove_roi(self, roi_id: str): + if roi_id in self.algorithms: + self.reset_algorithm(roi_id) + del self.algorithms[roi_id] + + def get_status(self, roi_id: str) -> Dict[str, Any]: + status = {} + if roi_id in self.algorithms: + for algo_type, algo in self.algorithms[roi_id].items(): + status[algo_type] = { + "track_count": len(getattr(algo, "track_states", {})), + } + return status diff --git a/inference/stream.py b/inference/stream.py new file mode 100644 index 0000000..3844ad8 --- /dev/null +++ b/inference/stream.py @@ -0,0 +1,192 @@ +import asyncio +import threading +import time +from collections import deque +from typing import Any, Callable, Dict, List, Optional, Tuple + +import cv2 +import numpy as np + + +class StreamReader: + def __init__( + self, + camera_id: str, + rtsp_url: str, + buffer_size: int = 2, + reconnect_delay: float = 3.0, + timeout: float = 10.0, + ): + self.camera_id = camera_id + self.rtsp_url = rtsp_url + self.buffer_size = buffer_size + self.reconnect_delay = reconnect_delay + self.timeout = timeout + + self.cap = None + self.running = False + self.thread: Optional[threading.Thread] = None + self.frame_buffer: deque = deque(maxlen=buffer_size) + self.lock = threading.Lock() + self.fps = 0.0 + self.last_frame_time = time.time() + self.frame_count = 0 + + def _reader_loop(self): + while self.running: + with self.lock: + if self.cap is None or not self.cap.isOpened(): + self._connect() + if self.cap is None: + time.sleep(self.reconnect_delay) + continue + + ret = False + frame = None + with self.lock: + if self.cap is not None and self.cap.isOpened(): + ret, frame = self.cap.read() + + if ret and frame is not None: + self.frame_buffer.append(frame) + self.frame_count += 1 + current_time = time.time() + + if current_time - self.last_frame_time >= 1.0: + self.fps = self.frame_count / (current_time - self.last_frame_time) + self.frame_count = 0 + self.last_frame_time = current_time + else: + time.sleep(0.1) + + def _connect(self): + if self.cap is not None: + try: + self.cap.release() + except Exception: + pass + self.cap = None + + try: + self.cap = cv2.VideoCapture(self.rtsp_url, cv2.CAP_FFMPEG) + if self.cap.isOpened(): + self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) + print(f"[{self.camera_id}] RTSP连接成功: {self.rtsp_url[:50]}...") + else: + print(f"[{self.camera_id}] 无法打开视频流") + self.cap = None + except Exception as e: + print(f"[{self.camera_id}] 连接失败: {e}") + self.cap = None + + def start(self): + if self.running: + return + + self.running = True + self.thread = threading.Thread(target=self._reader_loop, daemon=True) + self.thread.start() + + def stop(self): + if not self.running: + return + + self.running = False + + if self.thread is not None: + self.thread.join(timeout=5.0) + if self.thread.is_alive(): + print(f"[{self.camera_id}] 警告: 线程未能在5秒内结束") + + with self.lock: + if self.cap is not None: + try: + self.cap.release() + except Exception: + pass + self.cap = None + + self.frame_buffer.clear() + + def read(self) -> Tuple[bool, Optional[np.ndarray]]: + if len(self.frame_buffer) > 0: + return True, self.frame_buffer.popleft() + return False, None + + def get_frame(self) -> Optional[np.ndarray]: + if len(self.frame_buffer) > 0: + return self.frame_buffer[-1] + return None + + def is_connected(self) -> bool: + with self.lock: + return self.cap is not None and self.cap.isOpened() + + def get_info(self) -> Dict[str, Any]: + return { + "camera_id": self.camera_id, + "running": self.running, + "connected": self.is_connected(), + "fps": self.fps, + "buffer_size": len(self.frame_buffer), + "buffer_max": self.buffer_size, + } + + +class StreamManager: + def __init__(self, buffer_size: int = 2, reconnect_delay: float = 3.0): + self.streams: Dict[str, StreamReader] = {} + self.buffer_size = buffer_size + self.reconnect_delay = reconnect_delay + + def add_stream( + self, + camera_id: str, + rtsp_url: str, + buffer_size: Optional[int] = None, + ) -> StreamReader: + if camera_id in self.streams: + self.remove_stream(camera_id) + + stream = StreamReader( + camera_id=camera_id, + rtsp_url=rtsp_url, + buffer_size=buffer_size or self.buffer_size, + reconnect_delay=self.reconnect_delay, + ) + stream.start() + self.streams[camera_id] = stream + return stream + + def remove_stream(self, camera_id: str): + if camera_id in self.streams: + self.streams[camera_id].stop() + del self.streams[camera_id] + + def get_stream(self, camera_id: str) -> Optional[StreamReader]: + return self.streams.get(camera_id) + + def read(self, camera_id: str) -> Tuple[bool, Optional[np.ndarray]]: + stream = self.get_stream(camera_id) + if stream is None: + return False, None + return stream.read() + + def get_frame(self, camera_id: str) -> Optional[np.ndarray]: + stream = self.get_stream(camera_id) + if stream is None: + return None + return stream.get_frame() + + def stop_all(self): + for camera_id in list(self.streams.keys()): + self.remove_stream(camera_id) + + def get_all_info(self) -> List[Dict[str, Any]]: + return [stream.get_info() for stream in self.streams.values()] + + def __contains__(self, camera_id: str) -> bool: + return camera_id in self.streams + + def __len__(self) -> int: + return len(self.streams) diff --git a/logs/app.log b/logs/app.log new file mode 100644 index 0000000..5e3e77c --- /dev/null +++ b/logs/app.log @@ -0,0 +1,62 @@ +2026-01-20 16:14:08,029 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 16:14:08,042 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 16:35:00,666 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 16:35:00,681 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 16:35:35,087 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 16:35:35,099 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 16:36:43,838 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 16:36:43,854 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 16:36:44,074 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 16:40:28,288 - security_monitor - INFO - 正在关闭系统... +2026-01-20 16:40:36,045 - security_monitor - INFO - 系统已关闭 +2026-01-20 16:50:54,485 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 16:50:54,499 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 16:51:08,066 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 16:56:03,282 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 16:56:03,295 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 16:56:15,787 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 16:57:32,966 - security_monitor - INFO - 正在关闭系统... +2026-01-20 16:57:38,237 - security_monitor - INFO - 系统已关闭 +2026-01-20 17:01:03,511 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:01:03,525 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:01:16,086 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:02:31,685 - security_monitor - INFO - 正在关闭系统... +2026-01-20 17:02:38,875 - security_monitor - INFO - 系统已关闭 +2026-01-20 17:04:14,608 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:04:14,624 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:04:27,113 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:08:05,883 - security_monitor - INFO - 正在关闭系统... +2026-01-20 17:08:13,257 - security_monitor - INFO - 系统已关闭 +2026-01-20 17:09:38,432 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:09:38,445 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:09:50,944 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:10:54,647 - security_monitor - INFO - 正在关闭系统... +2026-01-20 17:11:01,985 - security_monitor - INFO - 系统已关闭 +2026-01-20 17:11:15,714 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:11:15,732 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:11:28,265 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:14:17,481 - security_monitor - INFO - 正在关闭系统... +2026-01-20 17:14:22,406 - security_monitor - INFO - 系统已关闭 +2026-01-20 17:15:56,666 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:15:56,680 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:16:09,183 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:16:36,221 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:16:36,235 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:16:48,752 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:16:53,218 - security_monitor - INFO - 正在关闭系统... +2026-01-20 17:17:08,230 - security_monitor - INFO - 系统已关闭 +2026-01-20 17:17:13,977 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:17:13,991 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:17:26,500 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:18:25,049 - security_monitor - INFO - 正在关闭系统... +2026-01-20 17:18:31,834 - security_monitor - INFO - 系统已关闭 +2026-01-20 17:18:37,609 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:18:37,623 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:18:50,135 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:19:25,216 - security_monitor - INFO - 正在关闭系统... +2026-01-20 17:19:25,384 - security_monitor - INFO - 系统已关闭 +2026-01-20 17:29:49,295 - security_monitor - INFO - 启动安保异常行为识别系统 +2026-01-20 17:29:49,307 - security_monitor - INFO - 数据库初始化完成 +2026-01-20 17:30:01,820 - security_monitor - INFO - 推理Pipeline启动,活跃摄像头数: 2 +2026-01-20 17:31:10,482 - security_monitor - INFO - 正在关闭系统... +2026-01-20 17:31:10,612 - security_monitor - INFO - 系统已关闭 diff --git a/main.py b/main.py index eb389a0..a1d393f 100644 --- a/main.py +++ b/main.py @@ -1,16 +1,211 @@ -# 这是一个示例 Python 脚本。 +import asyncio +import base64 +import os +import sys +import threading +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Optional -# 按 Shift+F10 执行或将其替换为您的代码。 -# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 +import cv2 +import numpy as np +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import FileResponse, StreamingResponse +from prometheus_client import start_http_server + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from api.alarm import router as alarm_router +from api.camera import router as camera_router +from api.roi import router as roi_router +from config import get_config, load_config +from db.models import init_db +from inference.pipeline import get_pipeline, start_pipeline, stop_pipeline +from utils.logger import setup_logger +from utils.metrics import get_metrics_server, start_metrics_server, update_system_info -def print_hi(name): - # 在下面的代码行中使用断点来调试脚本。 - print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。 +logger = None -# 按装订区域中的绿色按钮以运行脚本。 -if __name__ == '__main__': - print_hi('PyCharm') +@asynccontextmanager +async def lifespan(app: FastAPI): + global logger -# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助 + config = get_config() + logger = setup_logger( + name="security_monitor", + log_file=config.logging.file, + level=config.logging.level, + max_size=config.logging.max_size, + backup_count=config.logging.backup_count, + ) + + logger.info("启动安保异常行为识别系统") + + init_db() + logger.info("数据库初始化完成") + + start_metrics_server() + update_system_info() + + pipeline = start_pipeline() + logger.info(f"推理Pipeline启动,活跃摄像头数: {len(pipeline.camera_threads)}") + + try: + yield + except asyncio.CancelledError: + pass + finally: + logger.info("正在关闭系统...") + stop_pipeline() + logger.info("系统已关闭") + + +app = FastAPI( + title="安保异常行为识别系统", + description="基于YOLO和规则算法的智能监控系统", + version="1.0.0", + lifespan=lifespan, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(camera_router) +app.include_router(roi_router) +app.include_router(alarm_router) + + +@app.get("/") +async def root(): + return { + "name": "安保异常行为识别系统", + "version": "1.0.0", + "status": "running", + } + + +@app.get("/health") +async def health_check(): + pipeline = get_pipeline() + return { + "status": "healthy", + "running": pipeline.running, + "cameras": len(pipeline.camera_threads), + "timestamp": datetime.now().isoformat(), + } + + +@app.get("/api/camera/{camera_id}/snapshot") +async def get_snapshot(camera_id: int): + pipeline = get_pipeline() + frame = pipeline.get_latest_frame(camera_id) + + if frame is None: + raise HTTPException(status_code=404, detail="无法获取帧") + + _, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) + return StreamingResponse( + iter([buffer.tobytes()]), + media_type="image/jpeg", + ) + + +@app.get("/api/camera/{camera_id}/snapshot/base64") +async def get_snapshot_base64(camera_id: int): + pipeline = get_pipeline() + frame = pipeline.get_latest_frame(camera_id) + + if frame is None: + raise HTTPException(status_code=404, detail="无法获取帧") + + _, buffer = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) + return {"image": base64.b64encode(buffer).decode("utf-8")} + + +@app.get("/api/camera/{camera_id}/detect") +async def get_detection_with_overlay(camera_id: int): + pipeline = get_pipeline() + frame = pipeline.get_latest_frame(camera_id) + + if frame is None: + raise HTTPException(status_code=404, detail="无法获取帧") + + import json + from db.crud import get_all_rois + from db.models import get_session_factory + + SessionLocal = get_session_factory() + db = SessionLocal() + try: + rois = get_all_rois(db, camera_id) + roi_configs = [ + { + "id": roi.id, + "roi_id": roi.roi_id, + "name": roi.name, + "type": roi.roi_type, + "points": json.loads(roi.points), + "rule": roi.rule_type, + "enabled": roi.enabled, + } + for roi in rois + ] + finally: + db.close() + + from utils.helpers import draw_detections + + overlay_frame = draw_detections(frame, [], roi_configs if roi_configs else None) + + _, buffer = cv2.imencode(".jpg", overlay_frame, [cv2.IMWRITE_JPEG_QUALITY, 85]) + return StreamingResponse( + iter([buffer.tobytes()]), + media_type="image/jpeg", + ) + + +@app.get("/api/pipeline/status") +async def get_pipeline_status(): + pipeline = get_pipeline() + return pipeline.get_status() + + +@app.get("/api/stream/list") +async def list_streams(): + pipeline = get_pipeline() + return {"streams": pipeline.stream_manager.get_all_info()} + + +@app.post("/api/pipeline/reload") +async def reload_pipeline(): + stop_pipeline() + import time + time.sleep(1) + start_pipeline() + return {"message": "Pipeline已重新加载"} + + +def main(): + import uvicorn + + config = load_config() + + uvicorn.run( + "main:app", + host="0.0.0.0", + port=8000, + reload=False, + log_level=config.logging.level.lower(), + ) + + +if __name__ == "__main__": + main() diff --git a/models/yolo11n_fp16_480.engine b/models/yolo11n_fp16_480.engine new file mode 100644 index 0000000..6c9a140 Binary files /dev/null and b/models/yolo11n_fp16_480.engine differ diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..aff1a7c --- /dev/null +++ b/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "Detector", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/prometheus.yml b/prometheus.yml new file mode 100644 index 0000000..10e932b --- /dev/null +++ b/prometheus.yml @@ -0,0 +1,13 @@ +global: + scrape_interval: 15s + evaluation_interval: 15s + +scrape_configs: + - job_name: 'prometheus' + static_configs: + - targets: ['localhost:9090'] + + - job_name: 'security-monitor' + static_configs: + - targets: ['localhost:9090'] + metrics_path: /metrics diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1bdf980 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,51 @@ +# 安保异常行为识别系统 - 依赖配置 + +# 核心框架 +fastapi==0.109.0 +uvicorn[standard]==0.27.0 +click>=8.0.0 +pydantic==2.5.3 +pydantic-settings==2.1.0 + +# 数据库 +sqlalchemy==2.0.25 +alembic==1.13.1 +pymysql==1.1.0 +cryptography==42.0.0 + +# 推理引擎 +torch>=2.0.0 +tensorrt>=10.0.0 +pycuda>=2023.1 +ultralytics>=8.1.0 +numpy>=1.24.0 + +# 视频处理 +opencv-python>=4.8.0 + +# 几何处理 +shapely>=2.0.0 + +# 配置管理 +pyyaml>=6.0.1 +python-dotenv>=1.0.0 + +# 监控与日志 +prometheus-client==0.19.0 +python-json-logger==2.0.7 + +# 异步支持 +asyncio>=3.4.3 + +# 工具库 +python-dateutil>=2.8.2 +pillow>=10.0.0 + +# 前端(React打包后静态文件) +# 静态文件服务用aiofiles +aiofiles>=23.2.1 + +# 开发与测试 +pytest==7.4.4 +pytest-asyncio==0.23.3 +httpx==0.26.0 diff --git a/run_detectors.py b/run_detectors.py deleted file mode 100644 index 3269d8e..0000000 --- a/run_detectors.py +++ /dev/null @@ -1,52 +0,0 @@ -import yaml -import threading -from ultralytics import YOLO -import torch -from detector import OffDutyCrowdDetector -import os - - -def load_config(config_path="config.yaml"): - with open(config_path, "r", encoding="utf-8") as f: - return yaml.safe_load(f) - - -def main(): - config = load_config() - - # 全局模型(共享) - model_path = config["model"]["path"] - device = config["model"].get("device", "cuda" if torch.cuda.is_available() else "cpu") - use_half = (device == "cuda") - - print(f"Loading model {model_path} on {device} (FP16: {use_half})") - model = YOLO(model_path) - model.to(device) - if use_half: - model.model.half() - - # 启动每个摄像头的检测线程 - threads = [] - for cam_cfg in config["cameras"]: - # 合并 common 配置 - full_cfg = {**config.get("common", {}), **cam_cfg} - full_cfg["imgsz"] = config["model"]["imgsz"] - full_cfg["conf_thresh"] = config["model"]["conf_thresh"] - full_cfg["working_hours"] = config["common"]["working_hours"] - - detector = OffDutyCrowdDetector(full_cfg, model, device, use_half) - thread = threading.Thread(target=detector.run, daemon=True) - thread.start() - threads.append(thread) - print(f"Started detector for {cam_cfg['id']}") - - print(f"✅ 已启动 {len(threads)} 路摄像头检测,按 Ctrl+C 退出") - try: - for t in threads: - t.join() - except KeyboardInterrupt: - print("\n🛑 Shutting down...") - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/scripts/build_engine.py b/scripts/build_engine.py new file mode 100644 index 0000000..047e1d7 --- /dev/null +++ b/scripts/build_engine.py @@ -0,0 +1,89 @@ +# TensorRT Engine 生成脚本 +# 使用方法: python scripts/build_engine.py + +import os +import sys +import argparse + +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) + +import torch +from ultralytics import YOLO + + +def build_engine(onnx_path, engine_path, fp16=True, dynamic_batch=True): + """构建TensorRT引擎""" + from tensorrt import Builder, NetworkDefinitionLayer, Runtime + from tensorrt.parsers import onnxparser + + logger = trt.Logger(trt.Logger.INFO) + builder = trt.Builder(logger) + + network_flags = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) + network = builder.create_network(network_flags) + + parser = onnxparser.create_onnx_parser(network) + parser.parse(onnx_path) + parser.report_status() + + # 动态形状配置 + if dynamic_batch: + profile = builder.create_optimization_profile() + min_shape = (1, 3, 480, 480) + opt_shape = (4, 3, 480, 480) + max_shape = (8, 3, 480, 480) + + profile.set_shape("input", min_shape, opt_shape, max_shape) + network.get_input(0).set_dynamic_range(-1.0, 1.0) + network.set_precision_constraints(trt.PrecisionConstraints.PREFER) + + config = builder.create_builder_config() + config.set_memory_allocator(trt.MemoryAllocator()) + config.max_workspace_size = 4 << 30 # 4GB + + if fp16: + config.set_flag(trt.BuilderFlag.FP16) + + serialized_engine = builder.build_serialized_network(network, config) + + with open(engine_path, "wb") as f: + f.write(serialized_engine) + + print(f"✅ TensorRT引擎已保存: {engine_path}") + + +def export_onnx(model_path, onnx_path, imgsz=480): + """导出ONNX模型""" + model = YOLO(model_path) + model.export( + format="onnx", + imgsz=[imgsz, imgsz], + simplify=True, + opset=12, + dynamic=True, + ) + print(f"✅ ONNX模型已导出: {onnx_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="TensorRT Engine Builder") + parser.add_argument("--model", type=str, default="models/yolo11n.pt", + help="YOLO模型路径") + parser.add_argument("--engine", type=str, default="models/yolo11n_fp16_480.engine", + help="输出引擎路径") + parser.add_argument("--onnx", type=str, default="models/yolo11n_480.onnx", + help="临时ONNX路径") + parser.add_argument("--fp16", action="store_true", default=True, + help="启用FP16") + parser.add_argument("--no-dynamic", action="store_true", + help="禁用动态Batch") + + args = parser.parse_args() + + os.makedirs(os.path.dirname(args.engine), exist_ok=True) + + if not os.path.exists(args.onnx): + export_onnx(args.model, args.onnx) + + build_engine(args.onnx, args.engine, args.fp16, not args.no_dynamic) diff --git a/security_monitor.db b/security_monitor.db new file mode 100644 index 0000000..e7aa754 Binary files /dev/null and b/security_monitor.db differ diff --git a/tests/test_core.py b/tests/test_core.py new file mode 100644 index 0000000..c673f2b --- /dev/null +++ b/tests/test_core.py @@ -0,0 +1,168 @@ +import pytest +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def test_roi_filter_parse_points(): + from inference.roi.roi_filter import ROIFilter + + filter = ROIFilter() + points = filter.parse_points("[[100, 200], [300, 400]]") + assert len(points) == 2 + assert points[0] == (100.0, 200.0) + + +def test_roi_filter_polygon(): + from inference.roi.roi_filter import ROIFilter + + filter = ROIFilter() + points = [(0, 0), (100, 0), (100, 100), (0, 100)] + polygon = filter.create_polygon(points) + + assert polygon.area == 10000 + assert filter.is_point_in_polygon((50, 50), polygon) == True + assert filter.is_point_in_polygon((150, 150), polygon) == False + + +def test_roi_filter_bbox_center(): + from inference.roi.roi_filter import ROIFilter + + filter = ROIFilter() + center = filter.get_bbox_center([10, 20, 100, 200]) + assert center == (55.0, 110.0) + + +def test_leave_post_algorithm_init(): + from inference.rules.algorithms import LeavePostAlgorithm + + algo = LeavePostAlgorithm( + threshold_sec=300, + confirm_sec=30, + return_sec=5, + ) + + assert algo.threshold_sec == 300 + assert algo.confirm_sec == 30 + assert algo.return_sec == 5 + + +def test_leave_post_algorithm_process(): + from inference.rules.algorithms import LeavePostAlgorithm + from datetime import datetime + + algo = LeavePostAlgorithm(threshold_sec=360, confirm_sec=30, return_sec=5) + + tracks = [ + {"bbox": [100, 100, 200, 200], "conf": 0.9, "cls": 0}, + ] + + alerts = algo.process("test_cam", tracks, datetime.now()) + assert isinstance(alerts, list) + + +def test_intrusion_algorithm_init(): + from inference.rules.algorithms import IntrusionAlgorithm + + algo = IntrusionAlgorithm( + check_interval_sec=1.0, + direction_sensitive=False, + ) + + assert algo.check_interval_sec == 1.0 + + +def test_algorithm_manager(): + from inference.rules.algorithms import AlgorithmManager + + manager = AlgorithmManager() + + manager.register_algorithm("roi_1", "leave_post", {"threshold_sec": 300}) + + assert "roi_1" in manager.algorithms + assert "leave_post" in manager.algorithms["roi_1"] + + +def test_config_load(): + from config import load_config, get_config + + config = load_config() + + assert config.database.dialect in ["sqlite", "mysql"] + assert config.model.imgsz == [480, 480] + assert config.model.batch_size == 8 + + +def test_database_models(): + from db.models import Camera, ROI, Alarm, init_db + + init_db() + + camera = Camera( + name="测试摄像头", + rtsp_url="rtsp://test.local/cam1", + fps_limit=30, + process_every_n_frames=3, + enabled=True, + ) + + assert camera.name == "测试摄像头" + assert camera.enabled == True + + +def test_camera_crud(): + from db.crud import create_camera, get_camera_by_id + from db.models import get_session_factory, init_db + + init_db() + SessionLocal = get_session_factory() + db = SessionLocal() + + try: + camera = create_camera( + db, + name="测试摄像头", + rtsp_url="rtsp://test.local/cam1", + fps_limit=30, + ) + + assert camera.id is not None + + fetched = get_camera_by_id(db, camera.id) + assert fetched is not None + assert fetched.name == "测试摄像头" + finally: + db.close() + + +def test_stream_reader_init(): + from inference.stream import StreamReader + + reader = StreamReader( + camera_id="test_cam", + rtsp_url="rtsp://test.local/stream", + buffer_size=2, + ) + + assert reader.camera_id == "test_cam" + assert reader.buffer_size == 2 + + +def test_utils_helpers(): + from utils.helpers import draw_bbox, draw_roi, format_duration + + import numpy as np + + image = np.zeros((480, 640, 3), dtype=np.uint8) + + result = draw_bbox(image, [100, 100, 200, 200], "Test") + + assert result.shape == image.shape + + duration = format_duration(125.5) + assert "2分" in duration + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/utils/helpers.py b/utils/helpers.py new file mode 100644 index 0000000..2150f31 --- /dev/null +++ b/utils/helpers.py @@ -0,0 +1,195 @@ +import os +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import numpy as np + + +def draw_bbox( + image: np.ndarray, + bbox: List[float], + label: str = "", + color: Tuple[int, int, int] = (0, 255, 0), + thickness: int = 2, +) -> np.ndarray: + x1, y1, x2, y2 = [int(v) for v in bbox] + cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness) + + if label: + (text_width, text_height), baseline = cv2.getTextSize( + label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, thickness + ) + cv2.rectangle( + image, + (x1, y1 - text_height - 10), + (x1 + text_width + 10, y1), + color, + -1, + ) + cv2.putText( + image, + label, + (x1 + 5, y1 - 5), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (0, 0, 0), + thickness, + ) + + return image + + +def draw_roi( + image: np.ndarray, + points: List[List[float]], + roi_type: str = "polygon", + color: Tuple[int, int, int] = (255, 0, 0), + thickness: int = 2, + label: str = "", +) -> np.ndarray: + points = np.array(points, dtype=np.int32) + + if roi_type == "polygon": + cv2.polylines(image, [points], True, color, thickness) + cv2.fillPoly(image, [points], color=(color[0], color[1], color[2], 30)) + elif roi_type == "line": + cv2.line(image, tuple(points[0]), tuple(points[1]), color, thickness) + elif roi_type == "rectangle": + x1, y1 = points[0] + x2, y2 = points[1] + cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness) + + if label: + cx = int(np.mean(points[:, 0])) + cy = int(np.mean(points[:, 1])) + cv2.putText( + image, + label, + (cx, cy), + cv2.FONT_HERSHEY_SIMPLEX, + 0.7, + color, + thickness, + ) + + return image + + +def draw_detections( + image: np.ndarray, + detections: List[Dict[str, Any]], + roi_configs: Optional[List[Dict[str, Any]]] = None, +) -> np.ndarray: + result = image.copy() + + for detection in detections: + bbox = detection.get("bbox", []) + conf = detection.get("conf", 0.0) + cls = detection.get("cls", 0) + + label = f"Person: {conf:.2f}" + color = (0, 255, 0) + + if roi_configs: + matched_rois = detection.get("matched_rois", []) + for roi_conf in matched_rois: + if roi_conf.get("enabled", True): + roi_points = roi_conf.get("points", []) + roi_type = roi_conf.get("type", "polygon") + roi_label = roi_conf.get("name", "") + roi_color = (255, 0, 0) if roi_conf.get("rule") == "intrusion" else (0, 165, 255) + result = draw_roi(result, roi_points, roi_type, roi_color, 2, roi_label) + + result = draw_bbox(result, bbox, label, color, 2) + + return result + + +def resize_image( + image: np.ndarray, + max_size: int = 480, + maintain_aspect: bool = True, +) -> np.ndarray: + h, w = image.shape[:2] + + if maintain_aspect: + scale = min(max_size / h, max_size / w) + new_h, new_w = int(h * scale), int(w * scale) + else: + new_h, new_w = max_size, max_size + scale = max_size / max(h, w) + + result = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA) + return result, scale + + +def encode_image_base64(image: np.ndarray, quality: int = 85) -> str: + _, buffer = cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, quality]) + return base64.b64encode(buffer).decode("utf-8") + + +def decode_image_base64(data: str) -> np.ndarray: + import base64 + buffer = base64.b64decode(data) + nparr = np.frombuffer(buffer, np.uint8) + return cv2.imdecode(nparr, cv2.IMREAD_COLOR) + + +def get_timestamp_str() -> str: + return datetime.now().strftime("%Y%m%d_%H%M%S") + + +def format_duration(seconds: float) -> str: + if seconds < 60: + return f"{int(seconds)}秒" + elif seconds < 3600: + minutes = int(seconds // 60) + secs = int(seconds % 60) + return f"{minutes}分{secs}秒" + else: + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + return f"{hours}小时{minutes}分" + + +def is_time_in_ranges( + current_time: datetime, + time_ranges: List[Dict[str, List[int]]], +) -> bool: + if not time_ranges: + return True + + current_minutes = current_time.hour * 60 + current_time.minute + + for time_range in time_ranges: + start = time_range["start"] + end = time_range["end"] + start_minutes = start[0] * 60 + start[1] + end_minutes = end[0] * 60 + end[1] + + if start_minutes <= current_minutes < end_minutes: + return True + + return False + + +class FPSCounter: + def __init__(self, window_size: int = 30): + self.timestamps: List[float] = [] + self.window_size = window_size + + def update(self): + import time + now = time.time() + self.timestamps.append(now) + self.timestamps = self.timestamps[-self.window_size:] + + @property + def fps(self) -> float: + if len(self.timestamps) < 2: + return 0.0 + elapsed = self.timestamps[-1] - self.timestamps[0] + if elapsed <= 0: + return 0.0 + return (len(self.timestamps) - 1) / elapsed diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..5ca7c8b --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,114 @@ +import logging +import os +import sys +from datetime import datetime +from logging.handlers import RotatingFileHandler +from typing import Optional + +from pythonjsonlogger import jsonlogger + + +def setup_logger( + name: str = "security_monitor", + log_file: str = "logs/app.log", + level: str = "INFO", + format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + max_size: str = "100MB", + backup_count: int = 5, + json_format: bool = False, +) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(getattr(logging, level.upper(), logging.INFO)) + + if logger.handlers: + for handler in logger.handlers: + logger.removeHandler(handler) + + os.makedirs(os.path.dirname(log_file), exist_ok=True) + + if json_format: + handler = RotatingFileHandler( + log_file, + maxBytes=int(max_size.replace("MB", "")) * 1024 * 1024, + backupCount=backup_count, + encoding="utf-8", + ) + formatter = jsonlogger.JsonFormatter( + "%(asctime)s %(name)s %(levelname)s %(message)s", + rename_fields={"levelname": "severity", "asctime": "timestamp"}, + ) + handler.setFormatter(formatter) + else: + handler = RotatingFileHandler( + log_file, + maxBytes=int(max_size.replace("MB", "")) * 1024 * 1024, + backupCount=backup_count, + encoding="utf-8", + ) + formatter = logging.Formatter(format) + handler.setFormatter(formatter) + + logger.addHandler(handler) + + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(logging.Formatter(format)) + logger.addHandler(console_handler) + + return logger + + +def get_logger(name: str = "security_monitor") -> logging.Logger: + return logging.getLogger(name) + + +class LoggerMixin: + @property + def logger(self) -> logging.Logger: + return get_logger(self.__class__.__name__) + + +def log_execution_time(logger: Optional[logging.Logger] = None): + def decorator(func): + import time + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + log = logger or get_logger(func.__module__) + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + log.info( + f"函数执行完成", + extra={ + "function": func.__name__, + "execution_time_ms": int((end_time - start_time) * 1000), + }, + ) + return result + + return wrapper + + return decorator + + +def log_function_call(logger: Optional[logging.Logger] = None): + def decorator(func): + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + log = logger or get_logger(func.__module__) + log.info( + f"函数调用", + extra={ + "function": func.__name__, + "args": str(args), + "kwargs": str(kwargs), + }, + ) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000..a3f7d4d --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,113 @@ +from typing import Optional + +from prometheus_client import Counter, Gauge, Histogram, Info, start_http_server + +from config import get_config + +SYSTEM_INFO = Info("system", "System information") + +CAMERA_COUNT = Gauge("camera_count", "Number of active cameras") + +CAMERA_FPS = Gauge("camera_fps", "Camera FPS", ["camera_id"]) + +INFERENCE_LATENCY = Histogram( + "inference_latency_seconds", + "Inference latency in seconds", + ["camera_id"], + buckets=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0], +) + +ALERT_COUNT = Counter( + "alert_total", + "Total number of alerts", + ["camera_id", "event_type"], +) + +EVENT_QUEUE_SIZE = Gauge( + "event_queue_size", + "Current size of event queue", +) + +DETECTION_COUNT = Counter( + "detection_total", + "Total number of detections", + ["camera_id", "roi_id"], +) + +GPU_MEMORY_USED = Gauge( + "gpu_memory_used_bytes", + "GPU memory used", + ["device"], +) + +GPU_UTILIZATION = Gauge( + "gpu_utilization_percent", + "GPU utilization", + ["device"], +) + + +class MetricsServer: + def __init__(self, port: int = 9090): + self.port = port + self.started = False + + def start(self): + if self.started: + return + + config = get_config() + if not config.monitoring.enabled: + return + + start_http_server(self.port) + self.started = True + print(f"Prometheus metrics server started on port {self.port}") + + def update_camera_metrics(self, camera_id: int, fps: float): + CAMERA_FPS.labels(camera_id=str(camera_id)).set(fps) + + def record_inference(self, camera_id: int, latency: float): + INFERENCE_LATENCY.labels(camera_id=str(camera_id)).observe(latency) + + def record_alert(self, camera_id: int, event_type: str): + ALERT_COUNT.labels(camera_id=str(camera_id), event_type=event_type).inc() + + def update_event_queue(self, size: int): + EVENT_QUEUE_SIZE.set(size) + + def record_detection(self, camera_id: int, roi_id: str): + DETECTION_COUNT.labels(camera_id=str(camera_id), roi_id=roi_id).inc() + + def update_gpu_metrics(self, device: int, memory_bytes: float, utilization: float): + GPU_MEMORY_USED.labels(device=str(device)).set(memory_bytes) + GPU_UTILIZATION.labels(device=str(device)).set(utilization) + + +_metrics_server: Optional[MetricsServer] = None + + +def get_metrics_server() -> MetricsServer: + global _metrics_server + if _metrics_server is None: + config = get_config() + _metrics_server = MetricsServer(port=config.monitoring.port) + return _metrics_server + + +def start_metrics_server(): + server = get_metrics_server() + server.start() + + +def update_system_info(): + import platform + import psutil + + SYSTEM_INFO.info({ + "os": platform.system(), + "os_version": platform.version(), + "python_version": platform.python_version(), + "cpu_count": str(psutil.cpu_count()), + "memory_total_gb": str(round(psutil.virtual_memory().total / (1024**3), 2)), + }) diff --git a/yolo11n.onnx b/yolo11n.onnx new file mode 100644 index 0000000..004f768 Binary files /dev/null and b/yolo11n.onnx differ