Compare commits

6 Commits

Author SHA1 Message Date
f5077a25a8 修复: 移除未实现的 debug_http_server 模块引用
debug_http_server.py 文件不存在导致启动报 ModuleNotFoundError,
该调试功能非必要,直接删除相关导入和方法调用。
2026-04-14 10:12:46 +08:00
9c73efe1eb 修复: 参数类型强制转换 + camelCase 防御性转换 2026-04-13 15:48:43 +08:00
bf5ddb3e7a 基础设施: 统一依赖版本 + 新增 Docker 部署配置
- requirements.txt: GPU 依赖从注释改为正式声明,统一版本
  PyTorch 2.1.2+cu121, TensorRT 8.6.1.6, ultralytics 8.3.5
  NumPy 1.24→1.26.4, OpenCV 4.8.0.74→76, 新增 onnx/Pillow 等
- Dockerfile: 基于 nvcr.io/nvidia/tensorrt:23.08-py3
  (CUDA 12.1 + cuDNN 8.9 + TRT 8.6)
- docker-compose.yml: GPU 访问、host 网络、卷挂载、日志限制
- .dockerignore: 排除模型/数据/日志等大文件

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-13 14:59:55 +08:00
56f39f1be7 修复: 全局参数线程安全 + copy保护 + 回调只清除受影响算法缓存 2026-04-13 10:21:19 +08:00
3266241064 适配: Edge 全局参数解析 + AlgorithmManager 三级参数合并 2026-04-09 17:04:11 +08:00
c6d8430867 新增: 非机动车违停检测算法(non_motor_vehicle_parking)+ 修复 illegal_parking 参数不一致 2026-04-09 10:34:55 +08:00
9 changed files with 658 additions and 65 deletions

29
.dockerignore Normal file
View File

@@ -0,0 +1,29 @@
.git
__pycache__
*.pyc
*.pyo
.idea
.vscode
.env
.env.*
!.env.example
# 模型和数据通过卷挂载
models/
data/
logs/
# 测试文件
tests/
test_*.py
pytest.ini
# 文档
docs/
*.md
!CLAUDE.md
# 临时文件
*.engine
*.onnx
*.pt

51
Dockerfile Normal file
View File

@@ -0,0 +1,51 @@
# ============================================================
# 基础镜像NVIDIA TensorRT 23.08
# 内含CUDA 12.1.1 | cuDNN 8.9.3 | TensorRT 8.6.1.6 | Python 3.10
# ============================================================
FROM nvcr.io/nvidia/tensorrt:23.08-py3
LABEL maintainer="AI Edge Architecture Team"
LABEL description="Edge AI Inference Service - YOLOv11n + TensorRT"
# 设置时区
ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
# 系统依赖视频解码、OpenCV 运行时)
RUN apt-get update && apt-get install -y --no-install-recommends \
ffmpeg \
libsm6 \
libxext6 \
libgl1-mesa-glx \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# 先复制依赖文件,利用 Docker 层缓存
COPY requirements.txt .
# 安装 PyTorchCUDA 12.1 版本)+ 其余依赖
RUN pip install --no-cache-dir \
torch==2.1.2 torchvision==0.16.2 \
--index-url https://download.pytorch.org/whl/cu121 \
&& pip install --no-cache-dir -r requirements.txt
# 复制项目代码
COPY __init__.py .
COPY main.py .
COPY algorithms.py .
COPY build_engine.py .
COPY config/ ./config/
COPY core/ ./core/
COPY utils/ ./utils/
# 模型和数据通过卷挂载,不打入镜像
# -v /path/to/models:/app/models
# -v /path/to/data:/app/data
# 日志目录
RUN mkdir -p /app/logs /app/data
EXPOSE 9001
CMD ["python", "main.py"]

View File

