- Moved all project files and directories (config, core, models, etc.) from edge_inference_service/ to the repository root ai_edge/ - Updated model path in config/settings.py to reflect new structure - Revised usage paths in __init__.py documentation
660 lines
19 KiB
Python
660 lines
19 KiB
Python
"""
|
||
结果上报模块
|
||
实现MQTT客户端、告警存储、重试机制
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import threading
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Any, Callable, Dict, List, Optional
|
||
|
||
import numpy as np
|
||
import paho.mqtt.client as mqtt
|
||
from paho.mqtt.client import MQTTMessage
|
||
|
||
from config.settings import get_settings, MQTTConfig
|
||
from config.database import get_database_manager
|
||
from config.config_models import AlertInfo
|
||
from utils.logger import get_logger
|
||
from utils.common import generate_unique_id, ExponentialBackoff
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class MQTTClient:
|
||
"""MQTT客户端类
|
||
|
||
实现与云端业务平台的可靠连接
|
||
"""
|
||
|
||
def __init__(self, config: Optional[MQTTConfig] = None):
|
||
"""
|
||
初始化MQTT客户端
|
||
|
||
Args:
|
||
config: MQTT配置
|
||
"""
|
||
if config is None:
|
||
settings = get_settings()
|
||
config = settings.mqtt
|
||
|
||
self.config = config
|
||
self._client = None
|
||
self._connected = False
|
||
self._reconnect_thread: Optional[threading.Thread] = None
|
||
self._stop_event = threading.Event()
|
||
|
||
self._message_callbacks: Dict[str, Callable] = {}
|
||
self._pending_messages: List[Dict] = []
|
||
self._pending_lock = threading.Lock()
|
||
|
||
self._logger = get_logger("mqtt")
|
||
self._lock = threading.Lock()
|
||
|
||
self._performance_stats = {
|
||
"messages_sent": 0,
|
||
"messages_received": 0,
|
||
"send_errors": 0,
|
||
"reconnect_count": 0,
|
||
}
|
||
|
||
def _on_connect(self, client, userdata, flags, rc):
|
||
"""连接回调"""
|
||
if rc == 0:
|
||
self._connected = True
|
||
self._logger.log_connection_event(
|
||
"connect", "MQTT", self.config.broker_host, True
|
||
)
|
||
self._logger.info(f"MQTT连接成功: {self.config.broker_host}:{self.config.broker_port}")
|
||
|
||
self._resend_pending_messages()
|
||
else:
|
||
self._connected = False
|
||
self._logger.log_connection_event(
|
||
"connect", "MQTT", self.config.broker_host, False,
|
||
error_msg=f"错误码: {rc}"
|
||
)
|
||
|
||
def _on_disconnect(self, client, userdata, rc):
|
||
"""断开连接回调"""
|
||
self._connected = False
|
||
self._logger.log_connection_event(
|
||
"disconnect", "MQTT", self.config.broker_host, False
|
||
)
|
||
|
||
if rc != 0:
|
||
self._performance_stats["reconnect_count"] += 1
|
||
self._logger.warning(f"MQTT意外断开,尝试重连: {rc}")
|
||
self._start_reconnect()
|
||
|
||
def _on_message(self, client, userdata, msg: MQTTMessage):
|
||
"""消息回调"""
|
||
self._performance_stats["messages_received"] += 1
|
||
|
||
topic = msg.topic
|
||
payload = msg.payload.decode('utf-8')
|
||
|
||
self._logger.debug(f"收到消息: {topic} -> {payload}")
|
||
|
||
if topic in self._message_callbacks:
|
||
try:
|
||
self._message_callbacks[topic](payload)
|
||
except Exception as e:
|
||
self._logger.error(f"消息回调执行失败: {e}")
|
||
|
||
def _on_publish(self, client, userdata, mid):
|
||
"""发布回调"""
|
||
self._performance_stats["messages_sent"] += 1
|
||
self._logger.debug(f"消息已发布: mid={mid}")
|
||
|
||
def _create_client(self) -> mqtt.Client:
|
||
"""创建MQTT客户端"""
|
||
client = mqtt.Client(
|
||
client_id=self.config.client_id,
|
||
clean_session=True,
|
||
userdata=None,
|
||
protocol=mqtt.MQTTv311,
|
||
transport="tcp"
|
||
)
|
||
|
||
client.on_connect = self._on_connect
|
||
client.on_disconnect = self._on_disconnect
|
||
client.on_message = self._on_message
|
||
client.on_publish = self._on_publish
|
||
|
||
if self.config.username and self.config.password:
|
||
client.username_pw_set(
|
||
self.config.username,
|
||
self.config.password
|
||
)
|
||
|
||
client.reconnect_delay_set(
|
||
min_delay=self.config.reconnect_delay,
|
||
max_delay=self.config.reconnect_delay * 2
|
||
)
|
||
|
||
return client
|
||
|
||
def _start_reconnect(self):
|
||
"""启动重连线程"""
|
||
if self._reconnect_thread is not None and self._reconnect_thread.is_alive():
|
||
return
|
||
|
||
self._stop_event.clear()
|
||
self._reconnect_thread = threading.Thread(
|
||
target=self._reconnect_loop,
|
||
name="MQTTReconnect",
|
||
daemon=True
|
||
)
|
||
self._reconnect_thread.start()
|
||
|
||
def _reconnect_loop(self):
|
||
"""重连循环"""
|
||
backoff = ExponentialBackoff(
|
||
base_delay=self.config.reconnect_delay,
|
||
max_delay=self.config.reconnect_delay * 3,
|
||
max_attempts=self.config.max_reconnect_attempts
|
||
)
|
||
|
||
while not self._stop_event.is_set():
|
||
if self._connected:
|
||
break
|
||
|
||
if not backoff.next_attempt():
|
||
self._logger.error("MQTT重连次数已达上限")
|
||
break
|
||
|
||
delay = backoff.get_delay()
|
||
self._logger.info(f"等待 {delay:.1f}秒后尝试重连MQTT...")
|
||
time.sleep(delay)
|
||
|
||
try:
|
||
self._client.connect(
|
||
self.config.broker_host,
|
||
self.config.broker_port,
|
||
self.config.keepalive
|
||
)
|
||
self._client.loop_start()
|
||
except Exception as e:
|
||
self._logger.error(f"MQTT重连失败: {e}")
|
||
|
||
def _resend_pending_messages(self):
|
||
"""重发待发送消息"""
|
||
with self._pending_lock:
|
||
if not self._pending_messages:
|
||
return
|
||
|
||
self._logger.info(f"重发待处理消息: {len(self._pending_messages)}条")
|
||
|
||
failed_messages = []
|
||
|
||
for msg in self._pending_messages:
|
||
try:
|
||
self._publish_message(
|
||
msg["topic"],
|
||
msg["payload"],
|
||
msg.get("qos", self.config.qos),
|
||
msg.get("retain", False)
|
||
)
|
||
except Exception:
|
||
failed_messages.append(msg)
|
||
|
||
self._pending_messages = failed_messages
|
||
|
||
def connect(self) -> bool:
|
||
"""
|
||
连接到MQTT代理
|
||
|
||
Returns:
|
||
是否连接成功
|
||
"""
|
||
with self._lock:
|
||
try:
|
||
if self._client is not None:
|
||
self._client.loop_stop()
|
||
|
||
self._client = self._create_client()
|
||
|
||
self._client.connect(
|
||
self.config.broker_host,
|
||
self.config.broker_port,
|
||
self.config.keepalive
|
||
)
|
||
self._client.loop_start()
|
||
|
||
self._logger.info(
|
||
f"MQTT正在连接: {self.config.broker_host}:{self.config.broker_port}"
|
||
)
|
||
return True
|
||
|
||
except Exception as e:
|
||
self._logger.error(f"MQTT连接失败: {e}")
|
||
return False
|
||
|
||
def disconnect(self):
|
||
"""断开MQTT连接"""
|
||
with self._lock:
|
||
if self._client:
|
||
self._stop_event.set()
|
||
self._client.loop_stop()
|
||
self._client.disconnect()
|
||
self._client = None
|
||
self._connected = False
|
||
self._logger.info("MQTT连接已断开")
|
||
|
||
def subscribe(self, topic: str, qos: int = 1):
|
||
"""
|
||
订阅主题
|
||
|
||
Args:
|
||
topic: 主题
|
||
qos: 服务质量等级
|
||
"""
|
||
if self._client and self._connected:
|
||
self._client.subscribe(topic, qos)
|
||
self._logger.info(f"已订阅主题: {topic}")
|
||
|
||
def publish(
|
||
self,
|
||
topic: str,
|
||
payload: Dict[str, Any],
|
||
qos: Optional[int] = None,
|
||
retain: bool = False
|
||
) -> bool:
|
||
"""
|
||
发布消息
|
||
|
||
Args:
|
||
topic: 主题
|
||
payload: 消息负载
|
||
qos: 服务质量等级
|
||
retain: 保留消息
|
||
|
||
Returns:
|
||
是否发布成功
|
||
"""
|
||
if qos is None:
|
||
qos = self.config.qos
|
||
|
||
try:
|
||
message = json.dumps(payload, ensure_ascii=False)
|
||
|
||
if self._connected and self._client:
|
||
self._publish_message(topic, message, qos, retain)
|
||
return True
|
||
else:
|
||
self._queue_message(topic, payload, qos, retain)
|
||
return False
|
||
|
||
except Exception as e:
|
||
self._performance_stats["send_errors"] += 1
|
||
self._logger.error(f"发布消息失败: {e}")
|
||
return False
|
||
|
||
def _publish_message(
|
||
self,
|
||
topic: str,
|
||
message: str,
|
||
qos: int,
|
||
retain: bool
|
||
):
|
||
"""内部发布方法"""
|
||
self._client.publish(topic, message, qos, retain)
|
||
self._logger.debug(f"消息已发布: {topic}")
|
||
|
||
def _queue_message(
|
||
self,
|
||
topic: str,
|
||
payload: Dict[str, Any],
|
||
qos: int,
|
||
retain: bool
|
||
):
|
||
"""将消息加入待发送队列"""
|
||
with self._pending_lock:
|
||
self._pending_messages.append({
|
||
"topic": topic,
|
||
"payload": payload,
|
||
"qos": qos,
|
||
"retain": retain,
|
||
"timestamp": datetime.now().isoformat(),
|
||
})
|
||
|
||
self._logger.warning(f"消息已加入待发送队列: {topic}")
|
||
|
||
def register_message_callback(self, topic: str, callback: Callable):
|
||
"""注册消息回调"""
|
||
self._message_callbacks[topic] = callback
|
||
self._logger.info(f"已注册消息回调: {topic}")
|
||
|
||
def get_performance_stats(self) -> Dict[str, int]:
|
||
"""获取性能统计"""
|
||
return dict(self._performance_stats)
|
||
|
||
def get_status(self) -> Dict[str, Any]:
|
||
"""获取连接状态"""
|
||
return {
|
||
"connected": self._connected,
|
||
"broker_host": self.config.broker_host,
|
||
"broker_port": self.config.broker_port,
|
||
"pending_messages": len(self._pending_messages),
|
||
"stats": self.get_performance_stats(),
|
||
}
|
||
|
||
def close(self):
|
||
"""关闭客户端"""
|
||
self.disconnect()
|
||
self._logger.info("MQTT客户端已关闭")
|
||
|
||
|
||
class AlertReporter:
|
||
"""告警上报器类
|
||
|
||
整合MQTT上报、数据库存储、重试机制
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._mqtt_client = None
|
||
self._db_manager = None
|
||
self._logger = get_logger("reporter")
|
||
self._lock = threading.Lock()
|
||
|
||
self._topic_prefix = "edge/alerts"
|
||
|
||
self._performance_stats = {
|
||
"alerts_generated": 0,
|
||
"alerts_sent": 0,
|
||
"alerts_stored": 0,
|
||
"send_failures": 0,
|
||
}
|
||
|
||
def _init_mqtt(self):
|
||
"""初始化MQTT客户端"""
|
||
if self._mqtt_client is None:
|
||
self._mqtt_client = MQTTClient()
|
||
self._mqtt_client.connect()
|
||
|
||
def _init_database(self):
|
||
"""初始化数据库连接"""
|
||
if self._db_manager is None:
|
||
self._db_manager = get_database_manager()
|
||
|
||
def report_alert(
|
||
self,
|
||
alert: AlertInfo,
|
||
screenshot: Optional[np.ndarray] = None,
|
||
store_to_db: bool = True
|
||
) -> bool:
|
||
"""
|
||
上报告警
|
||
|
||
Args:
|
||
alert: 告警信息
|
||
screenshot: 截图 (可选)
|
||
store_to_db: 是否存储到数据库
|
||
|
||
Returns:
|
||
是否上报成功
|
||
"""
|
||
with self._lock:
|
||
self._performance_stats["alerts_generated"] += 1
|
||
|
||
try:
|
||
self._init_mqtt()
|
||
self._init_database()
|
||
|
||
alert_data = alert.to_dict()
|
||
|
||
if screenshot is not None:
|
||
from utils.common import image_to_base64
|
||
alert_data["screenshot"] = image_to_base64(screenshot)
|
||
|
||
topic = f"{self._topic_prefix}/{alert.camera_id}/{alert.roi_id}"
|
||
|
||
success = self._mqtt_client.publish(topic, alert_data)
|
||
|
||
if store_to_db:
|
||
db_success = self._store_to_database(alert.to_dict())
|
||
if db_success:
|
||
with self._lock:
|
||
self._performance_stats["alerts_stored"] += 1
|
||
|
||
if success:
|
||
with self._lock:
|
||
self._performance_stats["alerts_sent"] += 1
|
||
self._logger.log_alert(
|
||
alert.alert_type,
|
||
alert.camera_id,
|
||
alert.roi_id,
|
||
alert.confidence
|
||
)
|
||
else:
|
||
with self._lock:
|
||
self._performance_stats["send_failures"] += 1
|
||
|
||
return success
|
||
|
||
except Exception as e:
|
||
with self._lock:
|
||
self._performance_stats["send_failures"] += 1
|
||
self._logger.error(f"上报告警失败: {e}")
|
||
return False
|
||
|
||
def _store_to_database(self, alert_data: Dict[str, Any]) -> bool:
|
||
"""存储告警到数据库"""
|
||
try:
|
||
return self._db_manager.save_alert_record(alert_data)
|
||
except Exception as e:
|
||
self._logger.error(f"存储告警到数据库失败: {e}")
|
||
return False
|
||
|
||
def batch_report(
|
||
self,
|
||
alerts: List[AlertInfo],
|
||
screenshots: Optional[Dict[str, np.ndarray]] = None
|
||
) -> Dict[str, int]:
|
||
"""
|
||
批量上报告警
|
||
|
||
Args:
|
||
alerts: 告警列表
|
||
screenshots: 截图字典 (alert_id -> screenshot)
|
||
|
||
Returns:
|
||
上报结果统计
|
||
"""
|
||
results = {
|
||
"total": len(alerts),
|
||
"success": 0,
|
||
"failed": 0,
|
||
}
|
||
|
||
screenshots = screenshots or {}
|
||
|
||
for alert in alerts:
|
||
screenshot = screenshots.get(alert.alert_id)
|
||
if self.report_alert(alert, screenshot):
|
||
results["success"] += 1
|
||
else:
|
||
results["failed"] += 1
|
||
|
||
return results
|
||
|
||
def report_heartbeat(self, device_id: str, status: Dict[str, Any]) -> bool:
|
||
"""
|
||
上报心跳
|
||
|
||
Args:
|
||
device_id: 设备ID
|
||
status: 状态信息
|
||
|
||
Returns:
|
||
是否上报成功
|
||
"""
|
||
try:
|
||
self._init_mqtt()
|
||
|
||
heartbeat_data = {
|
||
"device_id": device_id,
|
||
"timestamp": datetime.now().isoformat(),
|
||
"status": status,
|
||
}
|
||
|
||
topic = f"edge/heartbeat/{device_id}"
|
||
return self._mqtt_client.publish(topic, heartbeat_data)
|
||
|
||
except Exception as e:
|
||
self._logger.error(f"上报心跳失败: {e}")
|
||
return False
|
||
|
||
def get_performance_stats(self) -> Dict[str, int]:
|
||
"""获取性能统计"""
|
||
with self._lock:
|
||
return dict(self._performance_stats)
|
||
|
||
def get_status(self) -> Dict[str, Any]:
|
||
"""获取状态"""
|
||
return {
|
||
"mqtt": self._mqtt_client.get_status() if self._mqtt_client else None,
|
||
"stats": self.get_performance_stats(),
|
||
}
|
||
|
||
def close(self):
|
||
"""关闭资源"""
|
||
if self._mqtt_client:
|
||
self._mqtt_client.close()
|
||
self._mqtt_client = None
|
||
|
||
|
||
class ResultReporter:
|
||
"""结果上报主类
|
||
|
||
统一管理告警上报
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._alert_reporter = AlertReporter()
|
||
self._logger = get_logger("reporter")
|
||
|
||
def report_detection_alert(
|
||
self,
|
||
camera_id: str,
|
||
roi_id: str,
|
||
alert_type: str,
|
||
detection: Dict[str, Any],
|
||
screenshot: Optional[np.ndarray] = None
|
||
) -> bool:
|
||
"""
|
||
上报检测告警
|
||
|
||
Args:
|
||
camera_id: 摄像头ID
|
||
roi_id: ROI ID
|
||
alert_type: 告警类型
|
||
detection: 检测结果
|
||
screenshot: 截图
|
||
|
||
Returns:
|
||
是否上报成功
|
||
"""
|
||
alert = AlertInfo(
|
||
alert_id=generate_unique_id("alert"),
|
||
camera_id=camera_id,
|
||
roi_id=roi_id,
|
||
alert_type=alert_type,
|
||
target_class=detection.get("class_name"),
|
||
confidence=detection.get("confidence"),
|
||
bbox=detection.get("bbox"),
|
||
message=detection.get("message", f"检测到{alert_type}"),
|
||
timestamp=datetime.now().isoformat(),
|
||
)
|
||
|
||
return self._alert_reporter.report_alert(alert, screenshot)
|
||
|
||
def report_leave_post_alert(
|
||
self,
|
||
camera_id: str,
|
||
roi_id: str,
|
||
duration_seconds: int,
|
||
screenshot: Optional[np.ndarray] = None
|
||
) -> bool:
|
||
"""
|
||
上报离岗告警
|
||
|
||
Args:
|
||
camera_id: 摄像头ID
|
||
roi_id: ROI ID
|
||
duration_seconds: 离岗时长(秒)
|
||
screenshot: 截图
|
||
|
||
Returns:
|
||
是否上报成功
|
||
"""
|
||
detection = {
|
||
"class_name": "person",
|
||
"confidence": None,
|
||
"message": f"离岗告警: 已离开{duration_seconds}秒",
|
||
}
|
||
|
||
return self.report_detection_alert(
|
||
camera_id=camera_id,
|
||
roi_id=roi_id,
|
||
alert_type="leave_post",
|
||
detection=detection,
|
||
screenshot=screenshot
|
||
)
|
||
|
||
def report_intrusion_alert(
|
||
self,
|
||
camera_id: str,
|
||
roi_id: str,
|
||
confidence: float,
|
||
bbox: List[float],
|
||
screenshot: Optional[np.ndarray] = None
|
||
) -> bool:
|
||
"""
|
||
上报入侵告警
|
||
|
||
Args:
|
||
camera_id: 摄像头ID
|
||
roi_id: ROI ID
|
||
confidence: 置信度
|
||
bbox: 检测框
|
||
screenshot: 截图
|
||
|
||
Returns:
|
||
是否上报成功
|
||
"""
|
||
detection = {
|
||
"class_name": "intrusion",
|
||
"confidence": confidence,
|
||
"bbox": bbox,
|
||
"message": "检测到周界入侵",
|
||
}
|
||
|
||
return self.report_detection_alert(
|
||
camera_id=camera_id,
|
||
roi_id=roi_id,
|
||
alert_type="intrusion",
|
||
detection=detection,
|
||
screenshot=screenshot
|
||
)
|
||
|
||
def report_heartbeat(self, device_id: str, status: Dict[str, Any]) -> bool:
|
||
"""上报心跳"""
|
||
return self._alert_reporter.report_heartbeat(device_id, status)
|
||
|
||
def get_status(self) -> Dict[str, Any]:
|
||
"""获取状态"""
|
||
return self._alert_reporter.get_status()
|
||
|
||
def close(self):
|
||
"""关闭资源"""
|
||
self._alert_reporter.close()
|
||
|
||
|
||
def get_result_reporter() -> ResultReporter:
|
||
"""获取结果上报器单例"""
|
||
return ResultReporter()
|