Files
security-ai-edge/config/database.py
16337 b0ddb6ee1a feat(project): move edge_inference_service contents to root and update paths
- Moved all project files and directories (config, core, models, etc.) from
  edge_inference_service/ to the repository root ai_edge/
- Updated model path in config/settings.py to reflect new structure
- Revised usage paths in __init__.py documentation
2026-01-29 18:43:19 +08:00

317 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据库连接配置模块
提供MySQL数据库连接池管理和操作封装
"""
import logging
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional
from sqlalchemy import create_engine, Column, String, Boolean, Integer, Float, Text, JSON, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import QueuePool
from config.settings import get_settings, DatabaseConfig
logger = logging.getLogger(__name__)
Base = declarative_base()
class CameraInfo(Base):
"""摄像头信息表模型"""
__tablename__ = "camera_info"
id = Column(Integer, primary_key=True, autoincrement=True)
camera_id = Column(String(64), unique=True, nullable=False, index=True)
camera_name = Column(String(128), nullable=True)
rtsp_url = Column(String(512), nullable=False)
status = Column(Boolean, default=True)
enabled = Column(Boolean, default=True)
location = Column(String(256), nullable=True)
extra_params = Column(JSON, nullable=True)
created_at = Column(DateTime, nullable=True)
updated_at = Column(DateTime, nullable=True)
class ROIConfigModel(Base):
"""ROI配置表模型"""
__tablename__ = "roi_config"
id = Column(Integer, primary_key=True, autoincrement=True)
roi_id = Column(String(64), unique=True, nullable=False, index=True)
camera_id = Column(String(64), nullable=False, index=True)
roi_type = Column(String(32), nullable=False) # 'polygon' or 'rectangle'
coordinates = Column(JSON, nullable=False) # 多边形顶点或矩形坐标
algorithm_type = Column(String(32), nullable=False) # 'leave_post', 'intrusion', etc.
alert_threshold = Column(Integer, default=3)
alert_cooldown = Column(Integer, default=300)
enabled = Column(Boolean, default=True)
extra_params = Column(JSON, nullable=True)
created_at = Column(DateTime, nullable=True)
updated_at = Column(DateTime, nullable=True)
class AlertRecord(Base):
"""告警记录表模型"""
__tablename__ = "alert_records"
id = Column(Integer, primary_key=True, autoincrement=True)
alert_id = Column(String(64), unique=True, nullable=False, index=True)
camera_id = Column(String(64), nullable=False, index=True)
roi_id = Column(String(64), nullable=False, index=True)
alert_type = Column(String(32), nullable=False)
target_class = Column(String(64), nullable=True)
confidence = Column(Float, nullable=True)
bbox = Column(JSON, nullable=True)
message = Column(Text, nullable=True)
screenshot = Column(Text, nullable=True) # Base64编码的截图
status = Column(String(32), default="pending")
created_at = Column(DateTime, nullable=True)
processed_at = Column(DateTime, nullable=True)
class DatabaseManager:
"""数据库连接管理器类"""
_instance = None
_engine = None
_session_factory = None
_available = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, config: Optional[DatabaseConfig] = None):
if self._engine is not None:
return
if config is None:
settings = get_settings()
config = settings.database
self._init_engine(config)
def _init_engine(self, config: DatabaseConfig):
"""初始化数据库引擎"""
try:
connection_string = (
f"mysql+pymysql://{config.username}:{config.password}"
f"@{config.host}:{config.port}/{config.database}"
f"?charset=utf8mb4"
)
self._engine = create_engine(
connection_string,
poolclass=QueuePool,
pool_size=config.pool_size,
pool_recycle=config.pool_recycle,
echo=config.echo,
pool_pre_ping=True,
max_overflow=5,
)
self._session_factory = sessionmaker(bind=self._engine)
test_connection = self._engine.connect()
test_connection.close()
self._available = True
logger.info(f"数据库引擎初始化成功: {config.host}:{config.port}/{config.database}")
except Exception as e:
self._available = False
logger.warning(f"数据库连接失败,服务将在无数据库模式下运行: {e}")
@property
def is_available(self) -> bool:
"""检查数据库是否可用"""
return self._available
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""获取数据库会话上下文"""
if not self._available:
logger.warning("数据库不可用,跳过数据库操作")
yield None
return
session = self._session_factory()
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logger.error(f"数据库操作异常: {e}")
raise
finally:
session.close()
def get_camera_info(self, camera_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取摄像头信息"""
if not self._available:
logger.warning("数据库不可用,返回空摄像头列表")
return []
with self.get_session() as session:
if session is None:
return []
query = session.query(CameraInfo)
if camera_id:
query = query.filter(CameraInfo.camera_id == camera_id)
cameras = query.filter(CameraInfo.enabled == True).all()
result = []
for camera in cameras:
result.append({
"camera_id": camera.camera_id,
"camera_name": camera.camera_name,
"rtsp_url": camera.rtsp_url,
"status": camera.status,
"location": camera.location,
"extra_params": camera.extra_params,
})
return result
def get_roi_configs(self, camera_id: Optional[str] = None,
roi_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取ROI配置"""
if not self._available:
logger.warning("数据库不可用返回空ROI配置列表")
return []
with self.get_session() as session:
if session is None:
return []
query = session.query(ROIConfigModel)
if camera_id:
query = query.filter(ROIConfigModel.camera_id == camera_id)
if roi_id:
query = query.filter(ROIConfigModel.roi_id == roi_id)
query = query.filter(ROIConfigModel.enabled == True)
rois = query.all()
result = []
for roi in rois:
result.append({
"roi_id": roi.roi_id,
"camera_id": roi.camera_id,
"roi_type": roi.roi_type,
"coordinates": roi.coordinates,
"algorithm_type": roi.algorithm_type,
"alert_threshold": roi.alert_threshold,
"alert_cooldown": roi.alert_cooldown,
"extra_params": roi.extra_params,
})
return result
def save_alert_record(self, alert_data: Dict[str, Any]) -> bool:
"""保存告警记录"""
if not self._available:
logger.warning("数据库不可用,跳过保存告警记录")
return False
try:
with self.get_session() as session:
if session is None:
return False
alert = AlertRecord(
alert_id=alert_data.get("alert_id"),
camera_id=alert_data.get("camera_id"),
roi_id=alert_data.get("roi_id"),
alert_type=alert_data.get("alert_type"),
target_class=alert_data.get("target_class"),
confidence=alert_data.get("confidence"),
bbox=alert_data.get("bbox"),
message=alert_data.get("message"),
screenshot=alert_data.get("screenshot"),
status=alert_data.get("status", "pending"),
)
session.add(alert)
session.flush()
logger.info(f"告警记录保存成功: {alert_data.get('alert_id')}")
return True
except Exception as e:
logger.error(f"保存告警记录失败: {e}")
return False
def update_alert_status(self, alert_id: str, status: str) -> bool:
"""更新告警状态"""
if not self._available:
logger.warning("数据库不可用,跳过更新告警状态")
return False
try:
with self.get_session() as session:
if session is None:
return False
from sqlalchemy import update
from datetime import datetime
stmt = update(AlertRecord).where(
AlertRecord.alert_id == alert_id
).values(
status=status,
processed_at=datetime.now()
)
session.execute(stmt)
logger.info(f"告警状态更新成功: {alert_id} -> {status}")
return True
except Exception as e:
logger.error(f"更新告警状态失败: {e}")
return False
def create_tables(self):
"""创建所有表"""
if not self._available:
logger.warning("数据库不可用,跳过创建表")
return
Base.metadata.create_all(self._engine)
logger.info("数据库表创建完成")
def drop_tables(self):
"""删除所有表"""
if not self._available:
return
Base.metadata.drop_all(self._engine)
logger.info("数据库表删除完成")
def close(self):
"""关闭数据库连接"""
if self._engine:
self._engine.dispose()
logger.info("数据库连接已关闭")
def get_database_manager() -> DatabaseManager:
"""获取数据库管理器单例"""
return DatabaseManager()
def init_database():
"""初始化数据库"""
db_manager = get_database_manager()
db_manager.create_tables()
return db_manager