Files
security-ai-edge/config/settings.py
16337 003c2885b9 重构: 模型类别配置动态化,适配 v2+ (5 类) 模型
- settings.py: 删除 80 类 COCO_CLASS_NAMES 硬编码,改为 MODEL_CLASS_NAMES
  + 新增 MODEL_NUM_CLASSES / MODEL_OUTPUT_CHANNELS 辅助常量
  + 新增 ALGO_INTENT_CLASSES 字典 + get_algo_target_classes() 辅助函数
  + COCO_CLASS_NAMES 保留名称向后兼容,指向 MODEL_CLASS_NAMES

- postprocessor.py: 4 处硬编码 84 → MODEL_OUTPUT_CHANNELS
  + 支持不同类别数模型切换不改代码

- algorithms.py: 4 处硬编码 target_classes 默认值 → get_algo_target_classes()
  + IllegalParkingAlgorithm / VehicleCongestionAlgorithm /
    NonMotorVehicleParkingAlgorithm / GarbageDetectionAlgorithm
  + 自动过滤当前模型不支持的类(truck/bus 等)

以后换模型只需修改 settings.py 一处 MODEL_CLASS_NAMES。
2026-04-21 14:46:53 +08:00

381 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局配置模块
定义数据库、Redis、MQTT、推理等各项配置参数
"""
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@dataclass
class DatabaseConfig:
"""数据库配置类MySQL - 云端)"""
host: str = "localhost"
port: int = 3306
username: str = "root"
password: str = ""
database: str = "edge_inference"
pool_size: int = 10
pool_recycle: int = 3600
echo: bool = False
@dataclass
class SQLiteConfig:
"""SQLite 配置(边缘侧本地存储)"""
db_path: str = "./data/security_events.db"
image_dir: str = "./data/captures"
retention_days: int = 7
wal_mode: bool = True
batch_size: int = 100
flush_interval: float = 5.0
@dataclass
class RedisConfig:
"""Redis配置类本地 Redis边缘侧缓存"""
host: str = "localhost"
port: int = 6379
db: int = 0
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 50
@dataclass
class CloudRedisConfig:
"""云端 Redis 配置(三层权威模型 - 云端层)"""
host: str = "localhost"
port: int = 6379
db: int = 1
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 20
@dataclass
class LocalRedisConfig:
"""本地 Redis 配置(三层权威模型 - 边缘层缓存)"""
host: str = "localhost"
port: int = 6379
db: int = 1
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 20
@dataclass
class MQTTConfig:
"""MQTT配置类保留配置结构不再用于告警上报"""
broker_host: str = "localhost"
broker_port: int = 1883
client_id: str = "edge_inference_service"
device_id: str = "default"
username: Optional[str] = None
password: Optional[str] = None
keepalive: int = 60
qos: int = 1
reconnect_delay: int = 5
max_reconnect_attempts: int = 10
@dataclass
class COSConfig:
"""腾讯云 COS 配置"""
secret_id: str = ""
secret_key: str = ""
region: str = "ap-beijing"
bucket: str = ""
@dataclass
class AlarmUploadConfig:
"""告警上报配置"""
cloud_api_url: str = "http://localhost:8000"
wvp_api_url: str = "" # WVP 平台地址(心跳同步用)
edge_token: str = ""
retry_max: int = 3
retry_interval: int = 5
@dataclass
class VideoStreamConfig:
"""视频流配置类"""
default_fps: int = 5
reconnect_max_attempts: int = 5
reconnect_base_delay: float = 1.0
reconnect_max_delay: float = 60.0
frame_buffer_size: int = 30
connection_timeout: int = 10
read_timeout: int = 30
@dataclass
class InferenceConfig:
"""推理配置类"""
model_path: str = "./models/yolo11n.engine"
input_width: int = 480
input_height: int = 480
batch_size: int = 1
conf_threshold: float = 0.45
nms_threshold: float = 0.5
device_id: int = 0
fp16_mode: bool = True
# 按算法类型覆盖置信度阈值key=algo_code, value=threshold
# 未命中时回退到 conf_threshold
algo_conf_thresholds: Dict[str, float] = field(default_factory=dict)
def get_conf_threshold(self, algo_code: str) -> float:
"""获取指定算法的置信度阈值,未配置则回退全局值"""
return self.algo_conf_thresholds.get(algo_code, self.conf_threshold)
# ===================== Debug / Local Sync =====================
@dataclass
class DebugConfig:
"""本地调试相关配置"""
enabled: bool = True
host: str = "127.0.0.1"
port: int = 9001
reload_signal_file: str = "./config/reload.signal"
local_config_path: str = "./config/local_config.json"
# 模型类别配置(支持不同模型切换,修改此处即可)
# 当前: yolo11s_v2plus_20260421 (5 类定制模型)
# 历史: yolo11n 官方 (80 类 COCO),列表见 git 历史
MODEL_CLASS_NAMES = ["garbage", "person", "car", "bicycle", "motorcycle"]
MODEL_NUM_CLASSES = len(MODEL_CLASS_NAMES) # 模型类别数
MODEL_OUTPUT_CHANNELS = 4 + MODEL_NUM_CLASSES # YOLO 输出通道 = 4(xywh) + nc
# 向后兼容:保留 COCO_CLASS_NAMES 名称,指向当前模型类别
COCO_CLASS_NAMES = MODEL_CLASS_NAMES
# 各算法的业务关注类别(全集,不限定当前模型是否支持)
# 部署时通过 get_algo_target_classes() 自动过滤出当前模型支持的子集
ALGO_INTENT_CLASSES = {
"leave_post": ["person"],
"intrusion": ["person"],
"illegal_parking": ["car", "truck", "bus"],
"vehicle_congestion": ["car", "truck", "bus", "motorcycle"],
"non_motor_vehicle_parking": ["bicycle", "motorcycle"],
"garbage": ["garbage"],
}
def get_algo_target_classes(algo_code: str) -> list:
"""获取算法的目标类别,自动过滤出当前模型支持的类
Args:
algo_code: 算法代码 (leave_post / intrusion / illegal_parking / ...)
Returns:
目标类名列表,保证每个类都在 MODEL_CLASS_NAMES 中
"""
intent = ALGO_INTENT_CLASSES.get(algo_code, [])
return [c for c in intent if c in MODEL_CLASS_NAMES]
@dataclass
class CameraConfig:
"""单个摄像头配置"""
camera_id: str
rtsp_url: str
enabled: bool = True
roi_ids: List[str] = field(default_factory=list)
@dataclass
class ROIConfig:
"""ROI区域配置"""
roi_id: str
camera_id: str
roi_type: str # 'polygon' or 'rectangle'
coordinates: List[List[float]] # 多边形顶点或矩形坐标
algorithm_type: str # 'leave_post', 'intrusion', etc.
alert_threshold: int = 3 # 连续N帧触发告警
alert_cooldown: int = 300 # 告警冷却时间(秒)
class Settings:
"""全局设置单例类"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._load_env_vars()
self._initialized = True
def _load_env_vars(self):
"""从环境变量加载配置"""
# 加载 .env 文件(如果 python-dotenv 可用)
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _abs_path(path: str) -> str:
if not path:
return path
return path if os.path.isabs(path) else os.path.normpath(os.path.join(base_dir, path))
self.database = DatabaseConfig(
host=os.getenv("DB_HOST", "localhost"),
port=int(os.getenv("DB_PORT", "3306")),
username=os.getenv("DB_USERNAME", "root"),
password=os.getenv("DB_PASSWORD", ""),
database=os.getenv("DB_DATABASE", "edge_inference"),
)
self.sqlite = SQLiteConfig(
db_path=_abs_path(os.getenv("SQLITE_DB_PATH", "./data/security_events.db")),
image_dir=_abs_path(os.getenv("SQLITE_IMAGE_DIR", "./data/captures")),
retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")),
wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1",
)
self.redis = RedisConfig(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD"),
)
self.cloud_redis = CloudRedisConfig(
host=os.getenv("CLOUD_REDIS_HOST", "localhost"),
port=int(os.getenv("CLOUD_REDIS_PORT", "6379")),
db=int(os.getenv("CLOUD_REDIS_DB", "1")),
password=os.getenv("CLOUD_REDIS_PASSWORD"),
)
self.local_redis = LocalRedisConfig(
host=os.getenv("LOCAL_REDIS_HOST", "localhost"),
port=int(os.getenv("LOCAL_REDIS_PORT", "6379")),
db=int(os.getenv("LOCAL_REDIS_DB", "1")),
password=os.getenv("LOCAL_REDIS_PASSWORD"),
)
self.mqtt = MQTTConfig(
broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"),
broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")),
client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"),
device_id=os.getenv("EDGE_DEVICE_ID", "default"),
username=os.getenv("MQTT_USERNAME"),
password=os.getenv("MQTT_PASSWORD"),
)
self.cos = COSConfig(
secret_id=os.getenv("COS_SECRET_ID", ""),
secret_key=os.getenv("COS_SECRET_KEY", ""),
region=os.getenv("COS_REGION", "ap-beijing"),
bucket=os.getenv("COS_BUCKET", ""),
)
self.alarm_upload = AlarmUploadConfig(
cloud_api_url=os.getenv("CLOUD_API_URL", "http://localhost:8000"),
wvp_api_url=os.getenv("WVP_API_URL", ""),
edge_token=os.getenv("EDGE_TOKEN", ""),
retry_max=int(os.getenv("ALARM_RETRY_MAX", "3")),
retry_interval=int(os.getenv("ALARM_RETRY_INTERVAL", "5")),
)
self.alarm_upload_enabled = os.getenv("ALARM_UPLOAD_ENABLED", "1") == "1"
self.video_stream = VideoStreamConfig(
default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")),
reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")),
)
self.inference = InferenceConfig(
model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"),
input_width=int(os.getenv("INPUT_WIDTH", "480")),
input_height=int(os.getenv("INPUT_HEIGHT", "480")),
batch_size=int(os.getenv("BATCH_SIZE", "4")),
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.45")),
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.5")),
algo_conf_thresholds=self._parse_algo_conf_thresholds(),
)
self.config_sync_mode = os.getenv("CONFIG_SYNC_MODE", "LOCAL").upper()
self.debug = DebugConfig(
enabled=os.getenv("DEBUG_SERVER_ENABLED", "1") == "1",
host=os.getenv("DEBUG_SERVER_HOST", "127.0.0.1"),
port=int(os.getenv("DEBUG_SERVER_PORT", "9001")),
reload_signal_file=_abs_path(os.getenv("DEBUG_RELOAD_SIGNAL_FILE", "./config/reload.signal")),
local_config_path=_abs_path(os.getenv("LOCAL_CONFIG_PATH", "./config/local_config.json")),
)
self.log_level = os.getenv("LOG_LEVEL", "INFO")
self.log_dir = os.getenv("LOG_DIR", "./logs")
self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760"))
self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5"))
self.working_hours = self._parse_working_hours()
# 使用 COCO 类别名称
self.class_names = COCO_CLASS_NAMES
@staticmethod
def _parse_algo_conf_thresholds() -> Dict[str, float]:
"""解析 ALGO_CONF_* 环境变量,返回 {algo_code: threshold} 字典
环境变量命名规则: ALGO_CONF_{ALGO_CODE},如:
ALGO_CONF_LEAVE_POST=0.35
ALGO_CONF_INTRUSION=0.55
"""
prefix = "ALGO_CONF_"
result = {}
for key, value in os.environ.items():
if key.startswith(prefix) and value:
algo_code = key[len(prefix):].lower()
try:
result[algo_code] = float(value)
except ValueError:
pass
return result
def _parse_working_hours(self) -> List[dict]:
"""解析工作时间配置"""
working_hours_str = os.getenv("WORKING_HOURS", "")
if not working_hours_str:
return []
working_hours = []
periods = working_hours_str.split(";")
for period in periods:
try:
start, end = period.split("-")
start_h, start_m = map(int, start.split(":"))
end_h, end_m = map(int, end.split(":"))
working_hours.append({
"start": [start_h, start_m],
"end": [end_h, end_m]
})
except (ValueError, AttributeError):
continue
return working_hours
@property
def config_version(self) -> str:
"""获取配置版本号"""
return os.getenv("CONFIG_VERSION", "1.0.0")
def get_settings() -> Settings:
"""获取全局配置单例"""
return Settings()