Files
security-ai-edge/config/settings.py
16337 0bb9f7ebd2 feat(inference): 提高置信度阈值到0.5 - 减少误检
- conf_threshold: 0.4 → 0.5
- 提高检测精度,减少误报
- 同时更新环境变量默认值

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 11:13:45 +08:00

331 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局配置模块
定义数据库、Redis、MQTT、推理等各项配置参数
"""
import os
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class DatabaseConfig:
"""数据库配置类MySQL - 云端)"""
host: str = "localhost"
port: int = 3306
username: str = "root"
password: str = ""
database: str = "edge_inference"
pool_size: int = 10
pool_recycle: int = 3600
echo: bool = False
@dataclass
class SQLiteConfig:
"""SQLite 配置(边缘侧本地存储)"""
db_path: str = "./data/security_events.db"
image_dir: str = "./data/captures"
retention_days: int = 7
wal_mode: bool = True
batch_size: int = 100
flush_interval: float = 5.0
@dataclass
class RedisConfig:
"""Redis配置类本地 Redis边缘侧缓存"""
host: str = "localhost"
port: int = 6379
db: int = 0
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 50
@dataclass
class CloudRedisConfig:
"""云端 Redis 配置(三层权威模型 - 云端层)"""
host: str = "localhost"
port: int = 6379
db: int = 0
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 20
@dataclass
class LocalRedisConfig:
"""本地 Redis 配置(三层权威模型 - 边缘层缓存)"""
host: str = "localhost"
port: int = 6379
db: int = 1
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 20
@dataclass
class MQTTConfig:
"""MQTT配置类保留配置结构不再用于告警上报"""
broker_host: str = "localhost"
broker_port: int = 1883
client_id: str = "edge_inference_service"
device_id: str = "edge"
username: Optional[str] = None
password: Optional[str] = None
keepalive: int = 60
qos: int = 1
reconnect_delay: int = 5
max_reconnect_attempts: int = 10
@dataclass
class COSConfig:
"""腾讯云 COS 配置"""
secret_id: str = ""
secret_key: str = ""
region: str = "ap-beijing"
bucket: str = ""
@dataclass
class AlarmUploadConfig:
"""告警上报配置"""
cloud_api_url: str = "http://localhost:8000"
edge_token: str = ""
retry_max: int = 3
retry_interval: int = 5
@dataclass
class VideoStreamConfig:
"""视频流配置类"""
default_fps: int = 5
reconnect_max_attempts: int = 5
reconnect_base_delay: float = 1.0
reconnect_max_delay: float = 60.0
frame_buffer_size: int = 30
connection_timeout: int = 10
read_timeout: int = 30
@dataclass
class InferenceConfig:
"""推理配置类"""
model_path: str = "./models/yolo11n.engine"
input_width: int = 480
input_height: int = 480
batch_size: int = 1
conf_threshold: float = 0.5
nms_threshold: float = 0.45
device_id: int = 0
fp16_mode: bool = True
# ===================== Debug / Local Sync =====================
@dataclass
class DebugConfig:
"""本地调试相关配置"""
enabled: bool = True
host: str = "127.0.0.1"
port: int = 9001
reload_signal_file: str = "./config/reload.signal"
local_config_path: str = "./config/local_config.json"
# COCO 数据集类别名称YOLO 模型使用)
COCO_CLASS_NAMES = [
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat",
"traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball",
"kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket",
"bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
"chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop",
"mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink",
"refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
]
@dataclass
class CameraConfig:
"""单个摄像头配置"""
camera_id: str
rtsp_url: str
enabled: bool = True
roi_ids: List[str] = field(default_factory=list)
@dataclass
class ROIConfig:
"""ROI区域配置"""
roi_id: str
camera_id: str
roi_type: str # 'polygon' or 'rectangle'
coordinates: List[List[float]] # 多边形顶点或矩形坐标
algorithm_type: str # 'leave_post', 'intrusion', etc.
alert_threshold: int = 3 # 连续N帧触发告警
alert_cooldown: int = 300 # 告警冷却时间(秒)
class Settings:
"""全局设置单例类"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._load_env_vars()
self._initialized = True
def _load_env_vars(self):
"""从环境变量加载配置"""
# 加载 .env 文件(如果 python-dotenv 可用)
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _abs_path(path: str) -> str:
if not path:
return path
return path if os.path.isabs(path) else os.path.normpath(os.path.join(base_dir, path))
self.database = DatabaseConfig(
host=os.getenv("DB_HOST", "localhost"),
port=int(os.getenv("DB_PORT", "3306")),
username=os.getenv("DB_USERNAME", "root"),
password=os.getenv("DB_PASSWORD", ""),
database=os.getenv("DB_DATABASE", "edge_inference"),
)
self.sqlite = SQLiteConfig(
db_path=_abs_path(os.getenv("SQLITE_DB_PATH", "./data/security_events.db")),
image_dir=_abs_path(os.getenv("SQLITE_IMAGE_DIR", "./data/captures")),
retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")),
wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1",
)
self.redis = RedisConfig(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD"),
)
self.cloud_redis = CloudRedisConfig(
host=os.getenv("CLOUD_REDIS_HOST", "localhost"),
port=int(os.getenv("CLOUD_REDIS_PORT", "6379")),
db=int(os.getenv("CLOUD_REDIS_DB", "0")),
password=os.getenv("CLOUD_REDIS_PASSWORD"),
)
self.local_redis = LocalRedisConfig(
host=os.getenv("LOCAL_REDIS_HOST", "localhost"),
port=int(os.getenv("LOCAL_REDIS_PORT", "6379")),
db=int(os.getenv("LOCAL_REDIS_DB", "1")),
password=os.getenv("LOCAL_REDIS_PASSWORD"),
)
self.mqtt = MQTTConfig(
broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"),
broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")),
client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"),
device_id=os.getenv("EDGE_DEVICE_ID", "edge"),
username=os.getenv("MQTT_USERNAME"),
password=os.getenv("MQTT_PASSWORD"),
)
self.cos = COSConfig(
secret_id=os.getenv("COS_SECRET_ID", ""),
secret_key=os.getenv("COS_SECRET_KEY", ""),
region=os.getenv("COS_REGION", "ap-beijing"),
bucket=os.getenv("COS_BUCKET", ""),
)
self.alarm_upload = AlarmUploadConfig(
cloud_api_url=os.getenv("CLOUD_API_URL", "http://localhost:8000"),
edge_token=os.getenv("EDGE_TOKEN", ""),
retry_max=int(os.getenv("ALARM_RETRY_MAX", "3")),
retry_interval=int(os.getenv("ALARM_RETRY_INTERVAL", "5")),
)
self.alarm_upload_enabled = os.getenv("ALARM_UPLOAD_ENABLED", "1") == "1"
self.video_stream = VideoStreamConfig(
default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")),
reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")),
)
self.inference = InferenceConfig(
model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"),
input_width=int(os.getenv("INPUT_WIDTH", "480")),
input_height=int(os.getenv("INPUT_HEIGHT", "480")),
batch_size=int(os.getenv("BATCH_SIZE", "4")),
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.5")),
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")),
)
self.config_sync_mode = os.getenv("CONFIG_SYNC_MODE", "LOCAL").upper()
self.debug = DebugConfig(
enabled=os.getenv("DEBUG_SERVER_ENABLED", "1") == "1",
host=os.getenv("DEBUG_SERVER_HOST", "127.0.0.1"),
port=int(os.getenv("DEBUG_SERVER_PORT", "9001")),
reload_signal_file=_abs_path(os.getenv("DEBUG_RELOAD_SIGNAL_FILE", "./config/reload.signal")),
local_config_path=_abs_path(os.getenv("LOCAL_CONFIG_PATH", "./config/local_config.json")),
)
self.log_level = os.getenv("LOG_LEVEL", "INFO")
self.log_dir = os.getenv("LOG_DIR", "./logs")
self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760"))
self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5"))
self.working_hours = self._parse_working_hours()
# 使用 COCO 类别名称
self.class_names = COCO_CLASS_NAMES
def _parse_working_hours(self) -> List[dict]:
"""解析工作时间配置"""
working_hours_str = os.getenv("WORKING_HOURS", "")
if not working_hours_str:
return []
working_hours = []
periods = working_hours_str.split(";")
for period in periods:
try:
start, end = period.split("-")
start_h, start_m = map(int, start.split(":"))
end_h, end_m = map(int, end.split(":"))
working_hours.append({
"start": [start_h, start_m],
"end": [end_h, end_m]
})
except (ValueError, AttributeError):
continue
return working_hours
@property
def config_version(self) -> str:
"""获取配置版本号"""
return os.getenv("CONFIG_VERSION", "1.0.0")
def get_settings() -> Settings:
"""获取全局配置单例"""
return Settings()