@@ -1283,12 +1283,304 @@ class VehicleCongestionAlgorithm(BaseAlgorithm):
return state_info
class NonMotorVehicleParkingAlgorithm(BaseAlgorithm):
"""
非机动车违停检测算法(状态机版本 v1.0
状态机:
IDLE → CONFIRMING_VEHICLE → PARKED_COUNTDOWN → ALARMED → CONFIRMING_CLEAR → IDLE
业务流程:
1. 检测到非机动车进入禁停区 → 车辆确认期confirm_vehicle_sec默认10秒ratio>=0.6
2. 确认有车 → 违停倒计时parking_countdown_sec默认180秒/3分钟
3. 倒计时结束仍有车 → 触发告警ALARMED状态
4. 车辆离开 → 消失确认期confirm_clear_sec默认60秒ratio<0.2
5. 确认车辆离开 → 发送resolve事件 → 回到空闲状态
使用滑动窗口10秒抗抖动检测自行车和摩托车。
"""
# 状态定义
STATE_IDLE = "IDLE"
STATE_CONFIRMING_VEHICLE = "CONFIRMING_VEHICLE"
STATE_PARKED_COUNTDOWN = "PARKED_COUNTDOWN"
STATE_ALARMED = "ALARMED"
STATE_CONFIRMING_CLEAR = "CONFIRMING_CLEAR"
# 告警级别常量(默认值,可通过 params 覆盖)
DEFAULT_ALARM_LEVEL = 2 # 普通
# 滑动窗口参数
WINDOW_SIZE_SEC = 10
# 阈值常量(与 IllegalParkingAlgorithm 一致)
RATIO_CONFIRMING_DROP = 0.3
RATIO_CONFIRM_VEHICLE = 0.6
RATIO_PARKED_LEAVE = 0.2
RATIO_ALARMED_CLEAR = 0.15
RATIO_CLEAR_RETURN = 0.5
RATIO_CLEAR_CONFIRM = 0.2
def __init__(
self,
confirm_vehicle_sec: int = 10,
parking_countdown_sec: int = 180,
confirm_clear_sec: int = 60,
cooldown_sec: int = 900,
target_classes: Optional[List[str]] = None,
alarm_level: Optional[int] = None,
):
super().__init__()
self.confirm_vehicle_sec = confirm_vehicle_sec
self.parking_countdown_sec = parking_countdown_sec
self.confirm_clear_sec = confirm_clear_sec
self.cooldown_sec = cooldown_sec
self.target_classes = target_classes or ["bicycle", "motorcycle"]
self._alarm_level = alarm_level if alarm_level is not None else self.DEFAULT_ALARM_LEVEL
# 状态变量
self.state: str = self.STATE_IDLE
self.state_start_time: Optional[datetime] = None
# 滑动窗口:存储 (timestamp, has_vehicle: bool)
self._detection_window: deque = deque(maxlen=1000)
# 告警追踪
self._parking_start_time: Optional[datetime] = None
# 冷却期管理
self.alert_cooldowns: Dict[str, datetime] = {}
def _check_target_classes(self, detection: Dict) -> bool:
"""检查检测目标是否属于非机动车类别"""
det_class = detection.get("class", "")
return det_class in self.target_classes
def _update_window(self, current_time: datetime, has_vehicle: bool):
"""更新滑动窗口"""
self._detection_window.append((current_time, has_vehicle))
cutoff = current_time - timedelta(seconds=self.WINDOW_SIZE_SEC)
while self._detection_window and self._detection_window[0][0] < cutoff:
self._detection_window.popleft()
def _get_window_ratio(self) -> float:
"""获取滑动窗口内的检测命中率"""
if not self._detection_window:
return 0.0
hits = sum(1 for _, has in self._detection_window if has)
return hits / len(self._detection_window)
def _scan_tracks(self, tracks: List[Dict], roi_id: str) -> Tuple[bool, int, List[float], float]:
"""
一次遍历 tracks返回 (has_target, count, latest_bbox, max_confidence)。
过滤 target_classes。
"""
has_target = False
count = 0
latest_bbox: List[float] = []
max_confidence = 0.0
for det in tracks:
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
has_target = True
count += 1
if not latest_bbox:
latest_bbox = det.get("bbox", [])
conf = det.get("confidence", 0.0)
if conf > max_confidence:
max_confidence = conf
return has_target, count, latest_bbox, max_confidence
def _get_latest_bbox(self, tracks: List[Dict], roi_id: str) -> List[float]:
for det in tracks:
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
return det.get("bbox", [])
return []
def _get_max_confidence(self, tracks: List[Dict], roi_id: str) -> float:
"""获取ROI内非机动车的最高置信度"""
max_conf = 0.0
for det in tracks:
if self._check_detection_in_roi(det, roi_id) and self._check_target_classes(det):
max_conf = max(max_conf, det.get("confidence", 0.0))
return max_conf
def process(
self,
roi_id: str,
camera_id: str,
tracks: List[Dict],
current_time: Optional[datetime] = None,
) -> List[Dict]:
"""处理单帧检测结果"""
current_time = current_time or datetime.now()
alerts = []
# 一次遍历获取所有信息
roi_has_vehicle, vehicle_count, scan_bbox, scan_confidence = self._scan_tracks(tracks, roi_id)
# 更新滑动窗口
self._update_window(current_time, roi_has_vehicle)
# 计算一次比率,后续分支复用
ratio = self._get_window_ratio()
# === 状态机处理 ===
if self.state == self.STATE_IDLE:
if roi_has_vehicle:
self.state = self.STATE_CONFIRMING_VEHICLE
self.state_start_time = current_time
logger.debug(f"ROI {roi_id}: IDLE → CONFIRMING_VEHICLE (非机动车)")
elif self.state == self.STATE_CONFIRMING_VEHICLE:
if self.state_start_time is None:
self.state = self.STATE_IDLE
return alerts
elapsed = (current_time - self.state_start_time).total_seconds()
if ratio < self.RATIO_CONFIRMING_DROP:
self.state = self.STATE_IDLE
self.state_start_time = None
logger.debug(f"ROI {roi_id}: CONFIRMING_VEHICLE → IDLE (ratio={ratio:.2f}<{self.RATIO_CONFIRMING_DROP})")
elif elapsed >= self.confirm_vehicle_sec and ratio >= self.RATIO_CONFIRM_VEHICLE:
self._parking_start_time = self.state_start_time
self.state = self.STATE_PARKED_COUNTDOWN
self.state_start_time = current_time
logger.info(f"ROI {roi_id}: CONFIRMING_VEHICLE → PARKED_COUNTDOWN (非机动车, ratio={ratio:.2f})")
elif self.state == self.STATE_PARKED_COUNTDOWN:
if self.state_start_time is None:
self.state = self.STATE_IDLE
return alerts
elapsed = (current_time - self.state_start_time).total_seconds()
if ratio < self.RATIO_PARKED_LEAVE:
self.state = self.STATE_IDLE
self.state_start_time = None
self._parking_start_time = None
logger.debug(f"ROI {roi_id}: PARKED_COUNTDOWN → IDLE (非机动车离开, ratio={ratio:.2f})")
elif elapsed >= self.parking_countdown_sec:
cooldown_key = f"{camera_id}_{roi_id}"
if cooldown_key not in self.alert_cooldowns or \
(current_time - self.alert_cooldowns[cooldown_key]).total_seconds() > self.cooldown_sec:
alerts.append({
"roi_id": roi_id,
"camera_id": camera_id,
"bbox": scan_bbox,
"alert_type": "non_motor_vehicle_parking",
"alarm_level": self._alarm_level,
"confidence": scan_confidence,
"message": f"检测到非机动车违停(已停留{int(elapsed / 60)}分钟)",
"first_frame_time": self._parking_start_time.strftime('%Y-%m-%d %H:%M:%S') if self._parking_start_time else None,
"duration_minutes": elapsed / 60,
})
self.alert_cooldowns[cooldown_key] = current_time
self.state = self.STATE_ALARMED
logger.warning(f"ROI {roi_id}: PARKED_COUNTDOWN → ALARMED (非机动车违停告警触发)")
else:
self.state = self.STATE_IDLE
self.state_start_time = None
self._parking_start_time = None
logger.debug(f"ROI {roi_id}: PARKED_COUNTDOWN → IDLE (冷却期内)")
elif self.state == self.STATE_ALARMED:
if ratio < self.RATIO_ALARMED_CLEAR:
self.state = self.STATE_CONFIRMING_CLEAR
self.state_start_time = current_time
logger.debug(f"ROI {roi_id}: ALARMED → CONFIRMING_CLEAR (ratio={ratio:.2f}<{self.RATIO_ALARMED_CLEAR})")
elif self.state == self.STATE_CONFIRMING_CLEAR:
if self.state_start_time is None:
self.state = self.STATE_IDLE
return alerts
elapsed = (current_time - self.state_start_time).total_seconds()
if ratio >= self.RATIO_CLEAR_RETURN:
self.state = self.STATE_ALARMED
self.state_start_time = None
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → ALARMED (非机动车仍在)")
elif elapsed >= self.confirm_clear_sec and ratio < self.RATIO_CLEAR_CONFIRM:
if self._last_alarm_id and self._parking_start_time:
duration_ms = int((current_time - self._parking_start_time).total_seconds() * 1000)
alerts.append({
"alert_type": "alarm_resolve",
"resolve_alarm_id": self._last_alarm_id,
"duration_ms": duration_ms,
"last_frame_time": current_time.strftime('%Y-%m-%d %H:%M:%S'),
"resolve_type": "vehicle_left",
})
logger.info(f"ROI {roi_id}: 非机动车违停告警已解决(车辆离开)")
self.state = self.STATE_IDLE
self.state_start_time = None
self._last_alarm_id = None
self._parking_start_time = None
self.alert_cooldowns.clear()
logger.debug(f"ROI {roi_id}: CONFIRMING_CLEAR → IDLE")
return alerts
def reset(self):
"""重置算法状态"""
self.state = self.STATE_IDLE
self.state_start_time = None
self._last_alarm_id = None
self._parking_start_time = None
self._detection_window.clear()
self.alert_cooldowns.clear()
def get_state(self, current_time: Optional[datetime] = None) -> Dict[str, Any]:
"""获取当前状态"""
current_time = current_time or datetime.now()
window_ratio = self._get_window_ratio()
state_info = {
"state": self.state,
"state_start_time": self.state_start_time.isoformat() if self.state_start_time else None,
"window_ratio": window_ratio,
}
if self.state in (self.STATE_ALARMED, self.STATE_PARKED_COUNTDOWN) and self._parking_start_time:
state_info["parking_duration_sec"] = (current_time - self._parking_start_time).total_seconds()
state_info["alarm_id"] = self._last_alarm_id
return state_info
class AlgorithmManager:
# 参数类型定义,用于三级合并后的类型强制转换
_PARAM_TYPES = {
"leave_post": {
"confirm_on_duty_sec": int, "confirm_off_duty_sec": int,
"confirm_leave_sec": int, "leave_countdown_sec": int, "cooldown_sec": int,
},
"intrusion": {
"cooldown_seconds": int, "confirm_seconds": int,
"confirm_intrusion_seconds": int, "confirm_clear_seconds": int,
},
"illegal_parking": {
"confirm_vehicle_sec": int, "parking_countdown_sec": int,
"confirm_clear_sec": int, "cooldown_sec": int,
},
"vehicle_congestion": {
"count_threshold": int, "confirm_congestion_sec": int,
"confirm_clear_sec": int, "cooldown_sec": int,
},
"non_motor_vehicle_parking": {
"confirm_vehicle_sec": int, "parking_countdown_sec": int,
"confirm_clear_sec": int, "cooldown_sec": int,
},
}
def __init__(self, working_hours: Optional[List[Dict]] = None):
self.algorithms: Dict[str, Dict[str, Any]] = {}
self.working_hours = working_hours or []
self._update_lock = threading.Lock()
self._registered_keys: set = set() # 已注册的 (roi_id, bind_id, algo_type) 缓存
self._global_params: Dict[str, Dict] = {} # 全局参数 {algo_code: params_dict}
# Bug fix: 默认参数与算法构造函数一致
self.default_params = {
@@ -1318,12 +1610,54 @@ class AlgorithmManager:
"cooldown_sec": 1800, # Bug fix: 与算法构造函数默认值一致1800非600
"target_classes": ["car", "truck", "bus", "motorcycle"],
},
"non_motor_vehicle_parking": {
"confirm_vehicle_sec": 10,
"parking_countdown_sec": 180,
"confirm_clear_sec": 60,
"cooldown_sec": 900,
"target_classes": ["bicycle", "motorcycle"],
},
}
self._pubsub = None
self._pubsub_thread = None
self._running = False
def update_global_params(self, global_params_map: Dict[str, Dict]):
"""更新全局参数
Args:
global_params_map: {algo_code: params_dict} 格式的全局参数
"""
with self._update_lock:
self._global_params = global_params_map or {}
logger.info(f"全局参数已更新: {list(self._global_params.keys())}")
def _coerce_param_types(self, algorithm_type: str, params: dict) -> dict:
"""强制转换参数类型,防止字符串型数字导致算法异常"""
type_map = self._PARAM_TYPES.get(algorithm_type, {})
for key, expected_type in type_map.items():
if key in params and params[key] is not None:
try:
if not isinstance(params[key], expected_type):
params[key] = expected_type(params[key])
except (ValueError, TypeError):
logger.warning(f"参数类型转换失败: {algorithm_type}.{key}={params[key]!r}, 删除使用默认值")
del params[key]
return params
def get_min_alarm_duration(self, algorithm_type: str) -> Optional[int]:
"""从全局参数获取最小告警持续时间(秒)
Args:
algorithm_type: 算法类型(如 leave_post, intrusion
Returns:
最小告警持续时间秒数,未配置返回 None
"""
gp = self._global_params.get(algorithm_type, {})
return gp.get("min_alarm_duration_sec")
def start_config_subscription(self):
"""启动配置变更订阅"""
try:
@@ -1421,6 +1755,9 @@ class AlgorithmManager:
else:
params = {}
# 强制转换参数类型(防止字符串型数字)
params = self._coerce_param_types(algo_code, params)
if roi_id not in self.algorithms:
self.algorithms[roi_id] = {}
@@ -1537,6 +1874,36 @@ class AlgorithmManager:
dissipation_ratio=algo_params["dissipation_ratio"],
)
logger.info(f"已从Redis加载拥堵算法: {key}")
elif algo_code == "non_motor_vehicle_parking":
configured_alarm_level = params.get("alarm_level")
algo_params = {
"confirm_vehicle_sec": params.get("confirm_vehicle_sec", 10),
"parking_countdown_sec": params.get("parking_countdown_sec", 180),
"confirm_clear_sec": params.get("confirm_clear_sec", 60),
"cooldown_sec": params.get("cooldown_sec", 900),
"target_classes": params.get("target_classes", ["bicycle", "motorcycle"]),
}
if key in self.algorithms.get(roi_id, {}) and "non_motor_vehicle_parking" in self.algorithms[roi_id].get(key, {}):
algo = self.algorithms[roi_id][key]["non_motor_vehicle_parking"]
algo.confirm_vehicle_sec = algo_params["confirm_vehicle_sec"]
algo.parking_countdown_sec = algo_params["parking_countdown_sec"]
algo.confirm_clear_sec = algo_params["confirm_clear_sec"]
algo.cooldown_sec = algo_params["cooldown_sec"]
algo.target_classes = algo_params["target_classes"]
if configured_alarm_level is not None:
algo._alarm_level = configured_alarm_level
logger.info(f"已热更新非机动车违停算法参数: {key}")
else:
self.algorithms[roi_id][key] = {}
self.algorithms[roi_id][key]["non_motor_vehicle_parking"] = NonMotorVehicleParkingAlgorithm(
confirm_vehicle_sec=algo_params["confirm_vehicle_sec"],
parking_countdown_sec=algo_params["parking_countdown_sec"],
confirm_clear_sec=algo_params["confirm_clear_sec"],
cooldown_sec=algo_params["cooldown_sec"],
target_classes=algo_params["target_classes"],
alarm_level=configured_alarm_level,
)
logger.info(f"已从Redis加载非机动车违停算法: {key}")
return True
except Exception as e:
@@ -1591,6 +1958,9 @@ class AlgorithmManager:
params = json.loads(params_str) if isinstance(params_str, str) else params_str
algo_code = bind_config.get("algo_code")
# 强制转换参数类型(防止字符串型数字)
params = self._coerce_param_types(algo_code, params)
# 获取现有算法实例
existing_algo = self.algorithms[roi_id][key].get(algo_code)
@@ -1661,6 +2031,19 @@ class AlgorithmManager:
logger.info(f"[{roi_id}_{bind_id}] 更新拥堵检测参数")
elif algo_code == "non_motor_vehicle_parking":
existing_algo.confirm_vehicle_sec = params.get("confirm_vehicle_sec", 10)
existing_algo.parking_countdown_sec = params.get("parking_countdown_sec", 180)
existing_algo.confirm_clear_sec = params.get("confirm_clear_sec", 60)
existing_algo.cooldown_sec = params.get("cooldown_sec", 900)
if "target_classes" in params:
existing_algo.target_classes = params["target_classes"]
alarm_level = params.get("alarm_level")
if alarm_level is not None:
existing_algo._alarm_level = alarm_level
logger.info(f"[{roi_id}_{bind_id}] 更新非机动车违停检测参数")
return True
except Exception as e:
@@ -1757,9 +2140,17 @@ class AlgorithmManager:
self.algorithms[roi_id][key] = {}
algo_params = self.default_params.get(algorithm_type, {}).copy()
# 三级合并:默认参数 → 全局参数 → 绑定级参数
with self._update_lock:
global_p = self._global_params.get(algorithm_type, {}).copy()
if global_p:
algo_params.update(global_p)
if params:
algo_params.update(params)
# 强制转换参数类型(防止字符串型数字)
algo_params = self._coerce_param_types(algorithm_type, algo_params)
# 从 params 中提取告警等级(前端配置下发)
configured_alarm_level = algo_params.get("alarm_level")
@@ -1800,6 +2191,15 @@ class AlgorithmManager:
alarm_level=configured_alarm_level,
dissipation_ratio=algo_params.get("dissipation_ratio", 0.5),
)
elif algorithm_type == "non_motor_vehicle_parking":
self.algorithms[roi_id][key]["non_motor_vehicle_parking"] = NonMotorVehicleParkingAlgorithm(
confirm_vehicle_sec=algo_params.get("confirm_vehicle_sec", 10),
parking_countdown_sec=algo_params.get("parking_countdown_sec", 180),
confirm_clear_sec=algo_params.get("confirm_clear_sec", 60),
cooldown_sec=algo_params.get("cooldown_sec", 900),
target_classes=algo_params.get("target_classes", ["bicycle", "motorcycle"]),
alarm_level=configured_alarm_level,
)
self._registered_keys.add(cache_key)
@@ -1892,7 +2292,7 @@ class AlgorithmManager:
"state": getattr(algo, "state", "WAITING"),
"alarm_sent": getattr(algo, "alarm_sent", False),
}
elif algo_type in ("illegal_parking", "vehicle_congestion"):
elif algo_type in ("illegal_parking", "vehicle_congestion", "non_motor_vehicle_parking"):
status[f"{algo_type}_{bind_id}"] = algo.get_state()
else:
status[f"{algo_type}_{bind_id}"] = {

View File

@@ -22,6 +22,7 @@ class AlgorithmType(str, Enum):
INTRUSION = "intrusion"
ILLEGAL_PARKING = "illegal_parking"
VEHICLE_CONGESTION = "vehicle_congestion"
NON_MOTOR_VEHICLE_PARKING = "non_motor_vehicle_parking"
CROWD_DETECTION = "crowd_detection"
FACE_RECOGNITION = "face_recognition"

View File

@@ -259,8 +259,17 @@ class SQLiteManager:
except Exception:
pass # 列已存在,忽略
# 算法全局参数表
cursor.execute("""
CREATE TABLE IF NOT EXISTS algo_global_params (
algo_code TEXT PRIMARY KEY,
params TEXT NOT NULL DEFAULT '{}',
updated_at TEXT
)
""")
self._init_default_algorithms()
def _init_default_algorithms(self):
"""初始化默认算法配置"""
try:
@@ -948,6 +957,39 @@ class SQLiteManager:
logger.error(f"获取所有绑定ID失败: {e}")
return []
def save_global_params(self, algo_code: str, params_dict: Dict[str, Any]) -> bool:
"""保存算法全局参数INSERT OR REPLACE"""
try:
cursor = self._conn.cursor()
now = datetime.now().isoformat()
cursor.execute("""
INSERT OR REPLACE INTO algo_global_params (algo_code, params, updated_at)
VALUES (?, ?, ?)
""", (algo_code, json.dumps(params_dict, ensure_ascii=False), now))
self._conn.commit()
return True
except Exception as e:
logger.error(f"保存算法全局参数失败: {e}")
return False
def get_all_global_params(self) -> Dict[str, Dict[str, Any]]:
"""获取所有算法全局参数,返回 {algo_code: params_dict}"""
result: Dict[str, Dict[str, Any]] = {}
try:
cursor = self._conn.cursor()
cursor.execute("SELECT algo_code, params FROM algo_global_params")
for row in cursor.fetchall():
algo_code = row[0]
params_str = row[1]
try:
result[algo_code] = json.loads(params_str) if params_str else {}
except (json.JSONDecodeError, TypeError):
result[algo_code] = {}
return result
except Exception as e:
logger.error(f"获取算法全局参数失败: {e}")
return result
def log_config_update(
self,
config_type: str,

View File

@@ -15,6 +15,7 @@ import json
import logging
import os
import platform
import re
import socket
# 禁用系统代理Clash 等代理工具会干扰 Redis TCP 长连接)
@@ -53,6 +54,12 @@ def _build_keepalive_options():
return opts
def _camel_to_snake(name: str) -> str:
"""将 camelCase 转换为 snake_case"""
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
# ==================== Redis Key 常量 ====================
# 云端 Redis Keys
@@ -643,6 +650,21 @@ class ConfigSyncManager:
# 清理 SQLite 中不在本次推送列表中的旧数据
self._cleanup_stale_records(incoming_camera_ids, incoming_roi_ids, incoming_bind_ids)
# 同步全局参数
global_params = config_data.get("global_params") or config_data.get("globalParams") or {}
if global_params and isinstance(global_params, dict):
for algo_code, params_dict in global_params.items():
if isinstance(params_dict, dict):
# 防御性转换camelCase → snake_case
params_dict = {_camel_to_snake(k): v for k, v in params_dict.items()}
self._db_manager.save_global_params(algo_code, params_dict)
logger.info(f"全局参数同步完成: {list(global_params.keys())}")
# 通知全局参数更新回调
self._notify_callbacks("global_params_update", {
"global_params": global_params,
})
except Exception as e:
logger.error(f"配置同步到 SQLite 失败: {e}")

46
docker-compose.yml Normal file
View File

@@ -0,0 +1,46 @@
version: "3.8"
services:
edge-inference:
build: .
image: edge-inference:latest
container_name: edge-inference
restart: always
# GPU 访问
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
# 环境变量
env_file:
- .env
# 卷挂载
volumes:
- ./models:/app/models # TensorRT 引擎文件
- ./data:/app/data # SQLite + 截图缓存
- ./logs:/app/logs # 运行日志
- ./.env:/app/.env # 环境配置
# 网络(需要访问摄像头 RTSP + 云端 API + Redis
network_mode: host
# 健康检查
healthcheck:
test: ["CMD", "python", "-c", "import os; assert os.path.exists('/app/main.py')"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
# 日志限制
logging:
driver: json-file
options:
max-size: "50m"
max-file: "5"

60
main.py
View File

@@ -18,7 +18,6 @@ for _key in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY", "all_prox
from config.settings import get_settings, Settings
from core.config_sync import get_config_sync_manager, ConfigSyncManager
from core.debug_http_server import start_debug_http_server
from core.video_stream import MultiStreamManager, VideoFrame
from core.preprocessor import ImagePreprocessor
from core.tensorrt_engine import TensorRTEngine, EngineManager
@@ -56,8 +55,6 @@ class EdgeInferenceService:
self._screenshot_handler: Optional[ScreenshotHandler] = None
self._algorithm_manager: Optional[AlgorithmManager] = None
self._debug_reload_thread: Optional[threading.Thread] = None
self._debug_http_server = None
self._debug_http_thread: Optional[threading.Thread] = None
self._heartbeat_thread: Optional[threading.Thread] = None
self._scheduler_thread: Optional[threading.Thread] = None
@@ -132,6 +129,18 @@ class EdgeInferenceService:
daemon=True
).start()
self._config_manager.register_callback("config_update", _on_config_update)
def _on_global_params_update(topic, data):
if self._algorithm_manager:
global_params = data.get("global_params", {})
self._algorithm_manager.update_global_params(global_params)
# 只清除受影响算法的注册缓存,避免无关算法状态丢失
affected_algos = set(global_params.keys())
keys_to_remove = [k for k in self._algorithm_manager._registered_keys if k[2] in affected_algos]
for key in keys_to_remove:
self._algorithm_manager._registered_keys.discard(key)
self._logger.info(f"全局参数回调已触发,清除 {len(keys_to_remove)} 个受影响算法的注册缓存")
self._config_manager.register_callback("global_params_update", _on_global_params_update)
self._logger.info("配置管理器初始化成功")
except Exception as e:
self._logger.error(f"配置管理器初始化失败: {e}")
@@ -198,6 +207,18 @@ class EdgeInferenceService:
try:
self._algorithm_manager = AlgorithmManager()
self._algorithm_manager.start_config_subscription()
# 启动时从 SQLite 加载已有全局参数
try:
from config.database import get_sqlite_manager
db = get_sqlite_manager()
saved_global_params = db.get_all_global_params()
if saved_global_params:
self._algorithm_manager.update_global_params(saved_global_params)
self._logger.info(f"从 SQLite 加载全局参数: {list(saved_global_params.keys())}")
except Exception as e:
self._logger.warning(f"从 SQLite 加载全局参数失败: {e}")
self._logger.info("算法管理器初始化成功")
except Exception as e:
self._logger.error(f"算法管理器初始化失败: {e}")
@@ -281,32 +302,6 @@ class EdgeInferenceService:
)
self._debug_reload_thread.start()
def _start_debug_http_server(self):
"""本地调试:启动 HTTP 同步接口"""
if self._settings.config_sync_mode != "LOCAL":
return
if not getattr(self._settings, "debug", None) or not self._settings.debug.enabled:
return
if self._debug_http_server is not None:
return
host = self._settings.debug.host
port = self._settings.debug.port
self._debug_http_server = start_debug_http_server(host, port)
def worker():
try:
self._debug_http_server.serve_forever()
except Exception as e:
self._logger.warning(f"[DEBUG] HTTP 服务器异常: {e}")
self._debug_http_thread = threading.Thread(
target=worker,
name="DebugHttpServer",
daemon=True,
)
self._debug_http_thread.start()
def _start_heartbeat(self):
"""启动心跳守护线程,每 30 秒向云端上报设备状态"""
def worker():
@@ -379,7 +374,6 @@ class EdgeInferenceService:
self._init_algorithm_manager()
self._init_screenshot_handler()
self._start_debug_reload_watcher()
self._start_debug_http_server()
self._start_heartbeat()
self._performance_stats["start_time"] = datetime.now()
@@ -1098,12 +1092,6 @@ class EdgeInferenceService:
if self._reporter:
self._reporter.close()
if self._debug_http_server:
try:
self._debug_http_server.shutdown()
except Exception:
pass
self._performance_stats["uptime_seconds"] = (
(datetime.now() - self._performance_stats["start_time"]).total_seconds()
)

View File

@@ -1,64 +1,78 @@
# Edge_Inference_Service 依赖清单
# 安装命令: pip install -r requirements.txt
# 备注:所有版本均选择最稳定版本,经过大量验证
# 环境要求: Python 3.10 | CUDA 12.1 | cuDNN 8.9 | TensorRT 8.6.1
# Docker 基础镜像: nvcr.io/nvidia/tensorrt:23.08-py3
# ============================================================
# 核心依赖(必需
# GPU 推理依赖TensorRT 8.6 + CUDA 12.1
# ============================================================
# 视频处理 - OpenCV 4.8.0最稳定的4.x版本
opencv-python==4.8.0.74
# PyTorch - CUDA 12.1 下最稳定版本
# 安装命令: pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121
--extra-index-url https://download.pytorch.org/whl/cu121
torch==2.1.2
torchvision==0.16.2
# 数值计算 - NumPy 1.24.0Python 3.8-3.11完美兼容
numpy==1.24.0
# TensorRT Python 绑定NGC 镜像已预装,裸机需手动安装)
tensorrt==8.6.1.6
pycuda==2023.1.1
# YOLO11 目标检测框架
ultralytics==8.3.5
# ONNX 模型转换与优化
onnx==1.16.0
onnxsim==0.4.36
onnxruntime-gpu==1.17.1
# ============================================================
# 核心依赖
# ============================================================
# 视频处理
opencv-python==4.8.0.76
# 数值计算(锁定 1.x避开 NumPy 2.0 破坏性变更)
numpy==1.26.4
# 图像处理
Pillow==10.2.0
# ============================================================
# 数据库依赖
# ============================================================
# ORM框架 - SQLAlchemy 2.0.23,长期支持稳定版
# ORM 框架
sqlalchemy==2.0.23
# MySQL驱动 - PyMySQL 1.1.0,成熟稳定版本
# MySQL 驱动
pymysql==1.1.0
# ============================================================
# 消息队列与缓存
# ============================================================
# MQTT客户端 - Paho-MQTT 1.6.11.x最终稳定版
# MQTT 客户端1.x 最终稳定版
paho-mqtt==1.6.1
# Redis客户端 - Redis 4.6.04.x最终稳定版
# Redis 客户端
redis==4.6.0
# 腾讯云COS SDK - 用于截图上传
# 腾讯云 COS SDK截图上传
cos-python-sdk-v5>=1.9.30
# ============================================================
# 工具库
# ============================================================
# YAML解析 - PyYAML 6.0.1,安全稳定版
pyyaml==6.0.1
requests==2.31.0
psutil==5.9.8
python-dotenv==1.0.1
# ============================================================
# 测试框架
# 测试依赖
# ============================================================
# 单元测试 - PyTest 7.4.47.x最终稳定版
pytest==7.4.4
# 覆盖率报告 - PyTest-Cov 4.1.0,成熟稳定版
pytest-cov==4.1.0
# ============================================================
# 可选依赖(按需安装)
# ============================================================
# GPU推理框架需要CUDA 12.1环境)
# tensorrt==8.6.1.6
# pycuda==2023.1.1
# YOLOv8目标检测按需安装
# ultralytics==8.0.228