""" 全局配置模块 定义数据库、Redis、MQTT、推理等各项配置参数 """ import os from dataclasses import dataclass, field from typing import Dict, List, Optional @dataclass class DatabaseConfig: """数据库配置类(MySQL - 云端)""" host: str = "localhost" port: int = 3306 username: str = "root" password: str = "" database: str = "edge_inference" pool_size: int = 10 pool_recycle: int = 3600 echo: bool = False @dataclass class SQLiteConfig: """SQLite 配置(边缘侧本地存储)""" db_path: str = "./data/security_events.db" image_dir: str = "./data/captures" retention_days: int = 7 wal_mode: bool = True batch_size: int = 100 flush_interval: float = 5.0 @dataclass class RedisConfig: """Redis配置类(本地 Redis,边缘侧缓存)""" host: str = "localhost" port: int = 6379 db: int = 0 password: Optional[str] = None decode_responses: bool = True max_connections: int = 50 @dataclass class CloudRedisConfig: """云端 Redis 配置(三层权威模型 - 云端层)""" host: str = "localhost" port: int = 6379 db: int = 1 password: Optional[str] = None decode_responses: bool = True max_connections: int = 20 @dataclass class LocalRedisConfig: """本地 Redis 配置(三层权威模型 - 边缘层缓存)""" host: str = "localhost" port: int = 6379 db: int = 1 password: Optional[str] = None decode_responses: bool = True max_connections: int = 20 @dataclass class MQTTConfig: """MQTT配置类(保留配置结构,不再用于告警上报)""" broker_host: str = "localhost" broker_port: int = 1883 client_id: str = "edge_inference_service" device_id: str = "default" username: Optional[str] = None password: Optional[str] = None keepalive: int = 60 qos: int = 1 reconnect_delay: int = 5 max_reconnect_attempts: int = 10 @dataclass class COSConfig: """腾讯云 COS 配置""" secret_id: str = "" secret_key: str = "" region: str = "ap-beijing" bucket: str = "" @dataclass class AlarmUploadConfig: """告警上报配置""" cloud_api_url: str = "http://localhost:8000" wvp_api_url: str = "" # WVP 平台地址(心跳同步用) edge_token: str = "" retry_max: int = 3 retry_interval: int = 5 @dataclass class VideoStreamConfig: """视频流配置类""" default_fps: int = 5 reconnect_max_attempts: int = 5 reconnect_base_delay: float = 1.0 reconnect_max_delay: float = 60.0 frame_buffer_size: int = 30 connection_timeout: int = 10 read_timeout: int = 30 @dataclass class InferenceConfig: """推理配置类""" model_path: str = "./models/yolo11n.engine" input_width: int = 480 input_height: int = 480 batch_size: int = 1 conf_threshold: float = 0.45 nms_threshold: float = 0.5 device_id: int = 0 fp16_mode: bool = True # 按算法类型覆盖置信度阈值,key=algo_code, value=threshold # 未命中时回退到 conf_threshold algo_conf_thresholds: Dict[str, float] = field(default_factory=dict) def get_conf_threshold(self, algo_code: str) -> float: """获取指定算法的置信度阈值,未配置则回退全局值""" return self.algo_conf_thresholds.get(algo_code, self.conf_threshold) # ===================== Debug / Local Sync ===================== @dataclass class DebugConfig: """本地调试相关配置""" enabled: bool = True host: str = "127.0.0.1" port: int = 9001 reload_signal_file: str = "./config/reload.signal" local_config_path: str = "./config/local_config.json" # 模型类别配置(支持不同模型切换,修改此处即可) # 当前: yolo11s_v2plus_20260421 (5 类定制模型) # 历史: yolo11n 官方 (80 类 COCO),列表见 git 历史 MODEL_CLASS_NAMES = ["garbage", "person", "car", "bicycle", "motorcycle"] MODEL_NUM_CLASSES = len(MODEL_CLASS_NAMES) # 模型类别数 MODEL_OUTPUT_CHANNELS = 4 + MODEL_NUM_CLASSES # YOLO 输出通道 = 4(xywh) + nc # 向后兼容:保留 COCO_CLASS_NAMES 名称,指向当前模型类别 COCO_CLASS_NAMES = MODEL_CLASS_NAMES # 各算法的业务关注类别(全集,不限定当前模型是否支持) # 部署时通过 get_algo_target_classes() 自动过滤出当前模型支持的子集 ALGO_INTENT_CLASSES = { "leave_post": ["person"], "intrusion": ["person"], "illegal_parking": ["car", "truck", "bus"], "vehicle_congestion": ["car", "truck", "bus", "motorcycle"], "non_motor_vehicle_parking": ["bicycle", "motorcycle"], "garbage": ["garbage"], } def get_algo_target_classes(algo_code: str) -> list: """获取算法的目标类别,自动过滤出当前模型支持的类 Args: algo_code: 算法代码 (leave_post / intrusion / illegal_parking / ...) Returns: 目标类名列表,保证每个类都在 MODEL_CLASS_NAMES 中 """ intent = ALGO_INTENT_CLASSES.get(algo_code, []) return [c for c in intent if c in MODEL_CLASS_NAMES] @dataclass class CameraConfig: """单个摄像头配置""" camera_id: str rtsp_url: str enabled: bool = True roi_ids: List[str] = field(default_factory=list) @dataclass class ROIConfig: """ROI区域配置""" roi_id: str camera_id: str roi_type: str # 'polygon' or 'rectangle' coordinates: List[List[float]] # 多边形顶点或矩形坐标 algorithm_type: str # 'leave_post', 'intrusion', etc. alert_threshold: int = 3 # 连续N帧触发告警 alert_cooldown: int = 300 # 告警冷却时间(秒) class Settings: """全局设置单例类""" _instance = None _initialized = False def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): if not self._initialized: self._load_env_vars() self._initialized = True def _load_env_vars(self): """从环境变量加载配置""" # 加载 .env 文件(如果 python-dotenv 可用) try: from dotenv import load_dotenv load_dotenv() except ImportError: pass base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) def _abs_path(path: str) -> str: if not path: return path return path if os.path.isabs(path) else os.path.normpath(os.path.join(base_dir, path)) self.database = DatabaseConfig( host=os.getenv("DB_HOST", "localhost"), port=int(os.getenv("DB_PORT", "3306")), username=os.getenv("DB_USERNAME", "root"), password=os.getenv("DB_PASSWORD", ""), database=os.getenv("DB_DATABASE", "edge_inference"), ) self.sqlite = SQLiteConfig( db_path=_abs_path(os.getenv("SQLITE_DB_PATH", "./data/security_events.db")), image_dir=_abs_path(os.getenv("SQLITE_IMAGE_DIR", "./data/captures")), retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")), wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1", ) self.redis = RedisConfig( host=os.getenv("REDIS_HOST", "localhost"), port=int(os.getenv("REDIS_PORT", "6379")), password=os.getenv("REDIS_PASSWORD"), ) self.cloud_redis = CloudRedisConfig( host=os.getenv("CLOUD_REDIS_HOST", "localhost"), port=int(os.getenv("CLOUD_REDIS_PORT", "6379")), db=int(os.getenv("CLOUD_REDIS_DB", "1")), password=os.getenv("CLOUD_REDIS_PASSWORD"), ) self.local_redis = LocalRedisConfig( host=os.getenv("LOCAL_REDIS_HOST", "localhost"), port=int(os.getenv("LOCAL_REDIS_PORT", "6379")), db=int(os.getenv("LOCAL_REDIS_DB", "1")), password=os.getenv("LOCAL_REDIS_PASSWORD"), ) self.mqtt = MQTTConfig( broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"), broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")), client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"), device_id=os.getenv("EDGE_DEVICE_ID", "default"), username=os.getenv("MQTT_USERNAME"), password=os.getenv("MQTT_PASSWORD"), ) self.cos = COSConfig( secret_id=os.getenv("COS_SECRET_ID", ""), secret_key=os.getenv("COS_SECRET_KEY", ""), region=os.getenv("COS_REGION", "ap-beijing"), bucket=os.getenv("COS_BUCKET", ""), ) self.alarm_upload = AlarmUploadConfig( cloud_api_url=os.getenv("CLOUD_API_URL", "http://localhost:8000"), wvp_api_url=os.getenv("WVP_API_URL", ""), edge_token=os.getenv("EDGE_TOKEN", ""), retry_max=int(os.getenv("ALARM_RETRY_MAX", "3")), retry_interval=int(os.getenv("ALARM_RETRY_INTERVAL", "5")), ) self.alarm_upload_enabled = os.getenv("ALARM_UPLOAD_ENABLED", "1") == "1" self.video_stream = VideoStreamConfig( default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")), reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")), ) self.inference = InferenceConfig( model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"), input_width=int(os.getenv("INPUT_WIDTH", "480")), input_height=int(os.getenv("INPUT_HEIGHT", "480")), batch_size=int(os.getenv("BATCH_SIZE", "4")), conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.45")), nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.5")), algo_conf_thresholds=self._parse_algo_conf_thresholds(), ) self.config_sync_mode = os.getenv("CONFIG_SYNC_MODE", "LOCAL").upper() self.debug = DebugConfig( enabled=os.getenv("DEBUG_SERVER_ENABLED", "1") == "1", host=os.getenv("DEBUG_SERVER_HOST", "127.0.0.1"), port=int(os.getenv("DEBUG_SERVER_PORT", "9001")), reload_signal_file=_abs_path(os.getenv("DEBUG_RELOAD_SIGNAL_FILE", "./config/reload.signal")), local_config_path=_abs_path(os.getenv("LOCAL_CONFIG_PATH", "./config/local_config.json")), ) self.log_level = os.getenv("LOG_LEVEL", "INFO") self.log_dir = os.getenv("LOG_DIR", "./logs") self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760")) self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5")) self.working_hours = self._parse_working_hours() # 使用 COCO 类别名称 self.class_names = COCO_CLASS_NAMES @staticmethod def _parse_algo_conf_thresholds() -> Dict[str, float]: """解析 ALGO_CONF_* 环境变量,返回 {algo_code: threshold} 字典 环境变量命名规则: ALGO_CONF_{ALGO_CODE},如: ALGO_CONF_LEAVE_POST=0.35 ALGO_CONF_INTRUSION=0.55 """ prefix = "ALGO_CONF_" result = {} for key, value in os.environ.items(): if key.startswith(prefix) and value: algo_code = key[len(prefix):].lower() try: result[algo_code] = float(value) except ValueError: pass return result def _parse_working_hours(self) -> List[dict]: """解析工作时间配置""" working_hours_str = os.getenv("WORKING_HOURS", "") if not working_hours_str: return [] working_hours = [] periods = working_hours_str.split(";") for period in periods: try: start, end = period.split("-") start_h, start_m = map(int, start.split(":")) end_h, end_m = map(int, end.split(":")) working_hours.append({ "start": [start_h, start_m], "end": [end_h, end_m] }) except (ValueError, AttributeError): continue return working_hours @property def config_version(self) -> str: """获取配置版本号""" return os.getenv("CONFIG_VERSION", "1.0.0") def get_settings() -> Settings: """获取全局配置单例""" return Settings()