feat(project): move edge_inference_service contents to root and update paths
- Moved all project files and directories (config, core, models, etc.) from edge_inference_service/ to the repository root ai_edge/ - Updated model path in config/settings.py to reflect new structure - Revised usage paths in __init__.py documentation
This commit is contained in:
659
core/result_reporter.py
Normal file
659
core/result_reporter.py
Normal file
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
结果上报模块
|
||||
实现MQTT客户端、告警存储、重试机制
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import paho.mqtt.client as mqtt
|
||||
from paho.mqtt.client import MQTTMessage
|
||||
|
||||
from config.settings import get_settings, MQTTConfig
|
||||
from config.database import get_database_manager
|
||||
from config.config_models import AlertInfo
|
||||
from utils.logger import get_logger
|
||||
from utils.common import generate_unique_id, ExponentialBackoff
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MQTTClient:
|
||||
"""MQTT客户端类
|
||||
|
||||
实现与云端业务平台的可靠连接
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[MQTTConfig] = None):
|
||||
"""
|
||||
初始化MQTT客户端
|
||||
|
||||
Args:
|
||||
config: MQTT配置
|
||||
"""
|
||||
if config is None:
|
||||
settings = get_settings()
|
||||
config = settings.mqtt
|
||||
|
||||
self.config = config
|
||||
self._client = None
|
||||
self._connected = False
|
||||
self._reconnect_thread: Optional[threading.Thread] = None
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
self._message_callbacks: Dict[str, Callable] = {}
|
||||
self._pending_messages: List[Dict] = []
|
||||
self._pending_lock = threading.Lock()
|
||||
|
||||
self._logger = get_logger("mqtt")
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._performance_stats = {
|
||||
"messages_sent": 0,
|
||||
"messages_received": 0,
|
||||
"send_errors": 0,
|
||||
"reconnect_count": 0,
|
||||
}
|
||||
|
||||
def _on_connect(self, client, userdata, flags, rc):
|
||||
"""连接回调"""
|
||||
if rc == 0:
|
||||
self._connected = True
|
||||
self._logger.log_connection_event(
|
||||
"connect", "MQTT", self.config.broker_host, True
|
||||
)
|
||||
self._logger.info(f"MQTT连接成功: {self.config.broker_host}:{self.config.broker_port}")
|
||||
|
||||
self._resend_pending_messages()
|
||||
else:
|
||||
self._connected = False
|
||||
self._logger.log_connection_event(
|
||||
"connect", "MQTT", self.config.broker_host, False,
|
||||
error_msg=f"错误码: {rc}"
|
||||
)
|
||||
|
||||
def _on_disconnect(self, client, userdata, rc):
|
||||
"""断开连接回调"""
|
||||
self._connected = False
|
||||
self._logger.log_connection_event(
|
||||
"disconnect", "MQTT", self.config.broker_host, False
|
||||
)
|
||||
|
||||
if rc != 0:
|
||||
self._performance_stats["reconnect_count"] += 1
|
||||
self._logger.warning(f"MQTT意外断开,尝试重连: {rc}")
|
||||
self._start_reconnect()
|
||||
|
||||
def _on_message(self, client, userdata, msg: MQTTMessage):
|
||||
"""消息回调"""
|
||||
self._performance_stats["messages_received"] += 1
|
||||
|
||||
topic = msg.topic
|
||||
payload = msg.payload.decode('utf-8')
|
||||
|
||||
self._logger.debug(f"收到消息: {topic} -> {payload}")
|
||||
|
||||
if topic in self._message_callbacks:
|
||||
try:
|
||||
self._message_callbacks[topic](payload)
|
||||
except Exception as e:
|
||||
self._logger.error(f"消息回调执行失败: {e}")
|
||||
|
||||
def _on_publish(self, client, userdata, mid):
|
||||
"""发布回调"""
|
||||
self._performance_stats["messages_sent"] += 1
|
||||
self._logger.debug(f"消息已发布: mid={mid}")
|
||||
|
||||
def _create_client(self) -> mqtt.Client:
|
||||
"""创建MQTT客户端"""
|
||||
client = mqtt.Client(
|
||||
client_id=self.config.client_id,
|
||||
clean_session=True,
|
||||
userdata=None,
|
||||
protocol=mqtt.MQTTv311,
|
||||
transport="tcp"
|
||||
)
|
||||
|
||||
client.on_connect = self._on_connect
|
||||
client.on_disconnect = self._on_disconnect
|
||||
client.on_message = self._on_message
|
||||
client.on_publish = self._on_publish
|
||||
|
||||
if self.config.username and self.config.password:
|
||||
client.username_pw_set(
|
||||
self.config.username,
|
||||
self.config.password
|
||||
)
|
||||
|
||||
client.reconnect_delay_set(
|
||||
min_delay=self.config.reconnect_delay,
|
||||
max_delay=self.config.reconnect_delay * 2
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _start_reconnect(self):
|
||||
"""启动重连线程"""
|
||||
if self._reconnect_thread is not None and self._reconnect_thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._reconnect_thread = threading.Thread(
|
||||
target=self._reconnect_loop,
|
||||
name="MQTTReconnect",
|
||||
daemon=True
|
||||
)
|
||||
self._reconnect_thread.start()
|
||||
|
||||
def _reconnect_loop(self):
|
||||
"""重连循环"""
|
||||
backoff = ExponentialBackoff(
|
||||
base_delay=self.config.reconnect_delay,
|
||||
max_delay=self.config.reconnect_delay * 3,
|
||||
max_attempts=self.config.max_reconnect_attempts
|
||||
)
|
||||
|
||||
while not self._stop_event.is_set():
|
||||
if self._connected:
|
||||
break
|
||||
|
||||
if not backoff.next_attempt():
|
||||
self._logger.error("MQTT重连次数已达上限")
|
||||
break
|
||||
|
||||
delay = backoff.get_delay()
|
||||
self._logger.info(f"等待 {delay:.1f}秒后尝试重连MQTT...")
|
||||
time.sleep(delay)
|
||||
|
||||
try:
|
||||
self._client.connect(
|
||||
self.config.broker_host,
|
||||
self.config.broker_port,
|
||||
self.config.keepalive
|
||||
)
|
||||
self._client.loop_start()
|
||||
except Exception as e:
|
||||
self._logger.error(f"MQTT重连失败: {e}")
|
||||
|
||||
def _resend_pending_messages(self):
|
||||
"""重发待发送消息"""
|
||||
with self._pending_lock:
|
||||
if not self._pending_messages:
|
||||
return
|
||||
|
||||
self._logger.info(f"重发待处理消息: {len(self._pending_messages)}条")
|
||||
|
||||
failed_messages = []
|
||||
|
||||
for msg in self._pending_messages:
|
||||
try:
|
||||
self._publish_message(
|
||||
msg["topic"],
|
||||
msg["payload"],
|
||||
msg.get("qos", self.config.qos),
|
||||
msg.get("retain", False)
|
||||
)
|
||||
except Exception:
|
||||
failed_messages.append(msg)
|
||||
|
||||
self._pending_messages = failed_messages
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""
|
||||
连接到MQTT代理
|
||||
|
||||
Returns:
|
||||
是否连接成功
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
if self._client is not None:
|
||||
self._client.loop_stop()
|
||||
|
||||
self._client = self._create_client()
|
||||
|
||||
self._client.connect(
|
||||
self.config.broker_host,
|
||||
self.config.broker_port,
|
||||
self.config.keepalive
|
||||
)
|
||||
self._client.loop_start()
|
||||
|
||||
self._logger.info(
|
||||
f"MQTT正在连接: {self.config.broker_host}:{self.config.broker_port}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"MQTT连接失败: {e}")
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""断开MQTT连接"""
|
||||
with self._lock:
|
||||
if self._client:
|
||||
self._stop_event.set()
|
||||
self._client.loop_stop()
|
||||
self._client.disconnect()
|
||||
self._client = None
|
||||
self._connected = False
|
||||
self._logger.info("MQTT连接已断开")
|
||||
|
||||
def subscribe(self, topic: str, qos: int = 1):
|
||||
"""
|
||||
订阅主题
|
||||
|
||||
Args:
|
||||
topic: 主题
|
||||
qos: 服务质量等级
|
||||
"""
|
||||
if self._client and self._connected:
|
||||
self._client.subscribe(topic, qos)
|
||||
self._logger.info(f"已订阅主题: {topic}")
|
||||
|
||||
def publish(
|
||||
self,
|
||||
topic: str,
|
||||
payload: Dict[str, Any],
|
||||
qos: Optional[int] = None,
|
||||
retain: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
发布消息
|
||||
|
||||
Args:
|
||||
topic: 主题
|
||||
payload: 消息负载
|
||||
qos: 服务质量等级
|
||||
retain: 保留消息
|
||||
|
||||
Returns:
|
||||
是否发布成功
|
||||
"""
|
||||
if qos is None:
|
||||
qos = self.config.qos
|
||||
|
||||
try:
|
||||
message = json.dumps(payload, ensure_ascii=False)
|
||||
|
||||
if self._connected and self._client:
|
||||
self._publish_message(topic, message, qos, retain)
|
||||
return True
|
||||
else:
|
||||
self._queue_message(topic, payload, qos, retain)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self._performance_stats["send_errors"] += 1
|
||||
self._logger.error(f"发布消息失败: {e}")
|
||||
return False
|
||||
|
||||
def _publish_message(
|
||||
self,
|
||||
topic: str,
|
||||
message: str,
|
||||
qos: int,
|
||||
retain: bool
|
||||
):
|
||||
"""内部发布方法"""
|
||||
self._client.publish(topic, message, qos, retain)
|
||||
self._logger.debug(f"消息已发布: {topic}")
|
||||
|
||||
def _queue_message(
|
||||
self,
|
||||
topic: str,
|
||||
payload: Dict[str, Any],
|
||||
qos: int,
|
||||
retain: bool
|
||||
):
|
||||
"""将消息加入待发送队列"""
|
||||
with self._pending_lock:
|
||||
self._pending_messages.append({
|
||||
"topic": topic,
|
||||
"payload": payload,
|
||||
"qos": qos,
|
||||
"retain": retain,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
|
||||
self._logger.warning(f"消息已加入待发送队列: {topic}")
|
||||
|
||||
def register_message_callback(self, topic: str, callback: Callable):
|
||||
"""注册消息回调"""
|
||||
self._message_callbacks[topic] = callback
|
||||
self._logger.info(f"已注册消息回调: {topic}")
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, int]:
|
||||
"""获取性能统计"""
|
||||
return dict(self._performance_stats)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取连接状态"""
|
||||
return {
|
||||
"connected": self._connected,
|
||||
"broker_host": self.config.broker_host,
|
||||
"broker_port": self.config.broker_port,
|
||||
"pending_messages": len(self._pending_messages),
|
||||
"stats": self.get_performance_stats(),
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""关闭客户端"""
|
||||
self.disconnect()
|
||||
self._logger.info("MQTT客户端已关闭")
|
||||
|
||||
|
||||
class AlertReporter:
|
||||
"""告警上报器类
|
||||
|
||||
整合MQTT上报、数据库存储、重试机制
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._mqtt_client = None
|
||||
self._db_manager = None
|
||||
self._logger = get_logger("reporter")
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._topic_prefix = "edge/alerts"
|
||||
|
||||
self._performance_stats = {
|
||||
"alerts_generated": 0,
|
||||
"alerts_sent": 0,
|
||||
"alerts_stored": 0,
|
||||
"send_failures": 0,
|
||||
}
|
||||
|
||||
def _init_mqtt(self):
|
||||
"""初始化MQTT客户端"""
|
||||
if self._mqtt_client is None:
|
||||
self._mqtt_client = MQTTClient()
|
||||
self._mqtt_client.connect()
|
||||
|
||||
def _init_database(self):
|
||||
"""初始化数据库连接"""
|
||||
if self._db_manager is None:
|
||||
self._db_manager = get_database_manager()
|
||||
|
||||
def report_alert(
|
||||
self,
|
||||
alert: AlertInfo,
|
||||
screenshot: Optional[np.ndarray] = None,
|
||||
store_to_db: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
上报告警
|
||||
|
||||
Args:
|
||||
alert: 告警信息
|
||||
screenshot: 截图 (可选)
|
||||
store_to_db: 是否存储到数据库
|
||||
|
||||
Returns:
|
||||
是否上报成功
|
||||
"""
|
||||
with self._lock:
|
||||
self._performance_stats["alerts_generated"] += 1
|
||||
|
||||
try:
|
||||
self._init_mqtt()
|
||||
self._init_database()
|
||||
|
||||
alert_data = alert.to_dict()
|
||||
|
||||
if screenshot is not None:
|
||||
from utils.common import image_to_base64
|
||||
alert_data["screenshot"] = image_to_base64(screenshot)
|
||||
|
||||
topic = f"{self._topic_prefix}/{alert.camera_id}/{alert.roi_id}"
|
||||
|
||||
success = self._mqtt_client.publish(topic, alert_data)
|
||||
|
||||
if store_to_db:
|
||||
db_success = self._store_to_database(alert.to_dict())
|
||||
if db_success:
|
||||
with self._lock:
|
||||
self._performance_stats["alerts_stored"] += 1
|
||||
|
||||
if success:
|
||||
with self._lock:
|
||||
self._performance_stats["alerts_sent"] += 1
|
||||
self._logger.log_alert(
|
||||
alert.alert_type,
|
||||
alert.camera_id,
|
||||
alert.roi_id,
|
||||
alert.confidence
|
||||
)
|
||||
else:
|
||||
with self._lock:
|
||||
self._performance_stats["send_failures"] += 1
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
with self._lock:
|
||||
self._performance_stats["send_failures"] += 1
|
||||
self._logger.error(f"上报告警失败: {e}")
|
||||
return False
|
||||
|
||||
def _store_to_database(self, alert_data: Dict[str, Any]) -> bool:
|
||||
"""存储告警到数据库"""
|
||||
try:
|
||||
return self._db_manager.save_alert_record(alert_data)
|
||||
except Exception as e:
|
||||
self._logger.error(f"存储告警到数据库失败: {e}")
|
||||
return False
|
||||
|
||||
def batch_report(
|
||||
self,
|
||||
alerts: List[AlertInfo],
|
||||
screenshots: Optional[Dict[str, np.ndarray]] = None
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
批量上报告警
|
||||
|
||||
Args:
|
||||
alerts: 告警列表
|
||||
screenshots: 截图字典 (alert_id -> screenshot)
|
||||
|
||||
Returns:
|
||||
上报结果统计
|
||||
"""
|
||||
results = {
|
||||
"total": len(alerts),
|
||||
"success": 0,
|
||||
"failed": 0,
|
||||
}
|
||||
|
||||
screenshots = screenshots or {}
|
||||
|
||||
for alert in alerts:
|
||||
screenshot = screenshots.get(alert.alert_id)
|
||||
if self.report_alert(alert, screenshot):
|
||||
results["success"] += 1
|
||||
else:
|
||||
results["failed"] += 1
|
||||
|
||||
return results
|
||||
|
||||
def report_heartbeat(self, device_id: str, status: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
上报心跳
|
||||
|
||||
Args:
|
||||
device_id: 设备ID
|
||||
status: 状态信息
|
||||
|
||||
Returns:
|
||||
是否上报成功
|
||||
"""
|
||||
try:
|
||||
self._init_mqtt()
|
||||
|
||||
heartbeat_data = {
|
||||
"device_id": device_id,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"status": status,
|
||||
}
|
||||
|
||||
topic = f"edge/heartbeat/{device_id}"
|
||||
return self._mqtt_client.publish(topic, heartbeat_data)
|
||||
|
||||
except Exception as e:
|
||||
self._logger.error(f"上报心跳失败: {e}")
|
||||
return False
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, int]:
|
||||
"""获取性能统计"""
|
||||
with self._lock:
|
||||
return dict(self._performance_stats)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取状态"""
|
||||
return {
|
||||
"mqtt": self._mqtt_client.get_status() if self._mqtt_client else None,
|
||||
"stats": self.get_performance_stats(),
|
||||
}
|
||||
|
||||
def close(self):
|
||||
"""关闭资源"""
|
||||
if self._mqtt_client:
|
||||
self._mqtt_client.close()
|
||||
self._mqtt_client = None
|
||||
|
||||
|
||||
class ResultReporter:
|
||||
"""结果上报主类
|
||||
|
||||
统一管理告警上报
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._alert_reporter = AlertReporter()
|
||||
self._logger = get_logger("reporter")
|
||||
|
||||
def report_detection_alert(
|
||||
self,
|
||||
camera_id: str,
|
||||
roi_id: str,
|
||||
alert_type: str,
|
||||
detection: Dict[str, Any],
|
||||
screenshot: Optional[np.ndarray] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上报检测告警
|
||||
|
||||
Args:
|
||||
camera_id: 摄像头ID
|
||||
roi_id: ROI ID
|
||||
alert_type: 告警类型
|
||||
detection: 检测结果
|
||||
screenshot: 截图
|
||||
|
||||
Returns:
|
||||
是否上报成功
|
||||
"""
|
||||
alert = AlertInfo(
|
||||
alert_id=generate_unique_id("alert"),
|
||||
camera_id=camera_id,
|
||||
roi_id=roi_id,
|
||||
alert_type=alert_type,
|
||||
target_class=detection.get("class_name"),
|
||||
confidence=detection.get("confidence"),
|
||||
bbox=detection.get("bbox"),
|
||||
message=detection.get("message", f"检测到{alert_type}"),
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
return self._alert_reporter.report_alert(alert, screenshot)
|
||||
|
||||
def report_leave_post_alert(
|
||||
self,
|
||||
camera_id: str,
|
||||
roi_id: str,
|
||||
duration_seconds: int,
|
||||
screenshot: Optional[np.ndarray] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上报离岗告警
|
||||
|
||||
Args:
|
||||
camera_id: 摄像头ID
|
||||
roi_id: ROI ID
|
||||
duration_seconds: 离岗时长(秒)
|
||||
screenshot: 截图
|
||||
|
||||
Returns:
|
||||
是否上报成功
|
||||
"""
|
||||
detection = {
|
||||
"class_name": "person",
|
||||
"confidence": None,
|
||||
"message": f"离岗告警: 已离开{duration_seconds}秒",
|
||||
}
|
||||
|
||||
return self.report_detection_alert(
|
||||
camera_id=camera_id,
|
||||
roi_id=roi_id,
|
||||
alert_type="leave_post",
|
||||
detection=detection,
|
||||
screenshot=screenshot
|
||||
)
|
||||
|
||||
def report_intrusion_alert(
|
||||
self,
|
||||
camera_id: str,
|
||||
roi_id: str,
|
||||
confidence: float,
|
||||
bbox: List[float],
|
||||
screenshot: Optional[np.ndarray] = None
|
||||
) -> bool:
|
||||
"""
|
||||
上报入侵告警
|
||||
|
||||
Args:
|
||||
camera_id: 摄像头ID
|
||||
roi_id: ROI ID
|
||||
confidence: 置信度
|
||||
bbox: 检测框
|
||||
screenshot: 截图
|
||||
|
||||
Returns:
|
||||
是否上报成功
|
||||
"""
|
||||
detection = {
|
||||
"class_name": "intrusion",
|
||||
"confidence": confidence,
|
||||
"bbox": bbox,
|
||||
"message": "检测到周界入侵",
|
||||
}
|
||||
|
||||
return self.report_detection_alert(
|
||||
camera_id=camera_id,
|
||||
roi_id=roi_id,
|
||||
alert_type="intrusion",
|
||||
detection=detection,
|
||||
screenshot=screenshot
|
||||
)
|
||||
|
||||
def report_heartbeat(self, device_id: str, status: Dict[str, Any]) -> bool:
|
||||
"""上报心跳"""
|
||||
return self._alert_reporter.report_heartbeat(device_id, status)
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取状态"""
|
||||
return self._alert_reporter.get_status()
|
||||
|
||||
def close(self):
|
||||
"""关闭资源"""
|
||||
self._alert_reporter.close()
|
||||
|
||||
|
||||
def get_result_reporter() -> ResultReporter:
|
||||
"""获取结果上报器单例"""
|
||||
return ResultReporter()
|
||||
Reference in New Issue
Block a user