Files
security-ai-edge/config/settings.py

381 lines
12 KiB
Python
Raw Normal View History

2026-01-29 18:33:12 +08:00
"""
全局配置模块
定义数据库RedisMQTT推理等各项配置参数
"""
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
2026-01-29 18:33:12 +08:00
@dataclass
class DatabaseConfig:
"""数据库配置类MySQL - 云端)"""
2026-01-29 18:33:12 +08:00
host: str = "localhost"
port: int = 3306
username: str = "root"
password: str = ""
database: str = "edge_inference"
pool_size: int = 10
pool_recycle: int = 3600
echo: bool = False
@dataclass
class SQLiteConfig:
"""SQLite 配置(边缘侧本地存储)"""
db_path: str = "./data/security_events.db"
image_dir: str = "./data/captures"
retention_days: int = 7
wal_mode: bool = True
batch_size: int = 100
flush_interval: float = 5.0
2026-01-29 18:33:12 +08:00
@dataclass
class RedisConfig:
"""Redis配置类本地 Redis边缘侧缓存"""
2026-01-29 18:33:12 +08:00
host: str = "localhost"
port: int = 6379
db: int = 0
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 50
@dataclass
class CloudRedisConfig:
"""云端 Redis 配置(三层权威模型 - 云端层)"""
host: str = "localhost"
port: int = 6379
db: int = 1
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 20
@dataclass
class LocalRedisConfig:
"""本地 Redis 配置(三层权威模型 - 边缘层缓存)"""
host: str = "localhost"
port: int = 6379
db: int = 1
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 20
2026-01-29 18:33:12 +08:00
@dataclass
class MQTTConfig:
"""MQTT配置类保留配置结构不再用于告警上报"""
2026-01-29 18:33:12 +08:00
broker_host: str = "localhost"
broker_port: int = 1883
client_id: str = "edge_inference_service"
device_id: str = "default"
2026-01-29 18:33:12 +08:00
username: Optional[str] = None
password: Optional[str] = None
keepalive: int = 60
qos: int = 1
reconnect_delay: int = 5
max_reconnect_attempts: int = 10
@dataclass
class COSConfig:
"""腾讯云 COS 配置"""
secret_id: str = ""
secret_key: str = ""
region: str = "ap-beijing"
bucket: str = ""
@dataclass
class AlarmUploadConfig:
"""告警上报配置"""
cloud_api_url: str = "http://localhost:8000"
wvp_api_url: str = "" # WVP 平台地址(心跳同步用)
edge_token: str = ""
retry_max: int = 3
retry_interval: int = 5
2026-01-29 18:33:12 +08:00
@dataclass
class VideoStreamConfig:
"""视频流配置类"""
default_fps: int = 5
reconnect_max_attempts: int = 5
reconnect_base_delay: float = 1.0
reconnect_max_delay: float = 60.0
frame_buffer_size: int = 30
connection_timeout: int = 10
read_timeout: int = 30
@dataclass
class InferenceConfig:
"""推理配置类"""
model_path: str = "./models/yolo11n.engine"
2026-01-29 18:33:12 +08:00
input_width: int = 480
input_height: int = 480
batch_size: int = 1
conf_threshold: float = 0.45
nms_threshold: float = 0.5
2026-01-29 18:33:12 +08:00
device_id: int = 0
fp16_mode: bool = True
# 按算法类型覆盖置信度阈值key=algo_code, value=threshold
# 未命中时回退到 conf_threshold
algo_conf_thresholds: Dict[str, float] = field(default_factory=dict)
def get_conf_threshold(self, algo_code: str) -> float:
"""获取指定算法的置信度阈值,未配置则回退全局值"""
return self.algo_conf_thresholds.get(algo_code, self.conf_threshold)
2026-01-29 18:33:12 +08:00
# ===================== Debug / Local Sync =====================
@dataclass
class DebugConfig:
"""本地调试相关配置"""
enabled: bool = True
host: str = "127.0.0.1"
port: int = 9001
reload_signal_file: str = "./config/reload.signal"
local_config_path: str = "./config/local_config.json"
# 模型类别配置(支持不同模型切换,修改此处即可)
# 当前: yolo11s_v2plus_20260421 (5 类定制模型)
# 历史: yolo11n 官方 (80 类 COCO),列表见 git 历史
MODEL_CLASS_NAMES = ["garbage", "person", "car", "bicycle", "motorcycle"]
MODEL_NUM_CLASSES = len(MODEL_CLASS_NAMES) # 模型类别数
MODEL_OUTPUT_CHANNELS = 4 + MODEL_NUM_CLASSES # YOLO 输出通道 = 4(xywh) + nc
# 向后兼容:保留 COCO_CLASS_NAMES 名称,指向当前模型类别
COCO_CLASS_NAMES = MODEL_CLASS_NAMES
# 各算法的业务关注类别(全集,不限定当前模型是否支持)
# 部署时通过 get_algo_target_classes() 自动过滤出当前模型支持的子集
ALGO_INTENT_CLASSES = {
"leave_post": ["person"],
"intrusion": ["person"],
"illegal_parking": ["car", "truck", "bus"],
"vehicle_congestion": ["car", "truck", "bus", "motorcycle"],
"non_motor_vehicle_parking": ["bicycle", "motorcycle"],
"garbage": ["garbage"],
}
def get_algo_target_classes(algo_code: str) -> list:
"""获取算法的目标类别,自动过滤出当前模型支持的类
Args:
algo_code: 算法代码 (leave_post / intrusion / illegal_parking / ...)
Returns:
目标类名列表保证每个类都在 MODEL_CLASS_NAMES
"""
intent = ALGO_INTENT_CLASSES.get(algo_code, [])
return [c for c in intent if c in MODEL_CLASS_NAMES]
2026-01-29 18:33:12 +08:00
@dataclass
class CameraConfig:
"""单个摄像头配置"""
camera_id: str
rtsp_url: str
enabled: bool = True
roi_ids: List[str] = field(default_factory=list)
@dataclass
class ROIConfig:
"""ROI区域配置"""
roi_id: str
camera_id: str
roi_type: str # 'polygon' or 'rectangle'
coordinates: List[List[float]] # 多边形顶点或矩形坐标
algorithm_type: str # 'leave_post', 'intrusion', etc.
alert_threshold: int = 3 # 连续N帧触发告警
alert_cooldown: int = 300 # 告警冷却时间(秒)
class Settings:
"""全局设置单例类"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._load_env_vars()
self._initialized = True
def _load_env_vars(self):
"""从环境变量加载配置"""
# 加载 .env 文件(如果 python-dotenv 可用)
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _abs_path(path: str) -> str:
if not path:
return path
return path if os.path.isabs(path) else os.path.normpath(os.path.join(base_dir, path))
2026-01-29 18:33:12 +08:00
self.database = DatabaseConfig(
host=os.getenv("DB_HOST", "localhost"),
port=int(os.getenv("DB_PORT", "3306")),
username=os.getenv("DB_USERNAME", "root"),
password=os.getenv("DB_PASSWORD", ""),
database=os.getenv("DB_DATABASE", "edge_inference"),
)
self.sqlite = SQLiteConfig(
db_path=_abs_path(os.getenv("SQLITE_DB_PATH", "./data/security_events.db")),
image_dir=_abs_path(os.getenv("SQLITE_IMAGE_DIR", "./data/captures")),
retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")),
wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1",
)
2026-01-29 18:33:12 +08:00
self.redis = RedisConfig(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD"),
)
self.cloud_redis = CloudRedisConfig(
host=os.getenv("CLOUD_REDIS_HOST", "localhost"),
port=int(os.getenv("CLOUD_REDIS_PORT", "6379")),
db=int(os.getenv("CLOUD_REDIS_DB", "1")),
password=os.getenv("CLOUD_REDIS_PASSWORD"),
)
self.local_redis = LocalRedisConfig(
host=os.getenv("LOCAL_REDIS_HOST", "localhost"),
port=int(os.getenv("LOCAL_REDIS_PORT", "6379")),
db=int(os.getenv("LOCAL_REDIS_DB", "1")),
password=os.getenv("LOCAL_REDIS_PASSWORD"),
)
2026-01-29 18:33:12 +08:00
self.mqtt = MQTTConfig(
broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"),
broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")),
client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"),
device_id=os.getenv("EDGE_DEVICE_ID", "default"),
2026-01-29 18:33:12 +08:00
username=os.getenv("MQTT_USERNAME"),
password=os.getenv("MQTT_PASSWORD"),
)
self.cos = COSConfig(
secret_id=os.getenv("COS_SECRET_ID", ""),
secret_key=os.getenv("COS_SECRET_KEY", ""),
region=os.getenv("COS_REGION", "ap-beijing"),
bucket=os.getenv("COS_BUCKET", ""),
)
self.alarm_upload = AlarmUploadConfig(
cloud_api_url=os.getenv("CLOUD_API_URL", "http://localhost:8000"),
wvp_api_url=os.getenv("WVP_API_URL", ""),
edge_token=os.getenv("EDGE_TOKEN", ""),
retry_max=int(os.getenv("ALARM_RETRY_MAX", "3")),
retry_interval=int(os.getenv("ALARM_RETRY_INTERVAL", "5")),
)
self.alarm_upload_enabled = os.getenv("ALARM_UPLOAD_ENABLED", "1") == "1"
2026-01-29 18:33:12 +08:00
self.video_stream = VideoStreamConfig(
default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")),
reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")),
)
self.inference = InferenceConfig(
model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"),
input_width=int(os.getenv("INPUT_WIDTH", "480")),
input_height=int(os.getenv("INPUT_HEIGHT", "480")),
batch_size=int(os.getenv("BATCH_SIZE", "4")),
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.45")),
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.5")),
algo_conf_thresholds=self._parse_algo_conf_thresholds(),
2026-01-29 18:33:12 +08:00
)
self.config_sync_mode = os.getenv("CONFIG_SYNC_MODE", "LOCAL").upper()
self.debug = DebugConfig(
enabled=os.getenv("DEBUG_SERVER_ENABLED", "1") == "1",
host=os.getenv("DEBUG_SERVER_HOST", "127.0.0.1"),
port=int(os.getenv("DEBUG_SERVER_PORT", "9001")),
reload_signal_file=_abs_path(os.getenv("DEBUG_RELOAD_SIGNAL_FILE", "./config/reload.signal")),
local_config_path=_abs_path(os.getenv("LOCAL_CONFIG_PATH", "./config/local_config.json")),
)
2026-01-29 18:33:12 +08:00
self.log_level = os.getenv("LOG_LEVEL", "INFO")
self.log_dir = os.getenv("LOG_DIR", "./logs")
self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760"))
2026-01-29 18:33:12 +08:00
self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5"))
self.working_hours = self._parse_working_hours()
# 使用 COCO 类别名称
self.class_names = COCO_CLASS_NAMES
2026-01-29 18:33:12 +08:00
@staticmethod
def _parse_algo_conf_thresholds() -> Dict[str, float]:
"""解析 ALGO_CONF_* 环境变量,返回 {algo_code: threshold} 字典
环境变量命名规则: ALGO_CONF_{ALGO_CODE}:
ALGO_CONF_LEAVE_POST=0.35
ALGO_CONF_INTRUSION=0.55
"""
prefix = "ALGO_CONF_"
result = {}
for key, value in os.environ.items():
if key.startswith(prefix) and value:
algo_code = key[len(prefix):].lower()
try:
result[algo_code] = float(value)
except ValueError:
pass
return result
2026-01-29 18:33:12 +08:00
def _parse_working_hours(self) -> List[dict]:
"""解析工作时间配置"""
working_hours_str = os.getenv("WORKING_HOURS", "")
if not working_hours_str:
return []
working_hours = []
periods = working_hours_str.split(";")
for period in periods:
try:
start, end = period.split("-")
start_h, start_m = map(int, start.split(":"))
end_h, end_m = map(int, end.split(":"))
working_hours.append({
"start": [start_h, start_m],
"end": [end_h, end_m]
})
except (ValueError, AttributeError):
continue
return working_hours
@property
def config_version(self) -> str:
"""获取配置版本号"""
return os.getenv("CONFIG_VERSION", "1.0.0")
def get_settings() -> Settings:
"""获取全局配置单例"""
return Settings()