- AlertInfo.to_dict() 补充 bind_id、device_id、algorithm 字段 - AlertInfo 新增 device_id 和 algorithm 属性 - MQTTConfig 新增 device_id 配置项(环境变量 EDGE_DEVICE_ID) - main.py 创建 AlertInfo 时传入 device_id 和 algorithm - 心跳上报使用配置的 device_id 代替硬编码字符串 - MQTT 协议从 MQTTv5 降级为 MQTTv311 提高兼容性 - MQTT client_id 添加随机后缀防止冲突 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
476 lines
16 KiB
Python
476 lines
16 KiB
Python
"""
|
||
结果上报模块
|
||
支持 MQTT 推送和本地 SQLite 存储
|
||
|
||
存储策略:
|
||
- MQTT: 实时推送告警到云端
|
||
- SQLite: 本地异步存储(边缘侧断网容灾)
|
||
- 断网时自动缓存到本地,恢复后自动同步
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import threading
|
||
import time
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional, Callable
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import paho.mqtt.client as mqtt
|
||
from paho.mqtt.client import MQTTMessage
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class AlertInfo:
|
||
"""告警信息类"""
|
||
alert_id: str
|
||
camera_id: str
|
||
roi_id: str
|
||
alert_type: str
|
||
bind_id: Optional[str] = None
|
||
device_id: Optional[str] = None
|
||
algorithm: Optional[str] = None
|
||
target_class: Optional[str] = None
|
||
confidence: Optional[float] = None
|
||
bbox: Optional[List[float]] = field(default_factory=list)
|
||
message: Optional[str] = None
|
||
screenshot: Optional[np.ndarray] = None
|
||
timestamp: datetime = field(default_factory=datetime.now)
|
||
duration_minutes: Optional[float] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典(发送到告警平台的 MQTT 消息格式)"""
|
||
return {
|
||
"alert_id": self.alert_id,
|
||
"camera_id": self.camera_id,
|
||
"roi_id": self.roi_id,
|
||
"bind_id": self.bind_id,
|
||
"device_id": self.device_id,
|
||
"alert_type": self.alert_type,
|
||
"algorithm": self.algorithm,
|
||
"target_class": self.target_class,
|
||
"confidence": self.confidence,
|
||
"bbox": self.bbox,
|
||
"message": self.message,
|
||
"timestamp": self.timestamp.isoformat(),
|
||
"duration_minutes": self.duration_minutes,
|
||
}
|
||
|
||
|
||
class ResultReporter:
|
||
"""结果上报器类"""
|
||
|
||
def __init__(
|
||
self,
|
||
mqtt_broker: str = "localhost",
|
||
mqtt_port: int = 1883,
|
||
topic_prefix: str = "edge/alert",
|
||
mqtt_client_id: str = "edge_reporter"
|
||
):
|
||
self._mqtt_broker = mqtt_broker
|
||
self._mqtt_port = mqtt_port
|
||
self._topic_prefix = topic_prefix
|
||
self._mqtt_client_id = mqtt_client_id
|
||
|
||
self._client: Optional[mqtt.Client] = None
|
||
self._connected = False
|
||
self._reconnect_count = 0
|
||
self._lock = threading.Lock()
|
||
|
||
self._performance_stats = {
|
||
"alerts_generated": 0,
|
||
"alerts_sent": 0,
|
||
"alerts_stored": 0,
|
||
"send_failures": 0,
|
||
}
|
||
|
||
self._logger = logging.getLogger("result_reporter")
|
||
|
||
self._db_manager = None
|
||
self._image_storage = None
|
||
self._local_cache = None
|
||
|
||
self._logger.info("ResultReporter 初始化完成")
|
||
|
||
def initialize(self):
|
||
"""初始化存储和MQTT"""
|
||
from config.database import SQLiteManager, AlertRecord, get_sqlite_manager
|
||
from core.storage_manager import get_image_storage, get_local_cache
|
||
|
||
self._logger.info("初始化存储管理器...")
|
||
|
||
self._db_manager = get_sqlite_manager()
|
||
self._image_storage = get_image_storage()
|
||
self._local_cache = get_local_cache()
|
||
|
||
self._logger.info("存储管理器初始化完成")
|
||
|
||
self._init_mqtt()
|
||
|
||
def _init_mqtt(self):
|
||
"""初始化MQTT客户端"""
|
||
self._logger.info(f"正在连接 MQTT: {self._mqtt_broker}:{self._mqtt_port}")
|
||
try:
|
||
# 给 client_id 添加随机后缀,防止冲突
|
||
unique_client_id = f"{self._mqtt_client_id}_{uuid.uuid4().hex[:8]}"
|
||
|
||
# 兼容不同版本的 paho-mqtt
|
||
try:
|
||
# paho-mqtt 2.0+ 版本
|
||
self._client = mqtt.Client(
|
||
client_id=unique_client_id,
|
||
protocol=mqtt.MQTTv311,
|
||
callback_api_version=mqtt.CallbackAPIVersion.VERSION2
|
||
)
|
||
except (AttributeError, TypeError):
|
||
# paho-mqtt 1.x 版本
|
||
self._client = mqtt.Client(
|
||
client_id=unique_client_id,
|
||
protocol=mqtt.MQTTv311
|
||
)
|
||
|
||
self._client.on_connect = self._on_connect
|
||
self._client.on_disconnect = self._on_disconnect
|
||
self._client.on_publish = self._on_publish
|
||
|
||
self._client.connect(self._mqtt_broker, self._mqtt_port, 60)
|
||
self._client.loop_start()
|
||
|
||
self._logger.info(f"MQTT 客户端初始化完成: {self._mqtt_broker}:{self._mqtt_port}")
|
||
|
||
except Exception as e:
|
||
self._logger.error(f"MQTT 初始化失败: {e}")
|
||
import traceback
|
||
self._logger.error(traceback.format_exc())
|
||
self._client = None
|
||
|
||
def _on_connect(self, client, userdata, flags, reason_code, properties=None):
|
||
"""MQTT连接回调(兼容 1.x 和 2.x)"""
|
||
# paho-mqtt 1.x: rc 是整数,0 表示成功
|
||
# paho-mqtt 2.x: reason_code 是对象,需要检查 .value 或直接比较
|
||
rc = reason_code if isinstance(reason_code, int) else getattr(reason_code, 'value', reason_code)
|
||
if rc == 0:
|
||
self._connected = True
|
||
self._reconnect_count = 0
|
||
self._logger.info("MQTT 连接成功")
|
||
|
||
if self._local_cache:
|
||
pending = self._local_cache.get_pending_count()
|
||
if pending > 0:
|
||
self._logger.info(f"有 {pending} 条待同步的缓存告警")
|
||
else:
|
||
self._logger.warning(f"MQTT 连接失败: {reason_code}")
|
||
|
||
def _on_disconnect(self, client, userdata, reason_code, properties=None):
|
||
"""MQTT断开连接回调(兼容 1.x 和 2.x)"""
|
||
self._connected = False
|
||
self._logger.warning(f"MQTT 连接断开: {reason_code}")
|
||
|
||
def _on_publish(self, client, userdata, mid, reason_code=None, properties=None):
|
||
"""MQTT发布回调(兼容 1.x 和 2.x)"""
|
||
rc = 0 if reason_code is None else (reason_code if isinstance(reason_code, int) else getattr(reason_code, 'value', 0))
|
||
if rc == 0:
|
||
self._logger.debug(f"MQTT 消息发布成功: {mid}")
|
||
|
||
def report_alert(
|
||
self,
|
||
alert: AlertInfo,
|
||
store_to_db: bool = True,
|
||
screenshot: Optional[np.ndarray] = None
|
||
) -> bool:
|
||
"""
|
||
上报告警
|
||
|
||
Args:
|
||
alert: 告警信息
|
||
store_to_db: 是否存储到本地数据库
|
||
screenshot: 抓拍图片
|
||
|
||
Returns:
|
||
是否上报成功
|
||
"""
|
||
self._performance_stats["alerts_generated"] += 1
|
||
|
||
try:
|
||
if store_to_db and self._db_manager:
|
||
self._store_alert(alert, screenshot)
|
||
|
||
if self._connected and self._client:
|
||
self._publish_alert(alert)
|
||
else:
|
||
self._logger.warning("MQTT 未连接,消息已加入待发送队列")
|
||
if self._local_cache:
|
||
self._local_cache.cache_alert(alert.to_dict())
|
||
|
||
self._performance_stats["alerts_sent"] += 1
|
||
return True
|
||
|
||
except Exception as e:
|
||
self._performance_stats["send_failures"] += 1
|
||
self._logger.error(f"上报告警失败: {e}")
|
||
return False
|
||
|
||
def _store_alert(self, alert: AlertInfo, screenshot: Optional[np.ndarray] = None):
|
||
"""存储告警到本地数据库(异步)"""
|
||
image_path = None
|
||
|
||
if screenshot is not None and self._image_storage:
|
||
image_path = self._image_storage.save_capture(
|
||
image=screenshot,
|
||
camera_id=alert.camera_id,
|
||
alert_id=alert.alert_id,
|
||
timestamp=alert.timestamp
|
||
)
|
||
|
||
record = AlertRecord(
|
||
alert_id=alert.alert_id,
|
||
camera_id=alert.camera_id,
|
||
roi_id=alert.roi_id,
|
||
alert_type=alert.alert_type,
|
||
target_class=alert.target_class,
|
||
confidence=alert.confidence,
|
||
bbox=alert.bbox,
|
||
message=alert.message,
|
||
image_path=image_path,
|
||
status="pending",
|
||
created_at=alert.timestamp,
|
||
duration_minutes=alert.duration_minutes,
|
||
)
|
||
|
||
if self._db_manager:
|
||
self._db_manager.queue_alert(record)
|
||
self._performance_stats["alerts_stored"] += 1
|
||
|
||
def _publish_alert(self, alert: AlertInfo):
|
||
"""发布告警到MQTT"""
|
||
alert_data = alert.to_dict()
|
||
|
||
topic = f"{self._topic_prefix}/{alert.camera_id}/{alert.roi_id}"
|
||
|
||
result = self._client.publish(topic, json.dumps(alert_data, ensure_ascii=False))
|
||
|
||
if result[0] == mqtt.MQTT_ERR_SUCCESS:
|
||
self._logger.info(
|
||
f"告警已发布: type={alert.alert_type}, "
|
||
f"camera={alert.camera_id}, roi={alert.roi_id}, "
|
||
f"confidence={alert.confidence}"
|
||
)
|
||
else:
|
||
raise Exception(f"MQTT 发布失败: {result[0]}")
|
||
|
||
def report_heartbeat(
|
||
self,
|
||
device_id: str,
|
||
status: Dict[str, Any]
|
||
) -> bool:
|
||
"""上报心跳"""
|
||
try:
|
||
heartbeat_data = {
|
||
"device_id": device_id,
|
||
"status": status,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
|
||
topic = f"{self._topic_prefix}/heartbeat/{device_id}"
|
||
|
||
if self._client and self._connected:
|
||
result = self._client.publish(topic, json.dumps(heartbeat_data, ensure_ascii=False))
|
||
if result[0] == mqtt.MQTT_ERR_SUCCESS:
|
||
self._logger.debug(f"心跳上报成功: {device_id}")
|
||
return True
|
||
|
||
self._logger.warning(f"心跳上报失败(MQTT未连接): {device_id}")
|
||
return False
|
||
except Exception as e:
|
||
self._logger.error(f"心跳上报异常: {e}")
|
||
return False
|
||
|
||
def get_pending_alerts(self) -> List[Dict[str, Any]]:
|
||
"""获取待同步的告警"""
|
||
if self._local_cache:
|
||
return self._local_cache.get_pending_alerts()
|
||
return []
|
||
|
||
def sync_pending_alerts(self) -> int:
|
||
"""同步待处理的告警到云端"""
|
||
if not self._connected or not self._client:
|
||
return 0
|
||
|
||
pending = self.get_pending_alerts()
|
||
synced = 0
|
||
|
||
for alert in pending:
|
||
try:
|
||
self._client.publish(
|
||
f"{self._topic_prefix}/{alert['camera_id']}/{alert['roi_id']}",
|
||
json.dumps(alert, ensure_ascii=False)
|
||
)
|
||
if self._local_cache:
|
||
self._local_cache.remove_cached(alert.get('_cache_id', ''))
|
||
synced += 1
|
||
except Exception as e:
|
||
self._logger.error(f"同步告警失败: {e}")
|
||
|
||
if synced > 0:
|
||
self._logger.info(f"已同步 {synced} 条告警到云端")
|
||
|
||
return synced
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取统计信息"""
|
||
stats = self._performance_stats.copy()
|
||
|
||
if self._db_manager:
|
||
db_stats = self._db_manager.get_statistics()
|
||
stats.update({
|
||
"db_total_alerts": db_stats.get("total_alerts", 0),
|
||
"db_pending_alerts": db_stats.get("pending_alerts", 0),
|
||
})
|
||
|
||
if self._image_storage:
|
||
img_stats = self._image_storage.get_statistics()
|
||
stats["image_saved"] = img_stats.get("saved_count", 0)
|
||
|
||
if self._local_cache:
|
||
stats["pending_sync"] = self._local_cache.get_pending_count()
|
||
|
||
stats["mqtt_connected"] = self._connected
|
||
|
||
return stats
|
||
|
||
def cleanup(self):
|
||
"""清理资源"""
|
||
self._logger.info("ResultReporter 资源清理")
|
||
|
||
if self._image_storage:
|
||
self._image_storage.close()
|
||
|
||
if self._db_manager:
|
||
self._db_manager.close()
|
||
|
||
if self._client:
|
||
self._client.loop_stop()
|
||
self._client.disconnect()
|
||
|
||
def close(self):
|
||
"""关闭上报器(别名)"""
|
||
self.cleanup()
|
||
|
||
if self._client:
|
||
self._client.loop_stop()
|
||
self._client.disconnect()
|
||
|
||
self._logger.info("ResultReporter 清理完成")
|
||
|
||
|
||
class AlertReporter:
|
||
"""告警上报器(简化版)"""
|
||
|
||
def __init__(self, topic_prefix: str = "edge/alert"):
|
||
self._topic_prefix = topic_prefix
|
||
self._lock = threading.Lock()
|
||
self._performance_stats = {
|
||
"alerts_generated": 0,
|
||
"alerts_sent": 0,
|
||
"alerts_stored": 0,
|
||
"send_failures": 0,
|
||
}
|
||
self._logger = logging.getLogger("alert_reporter")
|
||
|
||
def initialize(self):
|
||
"""初始化"""
|
||
from config.database import get_sqlite_manager
|
||
from core.storage_manager import get_image_storage
|
||
|
||
self._db_manager = get_sqlite_manager()
|
||
self._image_storage = get_image_storage()
|
||
self._logger.info("AlertReporter 初始化完成")
|
||
|
||
def report_alert(
|
||
self,
|
||
alert: AlertInfo,
|
||
store_to_db: bool = True,
|
||
screenshot: Optional[np.ndarray] = None
|
||
) -> bool:
|
||
"""
|
||
上报告警
|
||
|
||
Args:
|
||
alert: 告警信息
|
||
store_to_db: 是否存储到本地数据库
|
||
screenshot: 抓拍图片
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
with self._lock:
|
||
self._performance_stats["alerts_generated"] += 1
|
||
|
||
try:
|
||
if store_to_db and self._db_manager:
|
||
self._store_alert(alert, screenshot)
|
||
|
||
self._logger.info(
|
||
f"告警已记录: {alert.alert_type} - {alert.camera_id}/{alert.roi_id}"
|
||
)
|
||
|
||
with self._lock:
|
||
self._performance_stats["alerts_sent"] += 1
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
with self._lock:
|
||
self._performance_stats["send_failures"] += 1
|
||
self._logger.error(f"上报告警失败: {e}")
|
||
return False
|
||
|
||
def _store_alert(self, alert: AlertInfo, screenshot: Optional[np.ndarray] = None):
|
||
"""存储告警"""
|
||
from config.database import AlertRecord, get_sqlite_manager
|
||
|
||
image_path = None
|
||
|
||
if screenshot is not None and self._image_storage:
|
||
image_path = self._image_storage.save_capture(
|
||
image=screenshot,
|
||
camera_id=alert.camera_id,
|
||
alert_id=alert.alert_id,
|
||
timestamp=alert.timestamp
|
||
)
|
||
|
||
record = AlertRecord(
|
||
alert_id=alert.alert_id,
|
||
camera_id=alert.camera_id,
|
||
roi_id=alert.roi_id,
|
||
alert_type=alert.alert_type,
|
||
target_class=alert.target_class,
|
||
confidence=alert.confidence,
|
||
bbox=alert.bbox,
|
||
message=alert.message,
|
||
image_path=image_path,
|
||
status="pending",
|
||
created_at=alert.timestamp,
|
||
)
|
||
|
||
db_manager = get_sqlite_manager()
|
||
db_manager.queue_alert(record)
|
||
|
||
with self._lock:
|
||
self._performance_stats["alerts_stored"] += 1
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取统计"""
|
||
with self._lock:
|
||
return dict(self._performance_stats)
|
||
|
||
|
||
def create_alert_reporter(topic_prefix: str = "edge/alert") -> AlertReporter:
|
||
"""创建告警上报器"""
|
||
return AlertReporter(topic_prefix=topic_prefix)
|