Files
security-ai-edge/config/settings.py
16337 745cadc8e7 feat: TensorRT 固定 batch=4 重构
- tensorrt_engine.py 工业级 Buffer Pool
- preprocessor.py 添加 pad_to_batch4()
- postprocessor.py 支持批量输出
- settings.py 固定 batch_size=4
2026-02-02 14:49:47 +08:00

205 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
全局配置模块
定义数据库、Redis、MQTT、推理等各项配置参数
"""
import os
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class DatabaseConfig:
"""数据库配置类MySQL - 云端)"""
host: str = "localhost"
port: int = 3306
username: str = "root"
password: str = ""
database: str = "edge_inference"
pool_size: int = 10
pool_recycle: int = 3600
echo: bool = False
@dataclass
class SQLiteConfig:
"""SQLite 配置(边缘侧本地存储)"""
db_path: str = "./data/security_events.db"
image_dir: str = "./data/captures"
retention_days: int = 7
wal_mode: bool = True
batch_size: int = 100
flush_interval: float = 5.0
@dataclass
class RedisConfig:
"""Redis配置类"""
host: str = "localhost"
port: int = 6379
db: int = 0
password: Optional[str] = None
decode_responses: bool = True
max_connections: int = 50
@dataclass
class MQTTConfig:
"""MQTT配置类"""
broker_host: str = "localhost"
broker_port: int = 1883
client_id: str = "edge_inference_service"
username: Optional[str] = None
password: Optional[str] = None
keepalive: int = 60
qos: int = 1
reconnect_delay: int = 5
max_reconnect_attempts: int = 10
@dataclass
class VideoStreamConfig:
"""视频流配置类"""
default_fps: int = 5
reconnect_max_attempts: int = 5
reconnect_base_delay: float = 1.0
reconnect_max_delay: float = 60.0
frame_buffer_size: int = 30
connection_timeout: int = 10
read_timeout: int = 30
@dataclass
class InferenceConfig:
"""推理配置类"""
model_path: str = "./models/yolo11n.engine"
input_width: int = 480
input_height: int = 480
batch_size: int = 4
conf_threshold: float = 0.5
nms_threshold: float = 0.45
device_id: int = 0
fp16_mode: bool = True
@dataclass
class CameraConfig:
"""单个摄像头配置"""
camera_id: str
rtsp_url: str
enabled: bool = True
roi_ids: List[str] = field(default_factory=list)
@dataclass
class ROIConfig:
"""ROI区域配置"""
roi_id: str
camera_id: str
roi_type: str # 'polygon' or 'rectangle'
coordinates: List[List[float]] # 多边形顶点或矩形坐标
algorithm_type: str # 'leave_post', 'intrusion', etc.
alert_threshold: int = 3 # 连续N帧触发告警
alert_cooldown: int = 300 # 告警冷却时间(秒)
class Settings:
"""全局设置单例类"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self._load_env_vars()
self._initialized = True
def _load_env_vars(self):
"""从环境变量加载配置"""
self.database = DatabaseConfig(
host=os.getenv("DB_HOST", "localhost"),
port=int(os.getenv("DB_PORT", "3306")),
username=os.getenv("DB_USERNAME", "root"),
password=os.getenv("DB_PASSWORD", ""),
database=os.getenv("DB_DATABASE", "edge_inference"),
)
self.sqlite = SQLiteConfig(
db_path=os.getenv("SQLITE_DB_PATH", "./data/security_events.db"),
image_dir=os.getenv("SQLITE_IMAGE_DIR", "./data/captures"),
retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")),
wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1",
)
self.redis = RedisConfig(
host=os.getenv("REDIS_HOST", "localhost"),
port=int(os.getenv("REDIS_PORT", "6379")),
password=os.getenv("REDIS_PASSWORD"),
)
self.mqtt = MQTTConfig(
broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"),
broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")),
client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"),
username=os.getenv("MQTT_USERNAME"),
password=os.getenv("MQTT_PASSWORD"),
)
self.video_stream = VideoStreamConfig(
default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")),
reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")),
)
self.inference = InferenceConfig(
model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"),
input_width=int(os.getenv("INPUT_WIDTH", "480")),
input_height=int(os.getenv("INPUT_HEIGHT", "480")),
batch_size=int(os.getenv("BATCH_SIZE", "4")),
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.5")),
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")),
)
self.log_level = os.getenv("LOG_LEVEL", "INFO")
self.log_dir = os.getenv("LOG_DIR", "./logs")
self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760"))
self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5"))
self.working_hours = self._parse_working_hours()
def _parse_working_hours(self) -> List[dict]:
"""解析工作时间配置"""
working_hours_str = os.getenv("WORKING_HOURS", "")
if not working_hours_str:
return []
working_hours = []
periods = working_hours_str.split(";")
for period in periods:
try:
start, end = period.split("-")
start_h, start_m = map(int, start.split(":"))
end_h, end_m = map(int, end.split(":"))
working_hours.append({
"start": [start_h, start_m],
"end": [end_h, end_m]
})
except (ValueError, AttributeError):
continue
return working_hours
@property
def config_version(self) -> str:
"""获取配置版本号"""
return os.getenv("CONFIG_VERSION", "1.0.0")
def get_settings() -> Settings:
"""获取全局配置单例"""
return Settings()