""" 数据库连接配置模块 提供MySQL数据库连接池管理和操作封装 """ import logging from contextlib import contextmanager from typing import Any, Dict, Generator, List, Optional from sqlalchemy import create_engine, Column, String, Boolean, Integer, Float, Text, JSON, DateTime from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.pool import QueuePool from config.settings import get_settings, DatabaseConfig logger = logging.getLogger(__name__) Base = declarative_base() class CameraInfo(Base): """摄像头信息表模型""" __tablename__ = "camera_info" id = Column(Integer, primary_key=True, autoincrement=True) camera_id = Column(String(64), unique=True, nullable=False, index=True) camera_name = Column(String(128), nullable=True) rtsp_url = Column(String(512), nullable=False) status = Column(Boolean, default=True) enabled = Column(Boolean, default=True) location = Column(String(256), nullable=True) extra_params = Column(JSON, nullable=True) created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) class ROIConfigModel(Base): """ROI配置表模型""" __tablename__ = "roi_config" id = Column(Integer, primary_key=True, autoincrement=True) roi_id = Column(String(64), unique=True, nullable=False, index=True) camera_id = Column(String(64), nullable=False, index=True) roi_type = Column(String(32), nullable=False) # 'polygon' or 'rectangle' coordinates = Column(JSON, nullable=False) # 多边形顶点或矩形坐标 algorithm_type = Column(String(32), nullable=False) # 'leave_post', 'intrusion', etc. alert_threshold = Column(Integer, default=3) alert_cooldown = Column(Integer, default=300) enabled = Column(Boolean, default=True) extra_params = Column(JSON, nullable=True) created_at = Column(DateTime, nullable=True) updated_at = Column(DateTime, nullable=True) class AlertRecord(Base): """告警记录表模型""" __tablename__ = "alert_records" id = Column(Integer, primary_key=True, autoincrement=True) alert_id = Column(String(64), unique=True, nullable=False, index=True) camera_id = Column(String(64), nullable=False, index=True) roi_id = Column(String(64), nullable=False, index=True) alert_type = Column(String(32), nullable=False) target_class = Column(String(64), nullable=True) confidence = Column(Float, nullable=True) bbox = Column(JSON, nullable=True) message = Column(Text, nullable=True) screenshot = Column(Text, nullable=True) # Base64编码的截图 status = Column(String(32), default="pending") created_at = Column(DateTime, nullable=True) processed_at = Column(DateTime, nullable=True) class DatabaseManager: """数据库连接管理器类""" _instance = None _engine = None _session_factory = None _available = False def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self, config: Optional[DatabaseConfig] = None): if self._engine is not None: return if config is None: settings = get_settings() config = settings.database self._init_engine(config) def _init_engine(self, config: DatabaseConfig): """初始化数据库引擎""" try: connection_string = ( f"mysql+pymysql://{config.username}:{config.password}" f"@{config.host}:{config.port}/{config.database}" f"?charset=utf8mb4" ) self._engine = create_engine( connection_string, poolclass=QueuePool, pool_size=config.pool_size, pool_recycle=config.pool_recycle, echo=config.echo, pool_pre_ping=True, max_overflow=5, ) self._session_factory = sessionmaker(bind=self._engine) test_connection = self._engine.connect() test_connection.close() self._available = True logger.info(f"数据库引擎初始化成功: {config.host}:{config.port}/{config.database}") except Exception as e: self._available = False logger.warning(f"数据库连接失败,服务将在无数据库模式下运行: {e}") @property def is_available(self) -> bool: """检查数据库是否可用""" return self._available @contextmanager def get_session(self) -> Generator[Session, None, None]: """获取数据库会话上下文""" if not self._available: logger.warning("数据库不可用,跳过数据库操作") yield None return session = self._session_factory() try: yield session session.commit() except Exception as e: session.rollback() logger.error(f"数据库操作异常: {e}") raise finally: session.close() def get_camera_info(self, camera_id: Optional[str] = None) -> List[Dict[str, Any]]: """获取摄像头信息""" if not self._available: logger.warning("数据库不可用,返回空摄像头列表") return [] with self.get_session() as session: if session is None: return [] query = session.query(CameraInfo) if camera_id: query = query.filter(CameraInfo.camera_id == camera_id) cameras = query.filter(CameraInfo.enabled == True).all() result = [] for camera in cameras: result.append({ "camera_id": camera.camera_id, "camera_name": camera.camera_name, "rtsp_url": camera.rtsp_url, "status": camera.status, "location": camera.location, "extra_params": camera.extra_params, }) return result def get_roi_configs(self, camera_id: Optional[str] = None, roi_id: Optional[str] = None) -> List[Dict[str, Any]]: """获取ROI配置""" if not self._available: logger.warning("数据库不可用,返回空ROI配置列表") return [] with self.get_session() as session: if session is None: return [] query = session.query(ROIConfigModel) if camera_id: query = query.filter(ROIConfigModel.camera_id == camera_id) if roi_id: query = query.filter(ROIConfigModel.roi_id == roi_id) query = query.filter(ROIConfigModel.enabled == True) rois = query.all() result = [] for roi in rois: result.append({ "roi_id": roi.roi_id, "camera_id": roi.camera_id, "roi_type": roi.roi_type, "coordinates": roi.coordinates, "algorithm_type": roi.algorithm_type, "alert_threshold": roi.alert_threshold, "alert_cooldown": roi.alert_cooldown, "extra_params": roi.extra_params, }) return result def save_alert_record(self, alert_data: Dict[str, Any]) -> bool: """保存告警记录""" if not self._available: logger.warning("数据库不可用,跳过保存告警记录") return False try: with self.get_session() as session: if session is None: return False alert = AlertRecord( alert_id=alert_data.get("alert_id"), camera_id=alert_data.get("camera_id"), roi_id=alert_data.get("roi_id"), alert_type=alert_data.get("alert_type"), target_class=alert_data.get("target_class"), confidence=alert_data.get("confidence"), bbox=alert_data.get("bbox"), message=alert_data.get("message"), screenshot=alert_data.get("screenshot"), status=alert_data.get("status", "pending"), ) session.add(alert) session.flush() logger.info(f"告警记录保存成功: {alert_data.get('alert_id')}") return True except Exception as e: logger.error(f"保存告警记录失败: {e}") return False def update_alert_status(self, alert_id: str, status: str) -> bool: """更新告警状态""" if not self._available: logger.warning("数据库不可用,跳过更新告警状态") return False try: with self.get_session() as session: if session is None: return False from sqlalchemy import update from datetime import datetime stmt = update(AlertRecord).where( AlertRecord.alert_id == alert_id ).values( status=status, processed_at=datetime.now() ) session.execute(stmt) logger.info(f"告警状态更新成功: {alert_id} -> {status}") return True except Exception as e: logger.error(f"更新告警状态失败: {e}") return False def create_tables(self): """创建所有表""" if not self._available: logger.warning("数据库不可用,跳过创建表") return Base.metadata.create_all(self._engine) logger.info("数据库表创建完成") def drop_tables(self): """删除所有表""" if not self._available: return Base.metadata.drop_all(self._engine) logger.info("数据库表删除完成") def close(self): """关闭数据库连接""" if self._engine: self._engine.dispose() logger.info("数据库连接已关闭") def get_database_manager() -> DatabaseManager: """获取数据库管理器单例""" return DatabaseManager() def init_database(): """初始化数据库""" db_manager = get_database_manager() db_manager.create_tables() return db_manager