diff --git a/config/__pycache__/settings.cpython-310.pyc b/config/__pycache__/settings.cpython-310.pyc index 72f2a9f..b3f9afd 100644 Binary files a/config/__pycache__/settings.cpython-310.pyc and b/config/__pycache__/settings.cpython-310.pyc differ diff --git a/config/database.py b/config/database.py index 8075ebc..52503cd 100644 --- a/config/database.py +++ b/config/database.py @@ -1,316 +1,375 @@ """ -数据库连接配置模块 -提供MySQL数据库连接池管理和操作封装 +SQLite 数据库模块 +边缘AI推理服务的本地数据存储 + +特性: +- WAL 模式(Write-Ahead Logging)提升写入性能 +- 异步写入策略 +- 滚动清理机制(保留7天数据) """ +import os +import sqlite3 +import threading +import queue +import time import logging -from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional - -from sqlalchemy import create_engine, Column, String, Boolean, Integer, Float, Text, JSON, DateTime -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, Session -from sqlalchemy.pool import QueuePool - -from config.settings import get_settings, DatabaseConfig +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Generator +from dataclasses import dataclass, field +from pathlib import Path logger = logging.getLogger(__name__) -Base = declarative_base() + +@dataclass +class StorageConfig: + """存储配置类""" + db_path: str = "./data/security_events.db" + image_dir: str = "./data/captures" + retention_days: int = 7 + wal_mode: bool = True + batch_size: int = 100 + flush_interval: float = 5.0 -class CameraInfo(Base): - """摄像头信息表模型""" - __tablename__ = "camera_info" - - id = Column(Integer, primary_key=True, autoincrement=True) - camera_id = Column(String(64), unique=True, nullable=False, index=True) - camera_name = Column(String(128), nullable=True) - rtsp_url = Column(String(512), nullable=False) - status = Column(Boolean, default=True) - enabled = Column(Boolean, default=True) - location = Column(String(256), nullable=True) - extra_params = Column(JSON, nullable=True) - created_at = Column(DateTime, nullable=True) - updated_at = Column(DateTime, nullable=True) +@dataclass +class AlertRecord: + """告警记录""" + alert_id: str + camera_id: str + roi_id: str + alert_type: str + target_class: Optional[str] = None + confidence: Optional[float] = None + bbox: Optional[List[float]] = None + message: Optional[str] = None + image_path: Optional[str] = None + status: str = "pending" + created_at: datetime = field(default_factory=datetime.now) + processed_at: Optional[datetime] = None -class ROIConfigModel(Base): - """ROI配置表模型""" - __tablename__ = "roi_config" - - id = Column(Integer, primary_key=True, autoincrement=True) - roi_id = Column(String(64), unique=True, nullable=False, index=True) - camera_id = Column(String(64), nullable=False, index=True) - roi_type = Column(String(32), nullable=False) # 'polygon' or 'rectangle' - coordinates = Column(JSON, nullable=False) # 多边形顶点或矩形坐标 - algorithm_type = Column(String(32), nullable=False) # 'leave_post', 'intrusion', etc. - alert_threshold = Column(Integer, default=3) - alert_cooldown = Column(Integer, default=300) - enabled = Column(Boolean, default=True) - extra_params = Column(JSON, nullable=True) - created_at = Column(DateTime, nullable=True) - updated_at = Column(DateTime, nullable=True) - - -class AlertRecord(Base): - """告警记录表模型""" - __tablename__ = "alert_records" - - id = Column(Integer, primary_key=True, autoincrement=True) - alert_id = Column(String(64), unique=True, nullable=False, index=True) - camera_id = Column(String(64), nullable=False, index=True) - roi_id = Column(String(64), nullable=False, index=True) - alert_type = Column(String(32), nullable=False) - target_class = Column(String(64), nullable=True) - confidence = Column(Float, nullable=True) - bbox = Column(JSON, nullable=True) - message = Column(Text, nullable=True) - screenshot = Column(Text, nullable=True) # Base64编码的截图 - status = Column(String(32), default="pending") - created_at = Column(DateTime, nullable=True) - processed_at = Column(DateTime, nullable=True) - - -class DatabaseManager: - """数据库连接管理器类""" +class SQLiteManager: + """SQLite 数据库管理器""" _instance = None - _engine = None - _session_factory = None - _available = False + _lock = threading.Lock() - def __new__(cls): + def __new__(cls, config: Optional[StorageConfig] = None): if cls._instance is None: - cls._instance = super().__new__(cls) + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False return cls._instance - def __init__(self, config: Optional[DatabaseConfig] = None): - if self._engine is not None: + def __init__(self, config: Optional[StorageConfig] = None): + if self._initialized: return if config is None: - settings = get_settings() - config = settings.database + config = StorageConfig() - self._init_engine(config) + self.config = config + self._conn: Optional[sqlite3.Connection] = None + self._write_queue: queue.Queue = queue.Queue() + self._running = False + self._write_thread: Optional[threading.Thread] = None + self._cleanup_thread: Optional[threading.Thread] = None + + self._init_directories() + self._init_database() + self._start_background_threads() + + self._initialized = True + logger.info(f"SQLite 数据库初始化成功: {config.db_path}") - def _init_engine(self, config: DatabaseConfig): - """初始化数据库引擎""" - try: - connection_string = ( - f"mysql+pymysql://{config.username}:{config.password}" - f"@{config.host}:{config.port}/{config.database}" - f"?charset=utf8mb4" + def _init_directories(self): + """初始化目录""" + Path(self.config.db_path).parent.mkdir(parents=True, exist_ok=True) + Path(self.config.image_dir).mkdir(parents=True, exist_ok=True) + + def _init_database(self): + """初始化数据库表""" + self._conn = sqlite3.connect( + self.config.db_path, + check_same_thread=False, + timeout=30.0 + ) + + if self.config.wal_mode: + cursor = self._conn.cursor() + cursor.execute("PRAGMA journal_mode=WAL;") + cursor.execute("PRAGMA synchronous=NORMAL;") + cursor.execute("PRAGMA cache_size=-64000;") + self._conn.commit() + + cursor = self._conn.cursor() + + cursor.execute(""" + CREATE TABLE IF NOT EXISTS alert_records ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + alert_id TEXT UNIQUE NOT NULL, + camera_id TEXT NOT NULL, + roi_id TEXT NOT NULL, + alert_type TEXT NOT NULL, + target_class TEXT, + confidence REAL, + bbox TEXT, + message TEXT, + image_path TEXT, + status TEXT DEFAULT 'pending', + created_at TEXT NOT NULL, + processed_at TEXT ) - - self._engine = create_engine( - connection_string, - poolclass=QueuePool, - pool_size=config.pool_size, - pool_recycle=config.pool_recycle, - echo=config.echo, - pool_pre_ping=True, - max_overflow=5, - ) - - self._session_factory = sessionmaker(bind=self._engine) - - test_connection = self._engine.connect() - test_connection.close() - - self._available = True - logger.info(f"数据库引擎初始化成功: {config.host}:{config.port}/{config.database}") - - except Exception as e: - self._available = False - logger.warning(f"数据库连接失败,服务将在无数据库模式下运行: {e}") - - @property - def is_available(self) -> bool: - """检查数据库是否可用""" - return self._available - - @contextmanager - def get_session(self) -> Generator[Session, None, None]: - """获取数据库会话上下文""" - if not self._available: - logger.warning("数据库不可用,跳过数据库操作") - yield None - return + """) - session = self._session_factory() - try: - yield session - session.commit() - except Exception as e: - session.rollback() - logger.error(f"数据库操作异常: {e}") - raise - finally: - session.close() - - def get_camera_info(self, camera_id: Optional[str] = None) -> List[Dict[str, Any]]: - """获取摄像头信息""" - if not self._available: - logger.warning("数据库不可用,返回空摄像头列表") - return [] + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_alert_camera + ON alert_records(camera_id) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_alert_created + ON alert_records(created_at) + """) + cursor.execute(""" + CREATE INDEX IF NOT EXISTS idx_alert_status + ON alert_records(status) + """) - with self.get_session() as session: - if session is None: - return [] - - query = session.query(CameraInfo) - - if camera_id: - query = query.filter(CameraInfo.camera_id == camera_id) - - cameras = query.filter(CameraInfo.enabled == True).all() - - result = [] - for camera in cameras: - result.append({ - "camera_id": camera.camera_id, - "camera_name": camera.camera_name, - "rtsp_url": camera.rtsp_url, - "status": camera.status, - "location": camera.location, - "extra_params": camera.extra_params, - }) - - return result + self._conn.commit() - def get_roi_configs(self, camera_id: Optional[str] = None, - roi_id: Optional[str] = None) -> List[Dict[str, Any]]: - """获取ROI配置""" - if not self._available: - logger.warning("数据库不可用,返回空ROI配置列表") - return [] + def _start_background_threads(self): + """启动后台线程""" + self._running = True - with self.get_session() as session: - if session is None: - return [] - - query = session.query(ROIConfigModel) - - if camera_id: - query = query.filter(ROIConfigModel.camera_id == camera_id) - - if roi_id: - query = query.filter(ROIConfigModel.roi_id == roi_id) - - query = query.filter(ROIConfigModel.enabled == True) - - rois = query.all() - - result = [] - for roi in rois: - result.append({ - "roi_id": roi.roi_id, - "camera_id": roi.camera_id, - "roi_type": roi.roi_type, - "coordinates": roi.coordinates, - "algorithm_type": roi.algorithm_type, - "alert_threshold": roi.alert_threshold, - "alert_cooldown": roi.alert_cooldown, - "extra_params": roi.extra_params, - }) - - return result + self._write_thread = threading.Thread( + target=self._write_worker, + name="SQLiteWrite", + daemon=True + ) + self._write_thread.start() + + self._cleanup_thread = threading.Thread( + target=self._cleanup_worker, + name="SQLiteCleanup", + daemon=True + ) + self._cleanup_thread.start() - def save_alert_record(self, alert_data: Dict[str, Any]) -> bool: - """保存告警记录""" - if not self._available: - logger.warning("数据库不可用,跳过保存告警记录") - return False + def _write_worker(self): + """异步写入工作线程""" + batch = [] + last_flush = time.time() - try: - with self.get_session() as session: - if session is None: - return False - alert = AlertRecord( - alert_id=alert_data.get("alert_id"), - camera_id=alert_data.get("camera_id"), - roi_id=alert_data.get("roi_id"), - alert_type=alert_data.get("alert_type"), - target_class=alert_data.get("target_class"), - confidence=alert_data.get("confidence"), - bbox=alert_data.get("bbox"), - message=alert_data.get("message"), - screenshot=alert_data.get("screenshot"), - status=alert_data.get("status", "pending"), + while self._running: + try: + try: + item = self._write_queue.get(timeout=1.0) + batch.append(item) + except queue.Empty: + pass + + should_flush = ( + len(batch) >= self.config.batch_size or + time.time() - last_flush >= self.config.flush_interval ) - session.add(alert) - session.flush() - - logger.info(f"告警记录保存成功: {alert_data.get('alert_id')}") - return True - - except Exception as e: - logger.error(f"保存告警记录失败: {e}") - return False + if batch and (should_flush or len(batch) >= 1000): + self._flush_batch(batch) + batch.clear() + last_flush = time.time() + + except Exception as e: + logger.error(f"SQLite 写入异常: {e}") + + if batch: + self._flush_batch(batch) - def update_alert_status(self, alert_id: str, status: str) -> bool: + def _flush_batch(self, batch: List[Dict[str, Any]]): + """批量写入数据库""" + try: + cursor = self._conn.cursor() + + for record in batch: + cursor.execute(""" + INSERT OR REPLACE INTO alert_records ( + alert_id, camera_id, roi_id, alert_type, + target_class, confidence, bbox, message, + image_path, status, created_at, processed_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + record['alert_id'], + record['camera_id'], + record['roi_id'], + record['alert_type'], + record.get('target_class'), + record.get('confidence'), + record.get('bbox'), + record.get('message'), + record.get('image_path'), + record.get('status', 'pending'), + record['created_at'], + record.get('processed_at'), + )) + + self._conn.commit() + logger.debug(f"批量写入 {len(batch)} 条记录") + + except Exception as e: + logger.error(f"批量写入失败: {e}") + + def _cleanup_worker(self): + """清理工作线程(每天执行一次)""" + while self._running: + try: + time.sleep(3600) + if self._running: + self.cleanup_old_data() + except Exception as e: + logger.error(f"数据清理异常: {e}") + + def queue_alert(self, alert: AlertRecord): + """将告警加入写入队列""" + record = { + 'alert_id': alert.alert_id, + 'camera_id': alert.camera_id, + 'roi_id': alert.roi_id, + 'alert_type': alert.alert_type, + 'target_class': alert.target_class, + 'confidence': alert.confidence, + 'bbox': str(alert.bbox) if alert.bbox else None, + 'message': alert.message, + 'image_path': alert.image_path, + 'status': alert.status, + 'created_at': alert.created_at.isoformat(), + 'processed_at': alert.processed_at.isoformat() if alert.processed_at else None, + } + self._write_queue.put(record) + + def get_alerts( + self, + camera_id: Optional[str] = None, + status: Optional[str] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + limit: int = 100 + ) -> List[Dict[str, Any]]: + """查询告警记录""" + cursor = self._conn.cursor() + + query = "SELECT * FROM alert_records WHERE 1=1" + params = [] + + if camera_id: + query += " AND camera_id = ?" + params.append(camera_id) + + if status: + query += " AND status = ?" + params.append(status) + + if start_time: + query += " AND created_at >= ?" + params.append(start_time.isoformat()) + + if end_time: + query += " AND created_at <= ?" + params.append(end_time.isoformat()) + + query += " ORDER BY created_at DESC LIMIT ?" + params.append(limit) + + cursor.execute(query, params) + rows = cursor.fetchall() + + columns = ['id', 'alert_id', 'camera_id', 'roi_id', 'alert_type', + 'target_class', 'confidence', 'bbox', 'message', 'image_path', + 'status', 'created_at', 'processed_at'] + + return [dict(zip(columns, row)) for row in rows] + + def update_status(self, alert_id: str, status: str) -> bool: """更新告警状态""" - if not self._available: - logger.warning("数据库不可用,跳过更新告警状态") - return False - try: - with self.get_session() as session: - if session is None: - return False - - from sqlalchemy import update - from datetime import datetime - - stmt = update(AlertRecord).where( - AlertRecord.alert_id == alert_id - ).values( - status=status, - processed_at=datetime.now() - ) - - session.execute(stmt) - logger.info(f"告警状态更新成功: {alert_id} -> {status}") - return True - + cursor = self._conn.cursor() + cursor.execute(""" + UPDATE alert_records + SET status = ?, processed_at = ? + WHERE alert_id = ? + """, (status, datetime.now().isoformat(), alert_id)) + self._conn.commit() + return cursor.rowcount > 0 except Exception as e: - logger.error(f"更新告警状态失败: {e}") + logger.error(f"更新状态失败: {e}") return False - def create_tables(self): - """创建所有表""" - if not self._available: - logger.warning("数据库不可用,跳过创建表") - return - - Base.metadata.create_all(self._engine) - logger.info("数据库表创建完成") + def cleanup_old_data(self): + """清理过期数据""" + try: + cutoff = (datetime.now() - timedelta(days=self.config.retention_days)).isoformat() + + cursor = self._conn.cursor() + + cursor.execute("SELECT image_path FROM alert_records WHERE created_at < ?", (cutoff,)) + images = cursor.fetchall() + for (img_path,) in images: + if img_path and os.path.exists(img_path): + try: + os.remove(img_path) + except Exception: + pass + + cursor.execute("DELETE FROM alert_records WHERE created_at < ?", (cutoff,)) + deleted = cursor.rowcount + self._conn.commit() + + logger.info(f"清理完成: 删除 {deleted} 条过期记录") + return deleted + + except Exception as e: + logger.error(f"数据清理失败: {e}") + return 0 - def drop_tables(self): - """删除所有表""" - if not self._available: - return + def get_statistics(self) -> Dict[str, Any]: + """获取统计信息""" + cursor = self._conn.cursor() - Base.metadata.drop_all(self._engine) - logger.info("数据库表删除完成") + cursor.execute("SELECT COUNT(*) FROM alert_records") + total = cursor.fetchone()[0] + + cursor.execute("SELECT COUNT(*) FROM alert_records WHERE status = 'pending'") + pending = cursor.fetchone()[0] + + cursor.execute("SELECT COUNT(*) FROM alert_records WHERE created_at > ?", + ((datetime.now() - timedelta(hours=24)).isoformat(),)) + today = cursor.fetchone()[0] + + db_size = os.path.getsize(self.config.db_path) / (1024 * 1024) + + return { + "total_alerts": total, + "pending_alerts": pending, + "today_alerts": today, + "db_size_mb": round(db_size, 2), + "queue_size": self._write_queue.qsize(), + "retention_days": self.config.retention_days, + } def close(self): - """关闭数据库连接""" - if self._engine: - self._engine.dispose() - logger.info("数据库连接已关闭") + """关闭数据库""" + self._running = False + + if self._write_thread and self._write_thread.is_alive(): + self._write_thread.join(timeout=10) + + if self._conn: + self._conn.close() + + logger.info("SQLite 数据库已关闭") -def get_database_manager() -> DatabaseManager: - """获取数据库管理器单例""" - return DatabaseManager() - - -def init_database(): - """初始化数据库""" - db_manager = get_database_manager() - db_manager.create_tables() - return db_manager +def get_sqlite_manager() -> SQLiteManager: + """获取 SQLite 管理器单例""" + return SQLiteManager() diff --git a/config/settings.py b/config/settings.py index 7978772..8cf7883 100644 --- a/config/settings.py +++ b/config/settings.py @@ -10,7 +10,7 @@ from typing import List, Optional @dataclass class DatabaseConfig: - """数据库配置类""" + """数据库配置类(MySQL - 云端)""" host: str = "localhost" port: int = 3306 username: str = "root" @@ -21,6 +21,17 @@ class DatabaseConfig: echo: bool = False +@dataclass +class SQLiteConfig: + """SQLite 配置(边缘侧本地存储)""" + db_path: str = "./data/security_events.db" + image_dir: str = "./data/captures" + retention_days: int = 7 + wal_mode: bool = True + batch_size: int = 100 + flush_interval: float = 5.0 + + @dataclass class RedisConfig: """Redis配置类""" @@ -61,7 +72,7 @@ class VideoStreamConfig: @dataclass class InferenceConfig: """推理配置类""" - model_path: str = "./models/yolov8s.engine" + model_path: str = "./models/yolo11n.engine" input_width: int = 480 input_height: int = 480 batch_size: int = 1 @@ -111,7 +122,6 @@ class Settings: def _load_env_vars(self): """从环境变量加载配置""" - # 数据库配置 self.database = DatabaseConfig( host=os.getenv("DB_HOST", "localhost"), port=int(os.getenv("DB_PORT", "3306")), @@ -120,14 +130,19 @@ class Settings: database=os.getenv("DB_DATABASE", "edge_inference"), ) - # Redis配置 + self.sqlite = SQLiteConfig( + db_path=os.getenv("SQLITE_DB_PATH", "./data/security_events.db"), + image_dir=os.getenv("SQLITE_IMAGE_DIR", "./data/captures"), + retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")), + wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1", + ) + self.redis = RedisConfig( host=os.getenv("REDIS_HOST", "localhost"), port=int(os.getenv("REDIS_PORT", "6379")), password=os.getenv("REDIS_PASSWORD"), ) - # MQTT配置 self.mqtt = MQTTConfig( broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"), broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")), @@ -136,13 +151,11 @@ class Settings: password=os.getenv("MQTT_PASSWORD"), ) - # 视频流配置 self.video_stream = VideoStreamConfig( default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")), reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")), ) - # 推理配置 self.inference = InferenceConfig( model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"), input_width=int(os.getenv("INPUT_WIDTH", "480")), @@ -153,13 +166,11 @@ class Settings: nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")), ) - # 日志配置 self.log_level = os.getenv("LOG_LEVEL", "INFO") self.log_dir = os.getenv("LOG_DIR", "./logs") - self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760")) # 10MB + self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760")) self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5")) - # 工作时间配置 self.working_hours = self._parse_working_hours() def _parse_working_hours(self) -> List[dict]: diff --git a/core/__pycache__/config_sync.cpython-310.pyc b/core/__pycache__/config_sync.cpython-310.pyc index 344ecda..3e04b1f 100644 Binary files a/core/__pycache__/config_sync.cpython-310.pyc and b/core/__pycache__/config_sync.cpython-310.pyc differ diff --git a/core/__pycache__/tensorrt_engine.cpython-310.pyc b/core/__pycache__/tensorrt_engine.cpython-310.pyc index f9c7b19..f144f68 100644 Binary files a/core/__pycache__/tensorrt_engine.cpython-310.pyc and b/core/__pycache__/tensorrt_engine.cpython-310.pyc differ diff --git a/core/result_reporter.py b/core/result_reporter.py index 36f3fe4..cb821fd 100644 --- a/core/result_reporter.py +++ b/core/result_reporter.py @@ -1,6 +1,11 @@ """ 结果上报模块 -实现MQTT客户端、告警存储、重试机制 +支持 MQTT、云端 MySQL 和本地 SQLite 存储 + +存储策略: +- MQTT: 实时推送告警到云端 +- SQLite: 本地异步存储(边缘侧断网容灾) +- MySQL: 可选云端同步 """ import json @@ -8,359 +13,65 @@ import logging import threading import time from datetime import datetime -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional, Callable +from dataclasses import dataclass, field +from pathlib import Path import numpy as np import paho.mqtt.client as mqtt from paho.mqtt.client import MQTTMessage -from config.settings import get_settings, MQTTConfig -from config.database import get_database_manager -from config.config_models import AlertInfo -from utils.logger import get_logger -from utils.common import generate_unique_id, ExponentialBackoff - logger = logging.getLogger(__name__) -class MQTTClient: - """MQTT客户端类 +@dataclass +class AlertInfo: + """告警信息类""" + alert_id: str + camera_id: str + roi_id: str + alert_type: str + target_class: Optional[str] = None + confidence: Optional[float] = None + bbox: Optional[List[float]] = field(default_factory=list) + message: Optional[str] = None + screenshot: Optional[np.ndarray] = None + timestamp: datetime = field(default_factory=datetime.now) - 实现与云端业务平台的可靠连接 - """ - - def __init__(self, config: Optional[MQTTConfig] = None): - """ - 初始化MQTT客户端 - - Args: - config: MQTT配置 - """ - if config is None: - settings = get_settings() - config = settings.mqtt - - self.config = config - self._client = None - self._connected = False - self._reconnect_thread: Optional[threading.Thread] = None - self._stop_event = threading.Event() - - self._message_callbacks: Dict[str, Callable] = {} - self._pending_messages: List[Dict] = [] - self._pending_lock = threading.Lock() - - self._logger = get_logger("mqtt") - self._lock = threading.Lock() - - self._performance_stats = { - "messages_sent": 0, - "messages_received": 0, - "send_errors": 0, - "reconnect_count": 0, - } - - def _on_connect(self, client, userdata, flags, rc): - """连接回调""" - if rc == 0: - self._connected = True - self._logger.log_connection_event( - "connect", "MQTT", self.config.broker_host, True - ) - self._logger.info(f"MQTT连接成功: {self.config.broker_host}:{self.config.broker_port}") - - self._resend_pending_messages() - else: - self._connected = False - self._logger.log_connection_event( - "connect", "MQTT", self.config.broker_host, False, - error_msg=f"错误码: {rc}" - ) - - def _on_disconnect(self, client, userdata, rc): - """断开连接回调""" - self._connected = False - self._logger.log_connection_event( - "disconnect", "MQTT", self.config.broker_host, False - ) - - if rc != 0: - self._performance_stats["reconnect_count"] += 1 - self._logger.warning(f"MQTT意外断开,尝试重连: {rc}") - self._start_reconnect() - - def _on_message(self, client, userdata, msg: MQTTMessage): - """消息回调""" - self._performance_stats["messages_received"] += 1 - - topic = msg.topic - payload = msg.payload.decode('utf-8') - - self._logger.debug(f"收到消息: {topic} -> {payload}") - - if topic in self._message_callbacks: - try: - self._message_callbacks[topic](payload) - except Exception as e: - self._logger.error(f"消息回调执行失败: {e}") - - def _on_publish(self, client, userdata, mid): - """发布回调""" - self._performance_stats["messages_sent"] += 1 - self._logger.debug(f"消息已发布: mid={mid}") - - def _create_client(self) -> mqtt.Client: - """创建MQTT客户端""" - client = mqtt.Client( - client_id=self.config.client_id, - clean_session=True, - userdata=None, - protocol=mqtt.MQTTv311, - transport="tcp" - ) - - client.on_connect = self._on_connect - client.on_disconnect = self._on_disconnect - client.on_message = self._on_message - client.on_publish = self._on_publish - - if self.config.username and self.config.password: - client.username_pw_set( - self.config.username, - self.config.password - ) - - client.reconnect_delay_set( - min_delay=self.config.reconnect_delay, - max_delay=self.config.reconnect_delay * 2 - ) - - return client - - def _start_reconnect(self): - """启动重连线程""" - if self._reconnect_thread is not None and self._reconnect_thread.is_alive(): - return - - self._stop_event.clear() - self._reconnect_thread = threading.Thread( - target=self._reconnect_loop, - name="MQTTReconnect", - daemon=True - ) - self._reconnect_thread.start() - - def _reconnect_loop(self): - """重连循环""" - backoff = ExponentialBackoff( - base_delay=self.config.reconnect_delay, - max_delay=self.config.reconnect_delay * 3, - max_attempts=self.config.max_reconnect_attempts - ) - - while not self._stop_event.is_set(): - if self._connected: - break - - if not backoff.next_attempt(): - self._logger.error("MQTT重连次数已达上限") - break - - delay = backoff.get_delay() - self._logger.info(f"等待 {delay:.1f}秒后尝试重连MQTT...") - time.sleep(delay) - - try: - self._client.connect( - self.config.broker_host, - self.config.broker_port, - self.config.keepalive - ) - self._client.loop_start() - except Exception as e: - self._logger.error(f"MQTT重连失败: {e}") - - def _resend_pending_messages(self): - """重发待发送消息""" - with self._pending_lock: - if not self._pending_messages: - return - - self._logger.info(f"重发待处理消息: {len(self._pending_messages)}条") - - failed_messages = [] - - for msg in self._pending_messages: - try: - self._publish_message( - msg["topic"], - msg["payload"], - msg.get("qos", self.config.qos), - msg.get("retain", False) - ) - except Exception: - failed_messages.append(msg) - - self._pending_messages = failed_messages - - def connect(self) -> bool: - """ - 连接到MQTT代理 - - Returns: - 是否连接成功 - """ - with self._lock: - try: - if self._client is not None: - self._client.loop_stop() - - self._client = self._create_client() - - self._client.connect( - self.config.broker_host, - self.config.broker_port, - self.config.keepalive - ) - self._client.loop_start() - - self._logger.info( - f"MQTT正在连接: {self.config.broker_host}:{self.config.broker_port}" - ) - return True - - except Exception as e: - self._logger.error(f"MQTT连接失败: {e}") - return False - - def disconnect(self): - """断开MQTT连接""" - with self._lock: - if self._client: - self._stop_event.set() - self._client.loop_stop() - self._client.disconnect() - self._client = None - self._connected = False - self._logger.info("MQTT连接已断开") - - def subscribe(self, topic: str, qos: int = 1): - """ - 订阅主题 - - Args: - topic: 主题 - qos: 服务质量等级 - """ - if self._client and self._connected: - self._client.subscribe(topic, qos) - self._logger.info(f"已订阅主题: {topic}") - - def publish( - self, - topic: str, - payload: Dict[str, Any], - qos: Optional[int] = None, - retain: bool = False - ) -> bool: - """ - 发布消息 - - Args: - topic: 主题 - payload: 消息负载 - qos: 服务质量等级 - retain: 保留消息 - - Returns: - 是否发布成功 - """ - if qos is None: - qos = self.config.qos - - try: - message = json.dumps(payload, ensure_ascii=False) - - if self._connected and self._client: - self._publish_message(topic, message, qos, retain) - return True - else: - self._queue_message(topic, payload, qos, retain) - return False - - except Exception as e: - self._performance_stats["send_errors"] += 1 - self._logger.error(f"发布消息失败: {e}") - return False - - def _publish_message( - self, - topic: str, - message: str, - qos: int, - retain: bool - ): - """内部发布方法""" - self._client.publish(topic, message, qos, retain) - self._logger.debug(f"消息已发布: {topic}") - - def _queue_message( - self, - topic: str, - payload: Dict[str, Any], - qos: int, - retain: bool - ): - """将消息加入待发送队列""" - with self._pending_lock: - self._pending_messages.append({ - "topic": topic, - "payload": payload, - "qos": qos, - "retain": retain, - "timestamp": datetime.now().isoformat(), - }) - - self._logger.warning(f"消息已加入待发送队列: {topic}") - - def register_message_callback(self, topic: str, callback: Callable): - """注册消息回调""" - self._message_callbacks[topic] = callback - self._logger.info(f"已注册消息回调: {topic}") - - def get_performance_stats(self) -> Dict[str, int]: - """获取性能统计""" - return dict(self._performance_stats) - - def get_status(self) -> Dict[str, Any]: - """获取连接状态""" + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" return { - "connected": self._connected, - "broker_host": self.config.broker_host, - "broker_port": self.config.broker_port, - "pending_messages": len(self._pending_messages), - "stats": self.get_performance_stats(), + "alert_id": self.alert_id, + "camera_id": self.camera_id, + "roi_id": self.roi_id, + "alert_type": self.alert_type, + "target_class": self.target_class, + "confidence": self.confidence, + "bbox": self.bbox, + "message": self.message, + "timestamp": self.timestamp.isoformat(), } - - def close(self): - """关闭客户端""" - self.disconnect() - self._logger.info("MQTT客户端已关闭") -class AlertReporter: - """告警上报器类 +class ResultReporter: + """结果上报器类""" - 整合MQTT上报、数据库存储、重试机制 - """ - - def __init__(self): - self._mqtt_client = None - self._db_manager = None - self._logger = get_logger("reporter") - self._lock = threading.Lock() + def __init__( + self, + mqtt_broker: str = "localhost", + mqtt_port: int = 1883, + topic_prefix: str = "edge/alert", + mqtt_client_id: str = "edge_reporter" + ): + self._mqtt_broker = mqtt_broker + self._mqtt_port = mqtt_port + self._topic_prefix = topic_prefix + self._mqtt_client_id = mqtt_client_id - self._topic_prefix = "edge/alerts" + self._client: Optional[mqtt.Client] = None + self._connected = False + self._reconnect_count = 0 + self._lock = threading.Lock() self._performance_stats = { "alerts_generated": 0, @@ -368,72 +79,287 @@ class AlertReporter: "alerts_stored": 0, "send_failures": 0, } + + self._logger = logging.getLogger("result_reporter") + + self._db_manager = None + self._image_storage = None + self._local_cache = None + + self._logger.info("ResultReporter 初始化完成") + + def initialize(self): + """初始化存储和MQTT""" + from config.database import SQLiteManager, AlertRecord, get_sqlite_manager + from core.storage_manager import get_image_storage, get_local_cache + + self._logger.info("初始化存储管理器...") + + self._db_manager = get_sqlite_manager() + self._image_storage = get_image_storage() + self._local_cache = get_local_cache() + + self._logger.info("存储管理器初始化完成") + + self._init_mqtt() def _init_mqtt(self): """初始化MQTT客户端""" - if self._mqtt_client is None: - self._mqtt_client = MQTTClient() - self._mqtt_client.connect() + try: + self._client = mqtt.Client( + client_id=self._mqtt_client_id, + protocol=mqtt.MQTTv5, + callback_api_version=mqtt.CallbackAPIVersion.VERSION2 + ) + + self._client.on_connect = self._on_connect + self._client.on_disconnect = self._on_disconnect + self._client.on_publish = self._on_publish + + self._client.connect(self._mqtt_broker, self._mqtt_port, 60) + self._client.loop_start() + + self._logger.info(f"MQTT 客户端初始化: {self._mqtt_broker}:{self._mqtt_port}") + + except Exception as e: + self._logger.warning(f"MQTT 初始化失败: {e}") + self._client = None - def _init_database(self): - """初始化数据库连接""" - if self._db_manager is None: - self._db_manager = get_database_manager() + def _on_connect(self, client, userdata, flags, reason_code, properties): + """MQTT连接回调""" + if reason_code == 0: + self._connected = True + self._reconnect_count = 0 + self._logger.info("MQTT 连接成功") + + if self._local_cache: + pending = self._local_cache.get_pending_count() + if pending > 0: + self._logger.info(f"有 {pending} 条待同步的缓存告警") + else: + self._logger.warning(f"MQTT 连接失败: {reason_code}") + + def _on_disconnect(self, client, userdata, reason_code, properties): + """MQTT断开连接回调""" + self._connected = False + self._logger.warning(f"MQTT 连接断开: {reason_code}") + + def _on_publish(self, client, userdata, mid, reason_code, properties): + """MQTT发布回调""" + if reason_code == 0: + self._logger.debug(f"MQTT 消息发布成功: {mid}") def report_alert( self, alert: AlertInfo, - screenshot: Optional[np.ndarray] = None, - store_to_db: bool = True + store_to_db: bool = True, + screenshot: Optional[np.ndarray] = None ) -> bool: """ 上报告警 Args: alert: 告警信息 - screenshot: 截图 (可选) - store_to_db: 是否存储到数据库 + store_to_db: 是否存储到本地数据库 + screenshot: 抓拍图片 Returns: 是否上报成功 """ + self._performance_stats["alerts_generated"] += 1 + + try: + if store_to_db and self._db_manager: + self._store_alert(alert, screenshot) + + if self._connected and self._client: + self._publish_alert(alert) + else: + self._logger.warning("MQTT 未连接,消息已加入待发送队列") + if self._local_cache: + self._local_cache.cache_alert(alert.to_dict()) + + self._performance_stats["alerts_sent"] += 1 + return True + + except Exception as e: + self._performance_stats["send_failures"] += 1 + self._logger.error(f"上报告警失败: {e}") + return False + + def _store_alert(self, alert: AlertInfo, screenshot: Optional[np.ndarray] = None): + """存储告警到本地数据库(异步)""" + image_path = None + + if screenshot is not None and self._image_storage: + image_path = self._image_storage.save_capture( + image=screenshot, + camera_id=alert.camera_id, + alert_id=alert.alert_id, + timestamp=alert.timestamp + ) + + record = AlertRecord( + alert_id=alert.alert_id, + camera_id=alert.camera_id, + roi_id=alert.roi_id, + alert_type=alert.alert_type, + target_class=alert.target_class, + confidence=alert.confidence, + bbox=alert.bbox, + message=alert.message, + image_path=image_path, + status="pending", + created_at=alert.timestamp, + ) + + if self._db_manager: + self._db_manager.queue_alert(record) + self._performance_stats["alerts_stored"] += 1 + + def _publish_alert(self, alert: AlertInfo): + """发布告警到MQTT""" + alert_data = alert.to_dict() + + topic = f"{self._topic_prefix}/{alert.camera_id}/{alert.roi_id}" + + result = self._client.publish(topic, json.dumps(alert_data, ensure_ascii=False)) + + if result[0] == mqtt.MQTT_ERR_SUCCESS: + self._logger.log_alert( + alert.alert_type, + alert.camera_id, + alert.roi_id, + alert.confidence + ) + else: + raise Exception(f"MQTT 发布失败: {result[0]}") + + def get_pending_alerts(self) -> List[Dict[str, Any]]: + """获取待同步的告警""" + if self._local_cache: + return self._local_cache.get_pending_alerts() + return [] + + def sync_pending_alerts(self) -> int: + """同步待处理的告警到云端""" + if not self._connected or not self._client: + return 0 + + pending = self.get_pending_alerts() + synced = 0 + + for alert in pending: + try: + self._client.publish( + f"{self._topic_prefix}/{alert['camera_id']}/{alert['roi_id']}", + json.dumps(alert, ensure_ascii=False) + ) + if self._local_cache: + self._local_cache.remove_cached(alert.get('_cache_id', '')) + synced += 1 + except Exception as e: + self._logger.error(f"同步告警失败: {e}") + + if synced > 0: + self._logger.info(f"已同步 {synced} 条告警到云端") + + return synced + + def get_statistics(self) -> Dict[str, Any]: + """获取统计信息""" + stats = self._performance_stats.copy() + + if self._db_manager: + db_stats = self._db_manager.get_statistics() + stats.update({ + "db_total_alerts": db_stats.get("total_alerts", 0), + "db_pending_alerts": db_stats.get("pending_alerts", 0), + }) + + if self._image_storage: + img_stats = self._image_storage.get_statistics() + stats["image_saved"] = img_stats.get("saved_count", 0) + + if self._local_cache: + stats["pending_sync"] = self._local_cache.get_pending_count() + + stats["mqtt_connected"] = self._connected + + return stats + + def cleanup(self): + """清理资源""" + self._logger.info("ResultReporter 资源清理") + + if self._image_storage: + self._image_storage.close() + + if self._db_manager: + self._db_manager.close() + + if self._client: + self._client.loop_stop() + self._client.disconnect() + + self._logger.info("ResultReporter 清理完成") + + +class AlertReporter: + """告警上报器(简化版)""" + + def __init__(self, topic_prefix: str = "edge/alert"): + self._topic_prefix = topic_prefix + self._lock = threading.Lock() + self._performance_stats = { + "alerts_generated": 0, + "alerts_sent": 0, + "alerts_stored": 0, + "send_failures": 0, + } + self._logger = logging.getLogger("alert_reporter") + + def initialize(self): + """初始化""" + from config.database import get_sqlite_manager + from core.storage_manager import get_image_storage + + self._db_manager = get_sqlite_manager() + self._image_storage = get_image_storage() + self._logger.info("AlertReporter 初始化完成") + + def report_alert( + self, + alert: AlertInfo, + store_to_db: bool = True, + screenshot: Optional[np.ndarray] = None + ) -> bool: + """ + 上报告警 + + Args: + alert: 告警信息 + store_to_db: 是否存储到本地数据库 + screenshot: 抓拍图片 + + Returns: + 是否成功 + """ with self._lock: self._performance_stats["alerts_generated"] += 1 try: - self._init_mqtt() - self._init_database() + if store_to_db and self._db_manager: + self._store_alert(alert, screenshot) - alert_data = alert.to_dict() + self._logger.info( + f"告警已记录: {alert.alert_type} - {alert.camera_id}/{alert.roi_id}" + ) - if screenshot is not None: - from utils.common import image_to_base64 - alert_data["screenshot"] = image_to_base64(screenshot) + with self._lock: + self._performance_stats["alerts_sent"] += 1 - topic = f"{self._topic_prefix}/{alert.camera_id}/{alert.roi_id}" - - success = self._mqtt_client.publish(topic, alert_data) - - if store_to_db: - db_success = self._store_to_database(alert.to_dict()) - if db_success: - with self._lock: - self._performance_stats["alerts_stored"] += 1 - - if success: - with self._lock: - self._performance_stats["alerts_sent"] += 1 - self._logger.log_alert( - alert.alert_type, - alert.camera_id, - alert.roi_id, - alert.confidence - ) - else: - with self._lock: - self._performance_stats["send_failures"] += 1 - - return success + return True except Exception as e: with self._lock: @@ -441,219 +367,46 @@ class AlertReporter: self._logger.error(f"上报告警失败: {e}") return False - def _store_to_database(self, alert_data: Dict[str, Any]) -> bool: - """存储告警到数据库""" - try: - return self._db_manager.save_alert_record(alert_data) - except Exception as e: - self._logger.error(f"存储告警到数据库失败: {e}") - return False + def _store_alert(self, alert: AlertInfo, screenshot: Optional[np.ndarray] = None): + """存储告警""" + from config.database import AlertRecord, get_sqlite_manager + + image_path = None + + if screenshot is not None and self._image_storage: + image_path = self._image_storage.save_capture( + image=screenshot, + camera_id=alert.camera_id, + alert_id=alert.alert_id, + timestamp=alert.timestamp + ) + + record = AlertRecord( + alert_id=alert.alert_id, + camera_id=alert.camera_id, + roi_id=alert.roi_id, + alert_type=alert.alert_type, + target_class=alert.target_class, + confidence=alert.confidence, + bbox=alert.bbox, + message=alert.message, + image_path=image_path, + status="pending", + created_at=alert.timestamp, + ) + + db_manager = get_sqlite_manager() + db_manager.queue_alert(record) + + with self._lock: + self._performance_stats["alerts_stored"] += 1 - def batch_report( - self, - alerts: List[AlertInfo], - screenshots: Optional[Dict[str, np.ndarray]] = None - ) -> Dict[str, int]: - """ - 批量上报告警 - - Args: - alerts: 告警列表 - screenshots: 截图字典 (alert_id -> screenshot) - - Returns: - 上报结果统计 - """ - results = { - "total": len(alerts), - "success": 0, - "failed": 0, - } - - screenshots = screenshots or {} - - for alert in alerts: - screenshot = screenshots.get(alert.alert_id) - if self.report_alert(alert, screenshot): - results["success"] += 1 - else: - results["failed"] += 1 - - return results - - def report_heartbeat(self, device_id: str, status: Dict[str, Any]) -> bool: - """ - 上报心跳 - - Args: - device_id: 设备ID - status: 状态信息 - - Returns: - 是否上报成功 - """ - try: - self._init_mqtt() - - heartbeat_data = { - "device_id": device_id, - "timestamp": datetime.now().isoformat(), - "status": status, - } - - topic = f"edge/heartbeat/{device_id}" - return self._mqtt_client.publish(topic, heartbeat_data) - - except Exception as e: - self._logger.error(f"上报心跳失败: {e}") - return False - - def get_performance_stats(self) -> Dict[str, int]: - """获取性能统计""" + def get_statistics(self) -> Dict[str, Any]: + """获取统计""" with self._lock: return dict(self._performance_stats) - - def get_status(self) -> Dict[str, Any]: - """获取状态""" - return { - "mqtt": self._mqtt_client.get_status() if self._mqtt_client else None, - "stats": self.get_performance_stats(), - } - - def close(self): - """关闭资源""" - if self._mqtt_client: - self._mqtt_client.close() - self._mqtt_client = None -class ResultReporter: - """结果上报主类 - - 统一管理告警上报 - """ - - def __init__(self): - self._alert_reporter = AlertReporter() - self._logger = get_logger("reporter") - - def report_detection_alert( - self, - camera_id: str, - roi_id: str, - alert_type: str, - detection: Dict[str, Any], - screenshot: Optional[np.ndarray] = None - ) -> bool: - """ - 上报检测告警 - - Args: - camera_id: 摄像头ID - roi_id: ROI ID - alert_type: 告警类型 - detection: 检测结果 - screenshot: 截图 - - Returns: - 是否上报成功 - """ - alert = AlertInfo( - alert_id=generate_unique_id("alert"), - camera_id=camera_id, - roi_id=roi_id, - alert_type=alert_type, - target_class=detection.get("class_name"), - confidence=detection.get("confidence"), - bbox=detection.get("bbox"), - message=detection.get("message", f"检测到{alert_type}"), - timestamp=datetime.now().isoformat(), - ) - - return self._alert_reporter.report_alert(alert, screenshot) - - def report_leave_post_alert( - self, - camera_id: str, - roi_id: str, - duration_seconds: int, - screenshot: Optional[np.ndarray] = None - ) -> bool: - """ - 上报离岗告警 - - Args: - camera_id: 摄像头ID - roi_id: ROI ID - duration_seconds: 离岗时长(秒) - screenshot: 截图 - - Returns: - 是否上报成功 - """ - detection = { - "class_name": "person", - "confidence": None, - "message": f"离岗告警: 已离开{duration_seconds}秒", - } - - return self.report_detection_alert( - camera_id=camera_id, - roi_id=roi_id, - alert_type="leave_post", - detection=detection, - screenshot=screenshot - ) - - def report_intrusion_alert( - self, - camera_id: str, - roi_id: str, - confidence: float, - bbox: List[float], - screenshot: Optional[np.ndarray] = None - ) -> bool: - """ - 上报入侵告警 - - Args: - camera_id: 摄像头ID - roi_id: ROI ID - confidence: 置信度 - bbox: 检测框 - screenshot: 截图 - - Returns: - 是否上报成功 - """ - detection = { - "class_name": "intrusion", - "confidence": confidence, - "bbox": bbox, - "message": "检测到周界入侵", - } - - return self.report_detection_alert( - camera_id=camera_id, - roi_id=roi_id, - alert_type="intrusion", - detection=detection, - screenshot=screenshot - ) - - def report_heartbeat(self, device_id: str, status: Dict[str, Any]) -> bool: - """上报心跳""" - return self._alert_reporter.report_heartbeat(device_id, status) - - def get_status(self) -> Dict[str, Any]: - """获取状态""" - return self._alert_reporter.get_status() - - def close(self): - """关闭资源""" - self._alert_reporter.close() - - -def get_result_reporter() -> ResultReporter: - """获取结果上报器单例""" - return ResultReporter() +def create_alert_reporter(topic_prefix: str = "edge/alert") -> AlertReporter: + """创建告警上报器""" + return AlertReporter(topic_prefix=topic_prefix) diff --git a/core/storage_manager.py b/core/storage_manager.py new file mode 100644 index 0000000..4d7e3d0 --- /dev/null +++ b/core/storage_manager.py @@ -0,0 +1,281 @@ +""" +存储管理器模块 +负责图片保存、异步写入和断网容灾 + +特性: +- 异步保存抓拍图片 +- 断网本地缓存 +- 批量同步到云端 +""" + +import os +import cv2 +import uuid +import logging +import threading +import queue +import time +from datetime import datetime +from pathlib import Path +from typing import Optional, Dict, Any, List +from dataclasses import dataclass, field + +logger = logging.getLogger(__name__) + + +@dataclass +class CaptureConfig: + """抓拍配置""" + image_dir: str = "./data/captures" + quality: int = 85 + max_width: int = 1920 + max_height: int = 1080 + save_format: str = ".jpg" + + +@dataclass +class PendingCapture: + """待保存的抓拍""" + image: Any + camera_id: str + alert_id: str + timestamp: datetime = field(default_factory=datetime.now) + + +class ImageStorageManager: + """图片存储管理器""" + + _instance = None + _lock = threading.Lock() + + def __new__(cls, config: Optional[CaptureConfig] = None): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self, config: Optional[CaptureConfig] = None): + if self._initialized: + return + + if config is None: + config = CaptureConfig() + + self.config = config + self._save_queue: queue.Queue = queue.Queue() + self._running = False + self._save_thread: Optional[threading.Thread] = None + self._saved_count = 0 + self._failed_count = 0 + + self._init_directories() + self._start_save_thread() + + self._initialized = True + logger.info(f"图片存储管理器初始化: {config.image_dir}") + + def _init_directories(self): + """初始化目录""" + Path(self.config.image_dir).mkdir(parents=True, exist_ok=True) + + def _start_save_thread(self): + """启动保存线程""" + self._running = True + self._save_thread = threading.Thread( + target=self._save_worker, + name="ImageSave", + daemon=True + ) + self._save_thread.start() + + def _save_worker(self): + """图片保存工作线程""" + while self._running: + try: + try: + capture: PendingCapture = self._save_queue.get(timeout=1.0) + self._save_image(capture) + except queue.Empty: + pass + except Exception as e: + logger.error(f"图片保存异常: {e}") + + def _save_image(self, capture: PendingCapture) -> Optional[str]: + """保存单张图片""" + try: + image = capture.image + + if image is None: + self._failed_count += 1 + return None + + if len(image.shape) == 3 and image.shape[2] == 3: + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if image.shape[1] > self.config.max_width or image.shape[0] > self.config.max_height: + scale = min( + self.config.max_width / image.shape[1], + self.config.max_height / image.shape[0] + ) + new_size = ( + int(image.shape[1] * scale), + int(image.shape[0] * scale) + ) + image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA) + + date_dir = capture.timestamp.strftime("%Y%m%d") + save_dir = Path(self.config.image_dir) / date_dir + save_dir.mkdir(parents=True, exist_ok=True) + + filename = f"{capture.camera_id}_{capture.alert_id}{self.config.save_format}" + filepath = save_dir / filename + + success = cv2.imwrite( + str(filepath), + image, + [cv2.IMWRITE_JPEG_QUALITY, self.config.quality] + ) + + if success: + self._saved_count += 1 + logger.debug(f"图片已保存: {filepath}") + return str(filepath) + else: + self._failed_count += 1 + return None + + except Exception as e: + logger.error(f"保存图片失败: {e}") + self._failed_count += 1 + return None + + def save_capture( + self, + image: Any, + camera_id: str, + alert_id: str, + timestamp: Optional[datetime] = None + ) -> Optional[str]: + """异步保存抓拍图片""" + capture = PendingCapture( + image=image, + camera_id=camera_id, + alert_id=alert_id, + timestamp=timestamp or datetime.now() + ) + self._save_queue.put(capture) + return f"" + + def get_image_path(self, camera_id: str, alert_id: str) -> Optional[str]: + """获取已保存图片路径""" + date_str = datetime.now().strftime("%Y%m%d") + filename = f"{camera_id}_{alert_id}{self.config.save_format}" + filepath = Path(self.config.image_dir) / date_str / filename + + if filepath.exists(): + return str(filepath) + return None + + def cleanup_old_images(self, days: int = 7): + """清理过期图片""" + cutoff = datetime.now().timestamp() - (days * 24 * 60 * 60) + cleaned = 0 + + try: + for filepath in Path(self.config.image_dir).rglob(f"*{self.config.save_format}"): + if filepath.stat().st_mtime < cutoff: + filepath.unlink() + cleaned += 1 + + logger.info(f"清理完成: 删除 {cleaned} 张过期图片") + return cleaned + + except Exception as e: + logger.error(f"清理图片失败: {e}") + return 0 + + def get_statistics(self) -> Dict[str, Any]: + """获取统计信息""" + return { + "saved_count": self._saved_count, + "failed_count": self._failed_count, + "queue_size": self._save_queue.qsize(), + "image_dir": self.config.image_dir, + } + + def close(self): + """关闭存储管理器""" + self._running = False + + if self._save_thread and self._save_thread.is_alive(): + self._save_thread.join(timeout=10) + + logger.info(f"图片存储已关闭: 成功 {self._saved_count}, 失败 {self._failed_count}") + + +class LocalCacheManager: + """本地缓存管理器(断网容灾)""" + + def __init__(self, cache_dir: str = "./data/cache"): + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + self._pending_dir = self.cache_dir / "pending" + self._pending_dir.mkdir(exist_ok=True) + + self._sync_lock = threading.Lock() + + def cache_alert(self, alert_data: Dict[str, Any]) -> str: + """缓存告警数据(断网时使用)""" + cache_id = str(uuid.uuid4()) + cache_path = self._pending_dir / f"{cache_id}.json" + + try: + import json + with open(cache_path, 'w', encoding='utf-8') as f: + json.dump(alert_data, f, ensure_ascii=False, indent=2) + return cache_id + except Exception as e: + logger.error(f"缓存告警失败: {e}") + return "" + + def get_pending_count(self) -> int: + """获取待同步数量""" + return len(list(self._pending_dir.glob("*.json"))) + + def get_pending_alerts(self) -> List[Dict[str, Any]]: + """获取待同步的告警""" + alerts = [] + try: + import json + for cache_path in self._pending_dir.glob("*.json"): + with open(cache_path, 'r', encoding='utf-8') as f: + alert = json.load(f) + alert['_cache_id'] = cache_path.stem + alerts.append(alert) + except Exception as e: + logger.error(f"读取缓存告警失败: {e}") + return alerts + + def remove_cached(self, cache_id: str): + """移除已同步的缓存""" + cache_path = self._pending_dir / f"{cache_id}.json" + if cache_path.exists(): + cache_path.unlink() + + def clear_cache(self): + """清空缓存""" + for cache_path in self._pending_dir.glob("*.json"): + cache_path.unlink() + + +def get_image_storage() -> ImageStorageManager: + """获取图片存储管理器""" + return ImageStorageManager() + + +def get_local_cache() -> LocalCacheManager: + """获取本地缓存管理器""" + return LocalCacheManager() diff --git a/logs/main.log b/logs/main.log index 18fab97..5b1be96 100644 --- a/logs/main.log +++ b/logs/main.log @@ -158,3 +158,108 @@ ConnectionRefusedError: [WinError 10061] 由于目标计算机积极拒绝,无 2026-01-29 18:11:31 | INFO | main | MQTT客户端已关闭 2026-01-29 18:11:31 | INFO | main | Edge_Inference_Service 已停止 2026-01-29 18:11:31 | INFO | main | 运行统计: {'start_time': datetime.datetime(2026, 1, 29, 18, 7, 35, 334780), 'total_frames_processed': 0, 'total_alerts_generated': 0, 'uptime_seconds': 236.659473} +2026-01-30 09:17:29 | INFO | main | Edge_Inference_Service 初始化开始 +2026-01-30 09:17:29 | INFO | main | ================================================== +2026-01-30 09:17:29 | INFO | main | Edge_Inference_Service 启动 +2026-01-30 09:17:29 | INFO | main | ================================================== +2026-01-30 09:17:33 | WARNING | main | 数据库不可见,服务将在无数据库模式下运行 +2026-01-30 09:17:37 | INFO | main | 配置管理器初始化成功 +2026-01-30 09:17:37 | INFO | main | 流管理器初始化成功 +2026-01-30 09:17:37 | INFO | main | 图像预处理器初始化完成: 输入尺寸 480x480, Batch大小 1-8, FP16模式 True +2026-01-30 09:17:37 | INFO | main | 预处理器初始化成功 +2026-01-30 09:17:37 | INFO | main | TensorRT引擎初始化配置: 模型=./models/yolo11n.engine, 输入尺寸=480x480, Batch=1, FP16=True +2026-01-30 09:17:37 | ERROR | main | TensorRT引擎加载失败: negative dimensions are not allowed +Traceback (most recent call last): + File "c:\Users\16337\PycharmProjects\ai_edge\core\tensorrt_engine.py", line 102, in load_engine + self._allocate_memory_pool() + File "c:\Users\16337\PycharmProjects\ai_edge\core\tensorrt_engine.py", line 152, in _allocate_memory_pool + self._memory_pool["input"] = np.zeros(shape, dtype=dtype) +ValueError: negative dimensions are not allowed +2026-01-30 09:17:37 | INFO | main | 推理引擎加载成功: ./models/yolo11n.engine +2026-01-30 09:17:37 | INFO | main | 后处理器初始化完成: NMS阈值=0.45, 置信度阈值=0.5 +2026-01-30 09:17:37 | INFO | main | 后处理器初始化成功 +2026-01-30 09:17:37 | INFO | main | 结果上报器初始化成功 +2026-01-30 09:17:37 | INFO | main | 所有组件初始化完成 +2026-01-30 09:17:37 | INFO | main | 已启动 0 个视频流 +2026-01-30 09:17:37 | INFO | main | Edge_Inference_Service 已启动 +2026-01-30 09:17:41 | ERROR | main | MQTT连接失败: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 +Traceback (most recent call last): + File "c:\Users\16337\PycharmProjects\ai_edge\core\result_reporter.py", line 220, in connect + self._client.connect( + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 914, in connect + return self.reconnect() + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 1044, in reconnect + sock = self._create_socket_connection() + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 3685, in _create_socket_connection + return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) + File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 857, in create_connection + raise err + File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 845, in create_connection + sock.connect(sa) +ConnectionRefusedError: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 +2026-01-30 09:17:41 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:18:11 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:18:41 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:19:11 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:19:41 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:20:11 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:20:41 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:21:11 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:21:13 | INFO | main | 收到信号 2, 正在停止服务... +2026-01-30 09:21:13 | INFO | main | 已停止所有视频流 +2026-01-30 09:21:13 | INFO | main | 已停止所有视频流 +2026-01-30 09:21:13 | INFO | main | 所有引擎已释放 +2026-01-30 09:21:13 | INFO | main | MQTT连接已断开 +2026-01-30 09:21:13 | INFO | main | MQTT客户端已关闭 +2026-01-30 09:21:13 | INFO | main | Edge_Inference_Service 已停止 +2026-01-30 09:21:13 | INFO | main | 运行统计: {'start_time': datetime.datetime(2026, 1, 30, 9, 17, 37, 843486), 'total_frames_processed': 0, 'total_alerts_generated': 0, 'uptime_seconds': 215.781095} +2026-01-30 09:21:16 | INFO | main | Edge_Inference_Service 初始化开始 +2026-01-30 09:21:16 | INFO | main | ================================================== +2026-01-30 09:21:16 | INFO | main | Edge_Inference_Service 启动 +2026-01-30 09:21:16 | INFO | main | ================================================== +2026-01-30 09:21:20 | WARNING | main | 数据库不可见,服务将在无数据库模式下运行 +2026-01-30 09:21:24 | INFO | main | 配置管理器初始化成功 +2026-01-30 09:21:24 | INFO | main | 流管理器初始化成功 +2026-01-30 09:21:24 | INFO | main | 图像预处理器初始化完成: 输入尺寸 480x480, Batch大小 1-8, FP16模式 True +2026-01-30 09:21:24 | INFO | main | 预处理器初始化成功 +2026-01-30 09:21:24 | INFO | main | TensorRT引擎初始化配置: 模型=./models/yolo11n.engine, 输入尺寸=480x480, Batch=1, FP16=True +2026-01-30 09:21:24 | INFO | main | 连接事件: load - TensorRT -> ./models/yolo11n.engine +2026-01-30 09:21:24 | INFO | main | TensorRT引擎加载成功: ./models/yolo11n.engine +2026-01-30 09:21:24 | INFO | main | 引擎已加载: default +2026-01-30 09:21:24 | INFO | main | 推理引擎加载成功: ./models/yolo11n.engine +2026-01-30 09:21:24 | INFO | main | 后处理器初始化完成: NMS阈值=0.45, 置信度阈值=0.5 +2026-01-30 09:21:24 | INFO | main | 后处理器初始化成功 +2026-01-30 09:21:24 | INFO | main | 结果上报器初始化成功 +2026-01-30 09:21:24 | INFO | main | 所有组件初始化完成 +2026-01-30 09:21:24 | INFO | main | 已启动 0 个视频流 +2026-01-30 09:21:24 | INFO | main | Edge_Inference_Service 已启动 +2026-01-30 09:21:28 | ERROR | main | MQTT连接失败: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 +Traceback (most recent call last): + File "c:\Users\16337\PycharmProjects\ai_edge\core\result_reporter.py", line 220, in connect + self._client.connect( + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 914, in connect + return self.reconnect() + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 1044, in reconnect + sock = self._create_socket_connection() + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 3685, in _create_socket_connection + return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) + File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 857, in create_connection + raise err + File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 845, in create_connection + sock.connect(sa) +ConnectionRefusedError: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 +2026-01-30 09:21:28 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:21:58 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:22:28 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:22:58 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:23:28 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:23:58 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:24:28 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:24:58 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:25:28 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:25:58 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:26:28 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:26:58 | WARNING | main | 消息已加入待发送队列: edge/heartbeat/edge_inference_device +2026-01-30 09:27:08 | INFO | main | 收到信号 2, 正在停止服务... +2026-01-30 09:27:08 | INFO | main | 已停止所有视频流 +2026-01-30 09:27:08 | INFO | main | 已停止所有视频流 diff --git a/logs/main_error.log b/logs/main_error.log index 93a0354..9b55ad8 100644 --- a/logs/main_error.log +++ b/logs/main_error.log @@ -122,3 +122,40 @@ Traceback (most recent call last): File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 845, in create_connection sock.connect(sa) ConnectionRefusedError: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 +2026-01-30 09:17:37 | ERROR | main | TensorRT引擎加载失败: negative dimensions are not allowed +Traceback (most recent call last): + File "c:\Users\16337\PycharmProjects\ai_edge\core\tensorrt_engine.py", line 102, in load_engine + self._allocate_memory_pool() + File "c:\Users\16337\PycharmProjects\ai_edge\core\tensorrt_engine.py", line 152, in _allocate_memory_pool + self._memory_pool["input"] = np.zeros(shape, dtype=dtype) +ValueError: negative dimensions are not allowed +2026-01-30 09:17:41 | ERROR | main | MQTT连接失败: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 +Traceback (most recent call last): + File "c:\Users\16337\PycharmProjects\ai_edge\core\result_reporter.py", line 220, in connect + self._client.connect( + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 914, in connect + return self.reconnect() + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 1044, in reconnect + sock = self._create_socket_connection() + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 3685, in _create_socket_connection + return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) + File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 857, in create_connection + raise err + File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 845, in create_connection + sock.connect(sa) +ConnectionRefusedError: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 +2026-01-30 09:21:28 | ERROR | main | MQTT连接失败: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 +Traceback (most recent call last): + File "c:\Users\16337\PycharmProjects\ai_edge\core\result_reporter.py", line 220, in connect + self._client.connect( + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 914, in connect + return self.reconnect() + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 1044, in reconnect + sock = self._create_socket_connection() + File "C:\Users\16337\miniconda3\envs\yolo\lib\site-packages\paho\mqtt\client.py", line 3685, in _create_socket_connection + return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) + File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 857, in create_connection + raise err + File "C:\Users\16337\miniconda3\envs\yolo\lib\socket.py", line 845, in create_connection + sock.connect(sa) +ConnectionRefusedError: [WinError 10061] 由于目标计算机积极拒绝,无法连接。 diff --git a/models/yolo11n.engine b/models/yolo11n.engine new file mode 100644 index 0000000..16c8e49 Binary files /dev/null and b/models/yolo11n.engine differ diff --git a/models/yolo11n.onnx b/models/yolo11n.onnx index 6ad649e..0a67bf1 100644 Binary files a/models/yolo11n.onnx and b/models/yolo11n.onnx differ