""" 全局配置模块 定义数据库、Redis、MQTT、推理等各项配置参数 """ import os from dataclasses import dataclass, field from typing import List, Optional @dataclass class DatabaseConfig: """数据库配置类(MySQL - 云端)""" host: str = "localhost" port: int = 3306 username: str = "root" password: str = "" database: str = "edge_inference" pool_size: int = 10 pool_recycle: int = 3600 echo: bool = False @dataclass class SQLiteConfig: """SQLite 配置(边缘侧本地存储)""" db_path: str = "./data/security_events.db" image_dir: str = "./data/captures" retention_days: int = 7 wal_mode: bool = True batch_size: int = 100 flush_interval: float = 5.0 @dataclass class RedisConfig: """Redis配置类""" host: str = "localhost" port: int = 6379 db: int = 0 password: Optional[str] = None decode_responses: bool = True max_connections: int = 50 @dataclass class MQTTConfig: """MQTT配置类""" broker_host: str = "localhost" broker_port: int = 1883 client_id: str = "edge_inference_service" username: Optional[str] = None password: Optional[str] = None keepalive: int = 60 qos: int = 1 reconnect_delay: int = 5 max_reconnect_attempts: int = 10 @dataclass class VideoStreamConfig: """视频流配置类""" default_fps: int = 5 reconnect_max_attempts: int = 5 reconnect_base_delay: float = 1.0 reconnect_max_delay: float = 60.0 frame_buffer_size: int = 30 connection_timeout: int = 10 read_timeout: int = 30 @dataclass class InferenceConfig: """推理配置类""" model_path: str = "./models/yolo11n.engine" input_width: int = 480 input_height: int = 480 batch_size: int = 1 max_batch_size: int = 8 conf_threshold: float = 0.5 nms_threshold: float = 0.45 device_id: int = 0 fp16_mode: bool = True @dataclass class CameraConfig: """单个摄像头配置""" camera_id: str rtsp_url: str enabled: bool = True roi_ids: List[str] = field(default_factory=list) @dataclass class ROIConfig: """ROI区域配置""" roi_id: str camera_id: str roi_type: str # 'polygon' or 'rectangle' coordinates: List[List[float]] # 多边形顶点或矩形坐标 algorithm_type: str # 'leave_post', 'intrusion', etc. alert_threshold: int = 3 # 连续N帧触发告警 alert_cooldown: int = 300 # 告警冷却时间(秒) class Settings: """全局设置单例类""" _instance = None _initialized = False def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if not self._initialized: self._load_env_vars() self._initialized = True def _load_env_vars(self): """从环境变量加载配置""" self.database = DatabaseConfig( host=os.getenv("DB_HOST", "localhost"), port=int(os.getenv("DB_PORT", "3306")), username=os.getenv("DB_USERNAME", "root"), password=os.getenv("DB_PASSWORD", ""), database=os.getenv("DB_DATABASE", "edge_inference"), ) self.sqlite = SQLiteConfig( db_path=os.getenv("SQLITE_DB_PATH", "./data/security_events.db"), image_dir=os.getenv("SQLITE_IMAGE_DIR", "./data/captures"), retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")), wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1", ) self.redis = RedisConfig( host=os.getenv("REDIS_HOST", "localhost"), port=int(os.getenv("REDIS_PORT", "6379")), password=os.getenv("REDIS_PASSWORD"), ) self.mqtt = MQTTConfig( broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"), broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")), client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"), username=os.getenv("MQTT_USERNAME"), password=os.getenv("MQTT_PASSWORD"), ) self.video_stream = VideoStreamConfig( default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")), reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")), ) self.inference = InferenceConfig( model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"), input_width=int(os.getenv("INPUT_WIDTH", "480")), input_height=int(os.getenv("INPUT_HEIGHT", "480")), batch_size=int(os.getenv("BATCH_SIZE", "1")), max_batch_size=int(os.getenv("MAX_BATCH_SIZE", "8")), conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.5")), nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")), ) self.log_level = os.getenv("LOG_LEVEL", "INFO") self.log_dir = os.getenv("LOG_DIR", "./logs") self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760")) self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5")) self.working_hours = self._parse_working_hours() def _parse_working_hours(self) -> List[dict]: """解析工作时间配置""" working_hours_str = os.getenv("WORKING_HOURS", "") if not working_hours_str: return [] working_hours = [] periods = working_hours_str.split(";") for period in periods: try: start, end = period.split("-") start_h, start_m = map(int, start.split(":")) end_h, end_m = map(int, end.split(":")) working_hours.append({ "start": [start_h, start_m], "end": [end_h, end_m] }) except (ValueError, AttributeError): continue return working_hours @property def config_version(self) -> str: """获取配置版本号""" return os.getenv("CONFIG_VERSION", "1.0.0") def get_settings() -> Settings: """获取全局配置单例""" return Settings()