feat(project): move edge_inference_service contents to root and update paths

- Moved all project files and directories (config, core, models, etc.) from
  edge_inference_service/ to the repository root ai_edge/
- Updated model path in config/settings.py to reflect new structure
- Revised usage paths in __init__.py documentation
This commit is contained in:
2026-01-29 18:43:19 +08:00
parent 6733344b11
commit b0ddb6ee1a
50 changed files with 407 additions and 627 deletions

659
core/result_reporter.py Normal file
View File

@@ -0,0 +1,659 @@
"""
结果上报模块
实现MQTT客户端、告警存储、重试机制
"""
import json
import logging
import threading
import time
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
import numpy as np
import paho.mqtt.client as mqtt
from paho.mqtt.client import MQTTMessage
from config.settings import get_settings, MQTTConfig
from config.database import get_database_manager
from config.config_models import AlertInfo
from utils.logger import get_logger
from utils.common import generate_unique_id, ExponentialBackoff
logger = logging.getLogger(__name__)
class MQTTClient:
"""MQTT客户端类
实现与云端业务平台的可靠连接
"""
def __init__(self, config: Optional[MQTTConfig] = None):
"""
初始化MQTT客户端
Args:
config: MQTT配置
"""
if config is None:
settings = get_settings()
config = settings.mqtt
self.config = config
self._client = None
self._connected = False
self._reconnect_thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._message_callbacks: Dict[str, Callable] = {}
self._pending_messages: List[Dict] = []
self._pending_lock = threading.Lock()
self._logger = get_logger("mqtt")
self._lock = threading.Lock()
self._performance_stats = {
"messages_sent": 0,
"messages_received": 0,
"send_errors": 0,
"reconnect_count": 0,
}
def _on_connect(self, client, userdata, flags, rc):
"""连接回调"""
if rc == 0:
self._connected = True
self._logger.log_connection_event(
"connect", "MQTT", self.config.broker_host, True
)
self._logger.info(f"MQTT连接成功: {self.config.broker_host}:{self.config.broker_port}")
self._resend_pending_messages()
else:
self._connected = False
self._logger.log_connection_event(
"connect", "MQTT", self.config.broker_host, False,
error_msg=f"错误码: {rc}"
)
def _on_disconnect(self, client, userdata, rc):
"""断开连接回调"""
self._connected = False
self._logger.log_connection_event(
"disconnect", "MQTT", self.config.broker_host, False
)
if rc != 0:
self._performance_stats["reconnect_count"] += 1
self._logger.warning(f"MQTT意外断开尝试重连: {rc}")
self._start_reconnect()
def _on_message(self, client, userdata, msg: MQTTMessage):
"""消息回调"""
self._performance_stats["messages_received"] += 1
topic = msg.topic
payload = msg.payload.decode('utf-8')
self._logger.debug(f"收到消息: {topic} -> {payload}")
if topic in self._message_callbacks:
try:
self._message_callbacks[topic](payload)
except Exception as e:
self._logger.error(f"消息回调执行失败: {e}")
def _on_publish(self, client, userdata, mid):
"""发布回调"""
self._performance_stats["messages_sent"] += 1
self._logger.debug(f"消息已发布: mid={mid}")
def _create_client(self) -> mqtt.Client:
"""创建MQTT客户端"""
client = mqtt.Client(
client_id=self.config.client_id,
clean_session=True,
userdata=None,
protocol=mqtt.MQTTv311,
transport="tcp"
)
client.on_connect = self._on_connect
client.on_disconnect = self._on_disconnect
client.on_message = self._on_message
client.on_publish = self._on_publish
if self.config.username and self.config.password:
client.username_pw_set(
self.config.username,
self.config.password
)
client.reconnect_delay_set(
min_delay=self.config.reconnect_delay,
max_delay=self.config.reconnect_delay * 2
)
return client
def _start_reconnect(self):
"""启动重连线程"""
if self._reconnect_thread is not None and self._reconnect_thread.is_alive():
return
self._stop_event.clear()
self._reconnect_thread = threading.Thread(
target=self._reconnect_loop,
name="MQTTReconnect",
daemon=True
)
self._reconnect_thread.start()
def _reconnect_loop(self):
"""重连循环"""
backoff = ExponentialBackoff(
base_delay=self.config.reconnect_delay,
max_delay=self.config.reconnect_delay * 3,
max_attempts=self.config.max_reconnect_attempts
)
while not self._stop_event.is_set():
if self._connected:
break
if not backoff.next_attempt():
self._logger.error("MQTT重连次数已达上限")
break
delay = backoff.get_delay()
self._logger.info(f"等待 {delay:.1f}秒后尝试重连MQTT...")
time.sleep(delay)
try:
self._client.connect(
self.config.broker_host,
self.config.broker_port,
self.config.keepalive
)
self._client.loop_start()
except Exception as e:
self._logger.error(f"MQTT重连失败: {e}")
def _resend_pending_messages(self):
"""重发待发送消息"""
with self._pending_lock:
if not self._pending_messages:
return
self._logger.info(f"重发待处理消息: {len(self._pending_messages)}")
failed_messages = []
for msg in self._pending_messages:
try:
self._publish_message(
msg["topic"],
msg["payload"],
msg.get("qos", self.config.qos),
msg.get("retain", False)
)
except Exception:
failed_messages.append(msg)
self._pending_messages = failed_messages
def connect(self) -> bool:
"""
连接到MQTT代理
Returns:
是否连接成功
"""
with self._lock:
try:
if self._client is not None:
self._client.loop_stop()
self._client = self._create_client()
self._client.connect(
self.config.broker_host,
self.config.broker_port,
self.config.keepalive
)
self._client.loop_start()
self._logger.info(
f"MQTT正在连接: {self.config.broker_host}:{self.config.broker_port}"
)
return True
except Exception as e:
self._logger.error(f"MQTT连接失败: {e}")
return False
def disconnect(self):
"""断开MQTT连接"""
with self._lock:
if self._client:
self._stop_event.set()
self._client.loop_stop()
self._client.disconnect()
self._client = None
self._connected = False
self._logger.info("MQTT连接已断开")
def subscribe(self, topic: str, qos: int = 1):
"""
订阅主题
Args:
topic: 主题
qos: 服务质量等级
"""
if self._client and self._connected:
self._client.subscribe(topic, qos)
self._logger.info(f"已订阅主题: {topic}")
def publish(
self,
topic: str,
payload: Dict[str, Any],
qos: Optional[int] = None,
retain: bool = False
) -> bool:
"""
发布消息
Args:
topic: 主题
payload: 消息负载
qos: 服务质量等级
retain: 保留消息
Returns:
是否发布成功
"""
if qos is None:
qos = self.config.qos
try:
message = json.dumps(payload, ensure_ascii=False)
if self._connected and self._client:
self._publish_message(topic, message, qos, retain)
return True
else:
self._queue_message(topic, payload, qos, retain)
return False
except Exception as e:
self._performance_stats["send_errors"] += 1
self._logger.error(f"发布消息失败: {e}")
return False
def _publish_message(
self,
topic: str,
message: str,
qos: int,
retain: bool
):
"""内部发布方法"""
self._client.publish(topic, message, qos, retain)
self._logger.debug(f"消息已发布: {topic}")
def _queue_message(
self,
topic: str,
payload: Dict[str, Any],
qos: int,
retain: bool
):
"""将消息加入待发送队列"""
with self._pending_lock:
self._pending_messages.append({
"topic": topic,
"payload": payload,
"qos": qos,
"retain": retain,
"timestamp": datetime.now().isoformat(),
})
self._logger.warning(f"消息已加入待发送队列: {topic}")
def register_message_callback(self, topic: str, callback: Callable):
"""注册消息回调"""
self._message_callbacks[topic] = callback
self._logger.info(f"已注册消息回调: {topic}")
def get_performance_stats(self) -> Dict[str, int]:
"""获取性能统计"""
return dict(self._performance_stats)
def get_status(self) -> Dict[str, Any]:
"""获取连接状态"""
return {
"connected": self._connected,
"broker_host": self.config.broker_host,
"broker_port": self.config.broker_port,
"pending_messages": len(self._pending_messages),
"stats": self.get_performance_stats(),
}
def close(self):
"""关闭客户端"""
self.disconnect()
self._logger.info("MQTT客户端已关闭")
class AlertReporter:
"""告警上报器类
整合MQTT上报、数据库存储、重试机制
"""
def __init__(self):
self._mqtt_client = None
self._db_manager = None
self._logger = get_logger("reporter")
self._lock = threading.Lock()
self._topic_prefix = "edge/alerts"
self._performance_stats = {
"alerts_generated": 0,
"alerts_sent": 0,
"alerts_stored": 0,
"send_failures": 0,
}
def _init_mqtt(self):
"""初始化MQTT客户端"""
if self._mqtt_client is None:
self._mqtt_client = MQTTClient()
self._mqtt_client.connect()
def _init_database(self):
"""初始化数据库连接"""
if self._db_manager is None:
self._db_manager = get_database_manager()
def report_alert(
self,
alert: AlertInfo,
screenshot: Optional[np.ndarray] = None,
store_to_db: bool = True
) -> bool:
"""
上报告警
Args:
alert: 告警信息
screenshot: 截图 (可选)
store_to_db: 是否存储到数据库
Returns:
是否上报成功
"""
with self._lock:
self._performance_stats["alerts_generated"] += 1
try:
self._init_mqtt()
self._init_database()
alert_data = alert.to_dict()
if screenshot is not None:
from utils.common import image_to_base64
alert_data["screenshot"] = image_to_base64(screenshot)
topic = f"{self._topic_prefix}/{alert.camera_id}/{alert.roi_id}"
success = self._mqtt_client.publish(topic, alert_data)
if store_to_db:
db_success = self._store_to_database(alert.to_dict())
if db_success:
with self._lock:
self._performance_stats["alerts_stored"] += 1
if success:
with self._lock:
self._performance_stats["alerts_sent"] += 1
self._logger.log_alert(
alert.alert_type,
alert.camera_id,
alert.roi_id,
alert.confidence
)
else:
with self._lock:
self._performance_stats["send_failures"] += 1
return success
except Exception as e:
with self._lock:
self._performance_stats["send_failures"] += 1
self._logger.error(f"上报告警失败: {e}")
return False
def _store_to_database(self, alert_data: Dict[str, Any]) -> bool:
"""存储告警到数据库"""
try:
return self._db_manager.save_alert_record(alert_data)
except Exception as e:
self._logger.error(f"存储告警到数据库失败: {e}")
return False
def batch_report(
self,
alerts: List[AlertInfo],
screenshots: Optional[Dict[str, np.ndarray]] = None
) -> Dict[str, int]:
"""
批量上报告警
Args:
alerts: 告警列表
screenshots: 截图字典 (alert_id -> screenshot)
Returns:
上报结果统计
"""
results = {
"total": len(alerts),
"success": 0,
"failed": 0,
}
screenshots = screenshots or {}
for alert in alerts:
screenshot = screenshots.get(alert.alert_id)
if self.report_alert(alert, screenshot):
results["success"] += 1
else:
results["failed"] += 1
return results
def report_heartbeat(self, device_id: str, status: Dict[str, Any]) -> bool:
"""
上报心跳
Args:
device_id: 设备ID
status: 状态信息
Returns:
是否上报成功
"""
try:
self._init_mqtt()
heartbeat_data = {
"device_id": device_id,
"timestamp": datetime.now().isoformat(),
"status": status,
}
topic = f"edge/heartbeat/{device_id}"
return self._mqtt_client.publish(topic, heartbeat_data)
except Exception as e:
self._logger.error(f"上报心跳失败: {e}")
return False
def get_performance_stats(self) -> Dict[str, int]:
"""获取性能统计"""
with self._lock:
return dict(self._performance_stats)
def get_status(self) -> Dict[str, Any]:
"""获取状态"""
return {
"mqtt": self._mqtt_client.get_status() if self._mqtt_client else None,
"stats": self.get_performance_stats(),
}
def close(self):
"""关闭资源"""
if self._mqtt_client:
self._mqtt_client.close()
self._mqtt_client = None
class ResultReporter:
"""结果上报主类
统一管理告警上报
"""
def __init__(self):
self._alert_reporter = AlertReporter()
self._logger = get_logger("reporter")
def report_detection_alert(
self,
camera_id: str,
roi_id: str,
alert_type: str,
detection: Dict[str, Any],
screenshot: Optional[np.ndarray] = None
) -> bool:
"""
上报检测告警
Args:
camera_id: 摄像头ID
roi_id: ROI ID
alert_type: 告警类型
detection: 检测结果
screenshot: 截图
Returns:
是否上报成功
"""
alert = AlertInfo(
alert_id=generate_unique_id("alert"),
camera_id=camera_id,
roi_id=roi_id,
alert_type=alert_type,
target_class=detection.get("class_name"),
confidence=detection.get("confidence"),
bbox=detection.get("bbox"),
message=detection.get("message", f"检测到{alert_type}"),
timestamp=datetime.now().isoformat(),
)
return self._alert_reporter.report_alert(alert, screenshot)
def report_leave_post_alert(
self,
camera_id: str,
roi_id: str,
duration_seconds: int,
screenshot: Optional[np.ndarray] = None
) -> bool:
"""
上报离岗告警
Args:
camera_id: 摄像头ID
roi_id: ROI ID
duration_seconds: 离岗时长(秒)
screenshot: 截图
Returns:
是否上报成功
"""
detection = {
"class_name": "person",
"confidence": None,
"message": f"离岗告警: 已离开{duration_seconds}",
}
return self.report_detection_alert(
camera_id=camera_id,
roi_id=roi_id,
alert_type="leave_post",
detection=detection,
screenshot=screenshot
)
def report_intrusion_alert(
self,
camera_id: str,
roi_id: str,
confidence: float,
bbox: List[float],
screenshot: Optional[np.ndarray] = None
) -> bool:
"""
上报入侵告警
Args:
camera_id: 摄像头ID
roi_id: ROI ID
confidence: 置信度
bbox: 检测框
screenshot: 截图
Returns:
是否上报成功
"""
detection = {
"class_name": "intrusion",
"confidence": confidence,
"bbox": bbox,
"message": "检测到周界入侵",
}
return self.report_detection_alert(
camera_id=camera_id,
roi_id=roi_id,
alert_type="intrusion",
detection=detection,
screenshot=screenshot
)
def report_heartbeat(self, device_id: str, status: Dict[str, Any]) -> bool:
"""上报心跳"""
return self._alert_reporter.report_heartbeat(device_id, status)
def get_status(self) -> Dict[str, Any]:
"""获取状态"""
return self._alert_reporter.get_status()
def close(self):
"""关闭资源"""
self._alert_reporter.close()
def get_result_reporter() -> ResultReporter:
"""获取结果上报器单例"""
return ResultReporter()