- ResultReporter AlertInfo 添加 duration_minutes - main.py 使用 report_alert 替代 report_detection_alert - _store_alert 保存 duration_minutes
451 lines
14 KiB
Python
451 lines
14 KiB
Python
"""
|
||
结果上报模块
|
||
支持 MQTT 推送和本地 SQLite 存储
|
||
|
||
存储策略:
|
||
- MQTT: 实时推送告警到云端
|
||
- SQLite: 本地异步存储(边缘侧断网容灾)
|
||
- 断网时自动缓存到本地,恢复后自动同步
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import threading
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional, Callable
|
||
from dataclasses import dataclass, field
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import paho.mqtt.client as mqtt
|
||
from paho.mqtt.client import MQTTMessage
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class AlertInfo:
|
||
"""告警信息类"""
|
||
alert_id: str
|
||
camera_id: str
|
||
roi_id: str
|
||
alert_type: str
|
||
target_class: Optional[str] = None
|
||
confidence: Optional[float] = None
|
||
bbox: Optional[List[float]] = field(default_factory=list)
|
||
message: Optional[str] = None
|
||
screenshot: Optional[np.ndarray] = None
|
||
timestamp: datetime = field(default_factory=datetime.now)
|
||
duration_minutes: Optional[float] = None
|
||
|
||
def to_dict(self) -> Dict[str, Any]:
|
||
"""转换为字典"""
|
||
return {
|
||
"alert_id": self.alert_id,
|
||
"camera_id": self.camera_id,
|
||
"roi_id": self.roi_id,
|
||
"alert_type": self.alert_type,
|
||
"target_class": self.target_class,
|
||
"confidence": self.confidence,
|
||
"bbox": self.bbox,
|
||
"message": self.message,
|
||
"timestamp": self.timestamp.isoformat(),
|
||
"duration_minutes": self.duration_minutes,
|
||
}
|
||
|
||
|
||
class ResultReporter:
|
||
"""结果上报器类"""
|
||
|
||
def __init__(
|
||
self,
|
||
mqtt_broker: str = "localhost",
|
||
mqtt_port: int = 1883,
|
||
topic_prefix: str = "edge/alert",
|
||
mqtt_client_id: str = "edge_reporter"
|
||
):
|
||
self._mqtt_broker = mqtt_broker
|
||
self._mqtt_port = mqtt_port
|
||
self._topic_prefix = topic_prefix
|
||
self._mqtt_client_id = mqtt_client_id
|
||
|
||
self._client: Optional[mqtt.Client] = None
|
||
self._connected = False
|
||
self._reconnect_count = 0
|
||
self._lock = threading.Lock()
|
||
|
||
self._performance_stats = {
|
||
"alerts_generated": 0,
|
||
"alerts_sent": 0,
|
||
"alerts_stored": 0,
|
||
"send_failures": 0,
|
||
}
|
||
|
||
self._logger = logging.getLogger("result_reporter")
|
||
|
||
self._db_manager = None
|
||
self._image_storage = None
|
||
self._local_cache = None
|
||
|
||
self._logger.info("ResultReporter 初始化完成")
|
||
|
||
def initialize(self):
|
||
"""初始化存储和MQTT"""
|
||
from config.database import SQLiteManager, AlertRecord, get_sqlite_manager
|
||
from core.storage_manager import get_image_storage, get_local_cache
|
||
|
||
self._logger.info("初始化存储管理器...")
|
||
|
||
self._db_manager = get_sqlite_manager()
|
||
self._image_storage = get_image_storage()
|
||
self._local_cache = get_local_cache()
|
||
|
||
self._logger.info("存储管理器初始化完成")
|
||
|
||
self._init_mqtt()
|
||
|
||
def _init_mqtt(self):
|
||
"""初始化MQTT客户端"""
|
||
try:
|
||
self._client = mqtt.Client(
|
||
client_id=self._mqtt_client_id,
|
||
protocol=mqtt.MQTTv5,
|
||
callback_api_version=mqtt.CallbackAPIVersion.VERSION2
|
||
)
|
||
|
||
self._client.on_connect = self._on_connect
|
||
self._client.on_disconnect = self._on_disconnect
|
||
self._client.on_publish = self._on_publish
|
||
|
||
self._client.connect(self._mqtt_broker, self._mqtt_port, 60)
|
||
self._client.loop_start()
|
||
|
||
self._logger.info(f"MQTT 客户端初始化: {self._mqtt_broker}:{self._mqtt_port}")
|
||
|
||
except Exception as e:
|
||
self._logger.warning(f"MQTT 初始化失败: {e}")
|
||
self._client = None
|
||
|
||
def _on_connect(self, client, userdata, flags, reason_code, properties):
|
||
"""MQTT连接回调"""
|
||
if reason_code == 0:
|
||
self._connected = True
|
||
self._reconnect_count = 0
|
||
self._logger.info("MQTT 连接成功")
|
||
|
||
if self._local_cache:
|
||
pending = self._local_cache.get_pending_count()
|
||
if pending > 0:
|
||
self._logger.info(f"有 {pending} 条待同步的缓存告警")
|
||
else:
|
||
self._logger.warning(f"MQTT 连接失败: {reason_code}")
|
||
|
||
def _on_disconnect(self, client, userdata, reason_code, properties):
|
||
"""MQTT断开连接回调"""
|
||
self._connected = False
|
||
self._logger.warning(f"MQTT 连接断开: {reason_code}")
|
||
|
||
def _on_publish(self, client, userdata, mid, reason_code, properties):
|
||
"""MQTT发布回调"""
|
||
if reason_code == 0:
|
||
self._logger.debug(f"MQTT 消息发布成功: {mid}")
|
||
|
||
def report_alert(
|
||
self,
|
||
alert: AlertInfo,
|
||
store_to_db: bool = True,
|
||
screenshot: Optional[np.ndarray] = None
|
||
) -> bool:
|
||
"""
|
||
上报告警
|
||
|
||
Args:
|
||
alert: 告警信息
|
||
store_to_db: 是否存储到本地数据库
|
||
screenshot: 抓拍图片
|
||
|
||
Returns:
|
||
是否上报成功
|
||
"""
|
||
self._performance_stats["alerts_generated"] += 1
|
||
|
||
try:
|
||
if store_to_db and self._db_manager:
|
||
self._store_alert(alert, screenshot)
|
||
|
||
if self._connected and self._client:
|
||
self._publish_alert(alert)
|
||
else:
|
||
self._logger.warning("MQTT 未连接,消息已加入待发送队列")
|
||
if self._local_cache:
|
||
self._local_cache.cache_alert(alert.to_dict())
|
||
|
||
self._performance_stats["alerts_sent"] += 1
|
||
return True
|
||
|
||
except Exception as e:
|
||
self._performance_stats["send_failures"] += 1
|
||
self._logger.error(f"上报告警失败: {e}")
|
||
return False
|
||
|
||
def _store_alert(self, alert: AlertInfo, screenshot: Optional[np.ndarray] = None):
|
||
"""存储告警到本地数据库(异步)"""
|
||
image_path = None
|
||
|
||
if screenshot is not None and self._image_storage:
|
||
image_path = self._image_storage.save_capture(
|
||
image=screenshot,
|
||
camera_id=alert.camera_id,
|
||
alert_id=alert.alert_id,
|
||
timestamp=alert.timestamp
|
||
)
|
||
|
||
record = AlertRecord(
|
||
alert_id=alert.alert_id,
|
||
camera_id=alert.camera_id,
|
||
roi_id=alert.roi_id,
|
||
alert_type=alert.alert_type,
|
||
target_class=alert.target_class,
|
||
confidence=alert.confidence,
|
||
bbox=alert.bbox,
|
||
message=alert.message,
|
||
image_path=image_path,
|
||
status="pending",
|
||
created_at=alert.timestamp,
|
||
duration_minutes=alert.duration_minutes,
|
||
)
|
||
|
||
if self._db_manager:
|
||
self._db_manager.queue_alert(record)
|
||
self._performance_stats["alerts_stored"] += 1
|
||
|
||
def _publish_alert(self, alert: AlertInfo):
|
||
"""发布告警到MQTT"""
|
||
alert_data = alert.to_dict()
|
||
|
||
topic = f"{self._topic_prefix}/{alert.camera_id}/{alert.roi_id}"
|
||
|
||
result = self._client.publish(topic, json.dumps(alert_data, ensure_ascii=False))
|
||
|
||
if result[0] == mqtt.MQTT_ERR_SUCCESS:
|
||
self._logger.log_alert(
|
||
alert.alert_type,
|
||
alert.camera_id,
|
||
alert.roi_id,
|
||
alert.confidence
|
||
)
|
||
else:
|
||
raise Exception(f"MQTT 发布失败: {result[0]}")
|
||
|
||
def report_heartbeat(
|
||
self,
|
||
device_id: str,
|
||
status: Dict[str, Any]
|
||
) -> bool:
|
||
"""上报心跳"""
|
||
try:
|
||
heartbeat_data = {
|
||
"device_id": device_id,
|
||
"status": status,
|
||
"timestamp": datetime.now().isoformat(),
|
||
}
|
||
|
||
topic = f"{self._topic_prefix}/heartbeat/{device_id}"
|
||
|
||
if self._client and self._connected:
|
||
result = self._client.publish(topic, json.dumps(heartbeat_data, ensure_ascii=False))
|
||
if result[0] == mqtt.MQTT_ERR_SUCCESS:
|
||
self._logger.debug(f"心跳上报成功: {device_id}")
|
||
return True
|
||
|
||
self._logger.warning(f"心跳上报失败(MQTT未连接): {device_id}")
|
||
return False
|
||
except Exception as e:
|
||
self._logger.error(f"心跳上报异常: {e}")
|
||
return False
|
||
|
||
def get_pending_alerts(self) -> List[Dict[str, Any]]:
|
||
"""获取待同步的告警"""
|
||
if self._local_cache:
|
||
return self._local_cache.get_pending_alerts()
|
||
return []
|
||
|
||
def sync_pending_alerts(self) -> int:
|
||
"""同步待处理的告警到云端"""
|
||
if not self._connected or not self._client:
|
||
return 0
|
||
|
||
pending = self.get_pending_alerts()
|
||
synced = 0
|
||
|
||
for alert in pending:
|
||
try:
|
||
self._client.publish(
|
||
f"{self._topic_prefix}/{alert['camera_id']}/{alert['roi_id']}",
|
||
json.dumps(alert, ensure_ascii=False)
|
||
)
|
||
if self._local_cache:
|
||
self._local_cache.remove_cached(alert.get('_cache_id', ''))
|
||
synced += 1
|
||
except Exception as e:
|
||
self._logger.error(f"同步告警失败: {e}")
|
||
|
||
if synced > 0:
|
||
self._logger.info(f"已同步 {synced} 条告警到云端")
|
||
|
||
return synced
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取统计信息"""
|
||
stats = self._performance_stats.copy()
|
||
|
||
if self._db_manager:
|
||
db_stats = self._db_manager.get_statistics()
|
||
stats.update({
|
||
"db_total_alerts": db_stats.get("total_alerts", 0),
|
||
"db_pending_alerts": db_stats.get("pending_alerts", 0),
|
||
})
|
||
|
||
if self._image_storage:
|
||
img_stats = self._image_storage.get_statistics()
|
||
stats["image_saved"] = img_stats.get("saved_count", 0)
|
||
|
||
if self._local_cache:
|
||
stats["pending_sync"] = self._local_cache.get_pending_count()
|
||
|
||
stats["mqtt_connected"] = self._connected
|
||
|
||
return stats
|
||
|
||
def cleanup(self):
|
||
"""清理资源"""
|
||
self._logger.info("ResultReporter 资源清理")
|
||
|
||
if self._image_storage:
|
||
self._image_storage.close()
|
||
|
||
if self._db_manager:
|
||
self._db_manager.close()
|
||
|
||
if self._client:
|
||
self._client.loop_stop()
|
||
self._client.disconnect()
|
||
|
||
def close(self):
|
||
"""关闭上报器(别名)"""
|
||
self.cleanup()
|
||
|
||
if self._client:
|
||
self._client.loop_stop()
|
||
self._client.disconnect()
|
||
|
||
self._logger.info("ResultReporter 清理完成")
|
||
|
||
|
||
class AlertReporter:
|
||
"""告警上报器(简化版)"""
|
||
|
||
def __init__(self, topic_prefix: str = "edge/alert"):
|
||
self._topic_prefix = topic_prefix
|
||
self._lock = threading.Lock()
|
||
self._performance_stats = {
|
||
"alerts_generated": 0,
|
||
"alerts_sent": 0,
|
||
"alerts_stored": 0,
|
||
"send_failures": 0,
|
||
}
|
||
self._logger = logging.getLogger("alert_reporter")
|
||
|
||
def initialize(self):
|
||
"""初始化"""
|
||
from config.database import get_sqlite_manager
|
||
from core.storage_manager import get_image_storage
|
||
|
||
self._db_manager = get_sqlite_manager()
|
||
self._image_storage = get_image_storage()
|
||
self._logger.info("AlertReporter 初始化完成")
|
||
|
||
def report_alert(
|
||
self,
|
||
alert: AlertInfo,
|
||
store_to_db: bool = True,
|
||
screenshot: Optional[np.ndarray] = None
|
||
) -> bool:
|
||
"""
|
||
上报告警
|
||
|
||
Args:
|
||
alert: 告警信息
|
||
store_to_db: 是否存储到本地数据库
|
||
screenshot: 抓拍图片
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
with self._lock:
|
||
self._performance_stats["alerts_generated"] += 1
|
||
|
||
try:
|
||
if store_to_db and self._db_manager:
|
||
self._store_alert(alert, screenshot)
|
||
|
||
self._logger.info(
|
||
f"告警已记录: {alert.alert_type} - {alert.camera_id}/{alert.roi_id}"
|
||
)
|
||
|
||
with self._lock:
|
||
self._performance_stats["alerts_sent"] += 1
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
with self._lock:
|
||
self._performance_stats["send_failures"] += 1
|
||
self._logger.error(f"上报告警失败: {e}")
|
||
return False
|
||
|
||
def _store_alert(self, alert: AlertInfo, screenshot: Optional[np.ndarray] = None):
|
||
"""存储告警"""
|
||
from config.database import AlertRecord, get_sqlite_manager
|
||
|
||
image_path = None
|
||
|
||
if screenshot is not None and self._image_storage:
|
||
image_path = self._image_storage.save_capture(
|
||
image=screenshot,
|
||
camera_id=alert.camera_id,
|
||
alert_id=alert.alert_id,
|
||
timestamp=alert.timestamp
|
||
)
|
||
|
||
record = AlertRecord(
|
||
alert_id=alert.alert_id,
|
||
camera_id=alert.camera_id,
|
||
roi_id=alert.roi_id,
|
||
alert_type=alert.alert_type,
|
||
target_class=alert.target_class,
|
||
confidence=alert.confidence,
|
||
bbox=alert.bbox,
|
||
message=alert.message,
|
||
image_path=image_path,
|
||
status="pending",
|
||
created_at=alert.timestamp,
|
||
)
|
||
|
||
db_manager = get_sqlite_manager()
|
||
db_manager.queue_alert(record)
|
||
|
||
with self._lock:
|
||
self._performance_stats["alerts_stored"] += 1
|
||
|
||
def get_statistics(self) -> Dict[str, Any]:
|
||
"""获取统计"""
|
||
with self._lock:
|
||
return dict(self._performance_stats)
|
||
|
||
|
||
def create_alert_reporter(topic_prefix: str = "edge/alert") -> AlertReporter:
|
||
"""创建告警上报器"""
|
||
return AlertReporter(topic_prefix=topic_prefix)
|