Files
security-ai-edge/config/database.py

317 lines
11 KiB
Python
Raw Normal View History

2026-01-29 18:33:12 +08:00
"""
数据库连接配置模块
提供MySQL数据库连接池管理和操作封装
"""
import logging
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional
from sqlalchemy import create_engine, Column, String, Boolean, Integer, Float, Text, JSON, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool
from config.settings import get_settings, DatabaseConfig
logger = logging.getLogger(__name__)
Base = declarative_base()
class CameraInfo(Base):
"""摄像头信息表模型"""
__tablename__ = "camera_info"
id = Column(Integer, primary_key=True, autoincrement=True)
camera_id = Column(String(64), unique=True, nullable=False, index=True)
camera_name = Column(String(128), nullable=True)
rtsp_url = Column(String(512), nullable=False)
status = Column(Boolean, default=True)
enabled = Column(Boolean, default=True)
location = Column(String(256), nullable=True)
extra_params = Column(JSON, nullable=True)
created_at = Column(DateTime, nullable=True)
updated_at = Column(DateTime, nullable=True)
class ROIConfigModel(Base):
"""ROI配置表模型"""
__tablename__ = "roi_config"
id = Column(Integer, primary_key=True, autoincrement=True)
roi_id = Column(String(64), unique=True, nullable=False, index=True)
camera_id = Column(String(64), nullable=False, index=True)
roi_type = Column(String(32), nullable=False) # 'polygon' or 'rectangle'
coordinates = Column(JSON, nullable=False) # 多边形顶点或矩形坐标
algorithm_type = Column(String(32), nullable=False) # 'leave_post', 'intrusion', etc.
alert_threshold = Column(Integer, default=3)
alert_cooldown = Column(Integer, default=300)
enabled = Column(Boolean, default=True)
extra_params = Column(JSON, nullable=True)
created_at = Column(DateTime, nullable=True)
updated_at = Column(DateTime, nullable=True)
class AlertRecord(Base):
"""告警记录表模型"""
__tablename__ = "alert_records"
id = Column(Integer, primary_key=True, autoincrement=True)
alert_id = Column(String(64), unique=True, nullable=False, index=True)
camera_id = Column(String(64), nullable=False, index=True)
roi_id = Column(String(64), nullable=False, index=True)
alert_type = Column(String(32), nullable=False)
target_class = Column(String(64), nullable=True)
confidence = Column(Float, nullable=True)
bbox = Column(JSON, nullable=True)
message = Column(Text, nullable=True)
screenshot = Column(Text, nullable=True) # Base64编码的截图
status = Column(String(32), default="pending")
created_at = Column(DateTime, nullable=True)
processed_at = Column(DateTime, nullable=True)
class DatabaseManager:
"""数据库连接管理器类"""
_instance = None
_engine = None
_session_factory = None
_available = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: Optional[DatabaseConfig] = None):
if self._engine is not None:
return
if config is None:
settings = get_settings()
config = settings.database
self._init_engine(config)
def _init_engine(self, config: DatabaseConfig):
"""初始化数据库引擎"""
try:
connection_string = (
f"mysql+pymysql://{config.username}:{config.password}"
f"@{config.host}:{config.port}/{config.database}"
f"?charset=utf8mb4"
)
self._engine = create_engine(
connection_string,
poolclass=QueuePool,
pool_size=config.pool_size,
pool_recycle=config.pool_recycle,
echo=config.echo,
pool_pre_ping=True,
max_overflow=5,
)
self._session_factory = sessionmaker(bind=self._engine)
test_connection = self._engine.connect()
test_connection.close()
self._available = True
logger.info(f"数据库引擎初始化成功: {config.host}:{config.port}/{config.database}")
except Exception as e:
self._available = False
logger.warning(f"数据库连接失败,服务将在无数据库模式下运行: {e}")
@property
def is_available(self) -> bool:
"""检查数据库是否可用"""
return self._available
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""获取数据库会话上下文"""
if not self._available:
logger.warning("数据库不可用,跳过数据库操作")
yield None
return
session = self._session_factory()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"数据库操作异常: {e}")
raise
finally:
session.close()
def get_camera_info(self, camera_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取摄像头信息"""
if not self._available:
logger.warning("数据库不可用,返回空摄像头列表")
return []
with self.get_session() as session:
if session is None:
return []
query = session.query(CameraInfo)
if camera_id:
query = query.filter(CameraInfo.camera_id == camera_id)
cameras = query.filter(CameraInfo.enabled == True).all()
result = []
for camera in cameras:
result.append({
"camera_id": camera.camera_id,
"camera_name": camera.camera_name,
"rtsp_url": camera.rtsp_url,
"status": camera.status,
"location": camera.location,
"extra_params": camera.extra_params,
})
return result
def get_roi_configs(self, camera_id: Optional[str] = None,
roi_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取ROI配置"""
if not self._available:
logger.warning("数据库不可用返回空ROI配置列表")
return []
with self.get_session() as session:
if session is None:
return []
query = session.query(ROIConfigModel)
if camera_id:
query = query.filter(ROIConfigModel.camera_id == camera_id)
if roi_id:
query = query.filter(ROIConfigModel.roi_id == roi_id)
query = query.filter(ROIConfigModel.enabled == True)
rois = query.all()
result = []
for roi in rois:
result.append({
"roi_id": roi.roi_id,
"camera_id": roi.camera_id,
"roi_type": roi.roi_type,
"coordinates": roi.coordinates,
"algorithm_type": roi.algorithm_type,
"alert_threshold": roi.alert_threshold,
"alert_cooldown": roi.alert_cooldown,
"extra_params": roi.extra_params,
})
return result
def save_alert_record(self, alert_data: Dict[str, Any]) -> bool:
"""保存告警记录"""
if not self._available:
logger.warning("数据库不可用,跳过保存告警记录")
return False
try:
with self.get_session() as session:
if session is None:
return False
alert = AlertRecord(
alert_id=alert_data.get("alert_id"),
camera_id=alert_data.get("camera_id"),
roi_id=alert_data.get("roi_id"),
alert_type=alert_data.get("alert_type"),
target_class=alert_data.get("target_class"),
confidence=alert_data.get("confidence"),
bbox=alert_data.get("bbox"),
message=alert_data.get("message"),
screenshot=alert_data.get("screenshot"),
status=alert_data.get("status", "pending"),
)
session.add(alert)
session.flush()
logger.info(f"告警记录保存成功: {alert_data.get('alert_id')}")
return True
except Exception as e:
logger.error(f"保存告警记录失败: {e}")
return False
def update_alert_status(self, alert_id: str, status: str) -> bool:
"""更新告警状态"""
if not self._available:
logger.warning("数据库不可用,跳过更新告警状态")
return False
try:
with self.get_session() as session:
if session is None:
return False
from sqlalchemy import update
from datetime import datetime
stmt = update(AlertRecord).where(
AlertRecord.alert_id == alert_id
).values(
status=status,
processed_at=datetime.now()
)
session.execute(stmt)
logger.info(f"告警状态更新成功: {alert_id} -> {status}")
return True
except Exception as e:
logger.error(f"更新告警状态失败: {e}")
return False
def create_tables(self):
"""创建所有表"""
if not self._available:
logger.warning("数据库不可用,跳过创建表")
return
Base.metadata.create_all(self._engine)
logger.info("数据库表创建完成")
def drop_tables(self):
"""删除所有表"""
if not self._available:
return
Base.metadata.drop_all(self._engine)
logger.info("数据库表删除完成")
def close(self):
"""关闭数据库连接"""
if self._engine:
self._engine.dispose()
logger.info("数据库连接已关闭")
def get_database_manager() -> DatabaseManager:
"""获取数据库管理器单例"""
return DatabaseManager()
def init_database():
"""初始化数据库"""
db_manager = get_database_manager()
db_manager.create_tables()
return db_manager