- settings: batch_size=41 - tensorrt_engine: BATCH_SIZE=41 - preprocessor: 移除 padding 逻辑,直接 batch=1 - 预处理延迟从 17ms 5ms
205 lines
6.0 KiB
Python
205 lines
6.0 KiB
Python
"""
|
||
全局配置模块
|
||
定义数据库、Redis、MQTT、推理等各项配置参数
|
||
"""
|
||
|
||
import os
|
||
from dataclasses import dataclass, field
|
||
from typing import List, Optional
|
||
|
||
|
||
@dataclass
|
||
class DatabaseConfig:
|
||
"""数据库配置类(MySQL - 云端)"""
|
||
host: str = "localhost"
|
||
port: int = 3306
|
||
username: str = "root"
|
||
password: str = ""
|
||
database: str = "edge_inference"
|
||
pool_size: int = 10
|
||
pool_recycle: int = 3600
|
||
echo: bool = False
|
||
|
||
|
||
@dataclass
|
||
class SQLiteConfig:
|
||
"""SQLite 配置(边缘侧本地存储)"""
|
||
db_path: str = "./data/security_events.db"
|
||
image_dir: str = "./data/captures"
|
||
retention_days: int = 7
|
||
wal_mode: bool = True
|
||
batch_size: int = 100
|
||
flush_interval: float = 5.0
|
||
|
||
|
||
@dataclass
|
||
class RedisConfig:
|
||
"""Redis配置类"""
|
||
host: str = "localhost"
|
||
port: int = 6379
|
||
db: int = 0
|
||
password: Optional[str] = None
|
||
decode_responses: bool = True
|
||
max_connections: int = 50
|
||
|
||
|
||
@dataclass
|
||
class MQTTConfig:
|
||
"""MQTT配置类"""
|
||
broker_host: str = "localhost"
|
||
broker_port: int = 1883
|
||
client_id: str = "edge_inference_service"
|
||
username: Optional[str] = None
|
||
password: Optional[str] = None
|
||
keepalive: int = 60
|
||
qos: int = 1
|
||
reconnect_delay: int = 5
|
||
max_reconnect_attempts: int = 10
|
||
|
||
|
||
@dataclass
|
||
class VideoStreamConfig:
|
||
"""视频流配置类"""
|
||
default_fps: int = 5
|
||
reconnect_max_attempts: int = 5
|
||
reconnect_base_delay: float = 1.0
|
||
reconnect_max_delay: float = 60.0
|
||
frame_buffer_size: int = 30
|
||
connection_timeout: int = 10
|
||
read_timeout: int = 30
|
||
|
||
|
||
@dataclass
|
||
class InferenceConfig:
|
||
"""推理配置类"""
|
||
model_path: str = "./models/yolo11n.engine"
|
||
input_width: int = 480
|
||
input_height: int = 480
|
||
batch_size: int = 1
|
||
conf_threshold: float = 0.5
|
||
nms_threshold: float = 0.45
|
||
device_id: int = 0
|
||
fp16_mode: bool = True
|
||
|
||
|
||
@dataclass
|
||
class CameraConfig:
|
||
"""单个摄像头配置"""
|
||
camera_id: str
|
||
rtsp_url: str
|
||
enabled: bool = True
|
||
roi_ids: List[str] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class ROIConfig:
|
||
"""ROI区域配置"""
|
||
roi_id: str
|
||
camera_id: str
|
||
roi_type: str # 'polygon' or 'rectangle'
|
||
coordinates: List[List[float]] # 多边形顶点或矩形坐标
|
||
algorithm_type: str # 'leave_post', 'intrusion', etc.
|
||
alert_threshold: int = 3 # 连续N帧触发告警
|
||
alert_cooldown: int = 300 # 告警冷却时间(秒)
|
||
|
||
|
||
class Settings:
|
||
"""全局设置单例类"""
|
||
|
||
_instance = None
|
||
_initialized = False
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if not self._initialized:
|
||
self._load_env_vars()
|
||
self._initialized = True
|
||
|
||
def _load_env_vars(self):
|
||
"""从环境变量加载配置"""
|
||
self.database = DatabaseConfig(
|
||
host=os.getenv("DB_HOST", "localhost"),
|
||
port=int(os.getenv("DB_PORT", "3306")),
|
||
username=os.getenv("DB_USERNAME", "root"),
|
||
password=os.getenv("DB_PASSWORD", ""),
|
||
database=os.getenv("DB_DATABASE", "edge_inference"),
|
||
)
|
||
|
||
self.sqlite = SQLiteConfig(
|
||
db_path=os.getenv("SQLITE_DB_PATH", "./data/security_events.db"),
|
||
image_dir=os.getenv("SQLITE_IMAGE_DIR", "./data/captures"),
|
||
retention_days=int(os.getenv("SQLITE_RETENTION_DAYS", "7")),
|
||
wal_mode=os.getenv("SQLITE_WAL_MODE", "1") == "1",
|
||
)
|
||
|
||
self.redis = RedisConfig(
|
||
host=os.getenv("REDIS_HOST", "localhost"),
|
||
port=int(os.getenv("REDIS_PORT", "6379")),
|
||
password=os.getenv("REDIS_PASSWORD"),
|
||
)
|
||
|
||
self.mqtt = MQTTConfig(
|
||
broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"),
|
||
broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")),
|
||
client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"),
|
||
username=os.getenv("MQTT_USERNAME"),
|
||
password=os.getenv("MQTT_PASSWORD"),
|
||
)
|
||
|
||
self.video_stream = VideoStreamConfig(
|
||
default_fps=int(os.getenv("VIDEO_DEFAULT_FPS", "5")),
|
||
reconnect_max_attempts=int(os.getenv("VIDEO_RECONNECT_ATTEMPTS", "5")),
|
||
)
|
||
|
||
self.inference = InferenceConfig(
|
||
model_path=os.getenv("MODEL_PATH", "./models/yolo11n.engine"),
|
||
input_width=int(os.getenv("INPUT_WIDTH", "480")),
|
||
input_height=int(os.getenv("INPUT_HEIGHT", "480")),
|
||
batch_size=int(os.getenv("BATCH_SIZE", "4")),
|
||
conf_threshold=float(os.getenv("CONF_THRESHOLD", "0.5")),
|
||
nms_threshold=float(os.getenv("NMS_THRESHOLD", "0.45")),
|
||
)
|
||
|
||
self.log_level = os.getenv("LOG_LEVEL", "INFO")
|
||
self.log_dir = os.getenv("LOG_DIR", "./logs")
|
||
self.log_file_max_size = int(os.getenv("LOG_FILE_MAX_SIZE", "10485760"))
|
||
self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5"))
|
||
|
||
self.working_hours = self._parse_working_hours()
|
||
|
||
def _parse_working_hours(self) -> List[dict]:
|
||
"""解析工作时间配置"""
|
||
working_hours_str = os.getenv("WORKING_HOURS", "")
|
||
if not working_hours_str:
|
||
return []
|
||
|
||
working_hours = []
|
||
periods = working_hours_str.split(";")
|
||
for period in periods:
|
||
try:
|
||
start, end = period.split("-")
|
||
start_h, start_m = map(int, start.split(":"))
|
||
end_h, end_m = map(int, end.split(":"))
|
||
working_hours.append({
|
||
"start": [start_h, start_m],
|
||
"end": [end_h, end_m]
|
||
})
|
||
except (ValueError, AttributeError):
|
||
continue
|
||
|
||
return working_hours
|
||
|
||
@property
|
||
def config_version(self) -> str:
|
||
"""获取配置版本号"""
|
||
return os.getenv("CONFIG_VERSION", "1.0.0")
|
||
|
||
|
||
def get_settings() -> Settings:
|
||
"""获取全局配置单例"""
|
||
return Settings()
|