Files
security-ai-edge/config/settings.py
16337 ff3d6e2653 fix: 修复告警上报字段缺失和 MQTT 连接不稳定
- AlertInfo.to_dict() 补充 bind_id、device_id、algorithm 字段
- AlertInfo 新增 device_id 和 algorithm 属性
- MQTTConfig 新增 device_id 配置项(环境变量 EDGE_DEVICE_ID)
- main.py 创建 AlertInfo 时传入 device_id 和 algorithm
- 心跳上报使用配置的 device_id 代替硬编码字符串
- MQTT 协议从 MQTTv5 降级为 MQTTv311 提高兼容性
- MQTT client_id 添加随机后缀防止冲突

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 14:13:49 +08:00

225 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局配置模块
定义数据库、Redis、MQTT、推理等各项配置参数
"""
import os
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class DatabaseConfig:
"""数据库配置类MySQL - 云端)"""
host: str = "localhost"
port: int = 3306
username: str = "root"
password: str = ""
database: str = "edge_inference"
pool_size: int = 10
pool_recycle: int = 3600
echo: bool = False
@dataclass
class SQLiteConfig:
"""SQLite 配置(边缘侧本地存储)"""
db_path: str = "./data/security_events.db"
image_dir: str = "./data/captures"
retention_days: int = 7
wal_mode: bool = True
batch_size: int = 100
flush_interval: float = 5.0
@dataclass
class RedisConfig:
"""Redis配置类"""
host: str = "localhost"
port: int = 6379
db: int = 0
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 50
@dataclass
class MQTTConfig:
"""MQTT配置类"""
broker_host: str = "localhost"
broker_port: int = 1883
client_id: str = "edge_inference_service"
device_id: str = "edge-001"
username: Optional[str] = None
password: Optional[str] = None
keepalive: int = 60
qos: int = 1
reconnect_delay: int = 5
max_reconnect_attempts: int = 10
@dataclass
class VideoStreamConfig:
"""视频流配置类"""
default_fps: int = 5
reconnect_max_attempts: int = 5
reconnect_base_delay: float = 1.0
reconnect_max_delay: float = 60.0
frame_buffer_size: int = 30
connection_timeout: int = 10
read_timeout: int = 30
@dataclass
class InferenceConfig:
"""推理配置类"""
model_path: str = "./models/yolo11n.engine"
input_width: int = 480
input_height: int = 480
batch_size: int = 1
conf_threshold: float = 0.5
nms_threshold: float = 0.45
device_id: int = 0
fp16_mode: bool = True
# COCO 数据集类别名称YOLO 模型使用)
COCO_CLASS_NAMES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
"mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
]
@dataclass
class CameraConfig:
"""单个摄像头配置"""
camera_id: str
rtsp_url: str
enabled: bool = True
roi_ids: List[str] = field(default_factory=list)
@dataclass
class ROIConfig:
"""ROI区域配置"""
roi_id: str
camera_id: str
roi_type: str # 'polygon' or 'rectangle'
coordinates: List[List[float]] # 多边形顶点或矩形坐标
algorithm_type: str # 'leave_post', 'intrusion', etc.
alert_threshold: int = 3 # 连续N帧触发告警
alert_cooldown: int = 300 # 告警冷却时间(秒)
class Settings:
"""全局设置单例类"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._load_env_vars()
self._initialized = True
def _load_env_vars(self):
"""从环境变量加载配置"""
self.database = DatabaseConfig(
host=os.getenv("DB_HOST", "localhost"),
port=int(os.getenv("DB_PORT", "3306")),
username=os.getenv("DB_USERNAME", "root"),
password=os.getenv("DB_PASSWORD", ""),
database=os.getenv("DB_DATABASE", "edge_inference"),
)
self.sqlite = SQLiteConfig(
db_path=os.getenv("SQLITE_DB_PATH", "./data/security_events.db"),
image_dir=os.getenv("SQLITE_IMAGE_DIR", "./data/captures"),
retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")),
wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1",
)
self.redis = RedisConfig(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD"),
)
self.mqtt = MQTTConfig(
broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"),
broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")),
client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"),
device_id=os.getenv("EDGE_DEVICE_ID", "edge-001"),
username=os.getenv("MQTT_USERNAME"),
password=os.getenv("MQTT_PASSWORD"),
)
self.video_stream = VideoStreamConfig(
default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")),
reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")),
)
self.inference = InferenceConfig(
model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"),
input_width=int(os.getenv("INPUT_WIDTH", "480")),
input_height=int(os.getenv("INPUT_HEIGHT", "480")),
batch_size=int(os.getenv("BATCH_SIZE", "4")),
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.5")),
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")),
)
self.log_level = os.getenv("LOG_LEVEL", "INFO")
self.log_dir = os.getenv("LOG_DIR", "./logs")
self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760"))
self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5"))
self.working_hours = self._parse_working_hours()
# 使用 COCO 类别名称
self.class_names = COCO_CLASS_NAMES
def _parse_working_hours(self) -> List[dict]:
"""解析工作时间配置"""
working_hours_str = os.getenv("WORKING_HOURS", "")
if not working_hours_str:
return []
working_hours = []
periods = working_hours_str.split(";")
for period in periods:
try:
start, end = period.split("-")
start_h, start_m = map(int, start.split(":"))
end_h, end_m = map(int, end.split(":"))
working_hours.append({
"start": [start_h, start_m],
"end": [end_h, end_m]
})
except (ValueError, AttributeError):
continue
return working_hours
@property
def config_version(self) -> str:
"""获取配置版本号"""
return os.getenv("CONFIG_VERSION", "1.0.0")
def get_settings() -> Settings:
"""获取全局配置单例"""
return Settings()