diff --git a/config/settings.py b/config/settings.py index f654f39..6825f3d 100644 --- a/config/settings.py +++ b/config/settings.py @@ -49,6 +49,7 @@ class MQTTConfig: broker_host: str = "localhost" broker_port: int = 1883 client_id: str = "edge_inference_service" + device_id: str = "edge-001" username: Optional[str] = None password: Optional[str] = None keepalive: int = 60 @@ -82,6 +83,21 @@ class InferenceConfig: fp16_mode: bool = True +# COCO 数据集类别名称(YOLO 模型使用) +COCO_CLASS_NAMES = [ + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", + "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", + "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", + "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", + "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", + "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", + "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", + "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", + "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush" +] + + @dataclass class CameraConfig: """单个摄像头配置""" @@ -146,6 +162,7 @@ class Settings: broker_host=os.getenv("MQTT_BROKER_HOST", "localhost"), broker_port=int(os.getenv("MQTT_BROKER_PORT", "1883")), client_id=os.getenv("MQTT_CLIENT_ID", "edge_inference_service"), + device_id=os.getenv("EDGE_DEVICE_ID", "edge-001"), username=os.getenv("MQTT_USERNAME"), password=os.getenv("MQTT_PASSWORD"), ) @@ -170,6 +187,9 @@ class Settings: self.log_file_backup_count = int(os.getenv("LOG_FILE_BACKUP_COUNT", "5")) self.working_hours = self._parse_working_hours() + + # 使用 COCO 类别名称 + self.class_names = COCO_CLASS_NAMES def _parse_working_hours(self) -> List[dict]: """解析工作时间配置""" diff --git a/core/result_reporter.py b/core/result_reporter.py index 52b8b95..ed07c03 100644 --- a/core/result_reporter.py +++ b/core/result_reporter.py @@ -12,6 +12,7 @@ import json import logging import threading import time +import uuid from datetime import datetime from typing import Any, Dict, List, Optional, Callable from dataclasses import dataclass, field @@ -32,6 +33,8 @@ class AlertInfo: roi_id: str alert_type: str bind_id: Optional[str] = None + device_id: Optional[str] = None + algorithm: Optional[str] = None target_class: Optional[str] = None confidence: Optional[float] = None bbox: Optional[List[float]] = field(default_factory=list) @@ -39,14 +42,17 @@ class AlertInfo: screenshot: Optional[np.ndarray] = None timestamp: datetime = field(default_factory=datetime.now) duration_minutes: Optional[float] = None - + def to_dict(self) -> Dict[str, Any]: - """转换为字典""" + """转换为字典(发送到告警平台的 MQTT 消息格式)""" return { "alert_id": self.alert_id, "camera_id": self.camera_id, "roi_id": self.roi_id, + "bind_id": self.bind_id, + "device_id": self.device_id, "alert_type": self.alert_type, + "algorithm": self.algorithm, "target_class": self.target_class, "confidence": self.confidence, "bbox": self.bbox, @@ -110,19 +116,22 @@ class ResultReporter: """初始化MQTT客户端""" self._logger.info(f"正在连接 MQTT: {self._mqtt_broker}:{self._mqtt_port}") try: + # 给 client_id 添加随机后缀,防止冲突 + unique_client_id = f"{self._mqtt_client_id}_{uuid.uuid4().hex[:8]}" + # 兼容不同版本的 paho-mqtt try: # paho-mqtt 2.0+ 版本 self._client = mqtt.Client( - client_id=self._mqtt_client_id, - protocol=mqtt.MQTTv5, + client_id=unique_client_id, + protocol=mqtt.MQTTv311, callback_api_version=mqtt.CallbackAPIVersion.VERSION2 ) except (AttributeError, TypeError): # paho-mqtt 1.x 版本 self._client = mqtt.Client( - client_id=self._mqtt_client_id, - protocol=mqtt.MQTTv5 + client_id=unique_client_id, + protocol=mqtt.MQTTv311 ) self._client.on_connect = self._on_connect diff --git a/main.py b/main.py index de65e17..62ad60c 100644 --- a/main.py +++ b/main.py @@ -385,7 +385,9 @@ class EdgeInferenceService: camera_id=camera_id, roi_id=roi_id, bind_id=bind.bind_id, + device_id=self._settings.mqtt.device_id, alert_type=alert.get("alert_type", "detection"), + algorithm=algo_code, target_class=alert.get("class", bind.target_class or "unknown"), confidence=alert.get("confidence", 1.0), bbox=alert.get("bbox", []), @@ -472,7 +474,7 @@ class EdgeInferenceService: } if self._reporter: - self._reporter.report_heartbeat("edge_inference_device", status) + self._reporter.report_heartbeat(self._settings.mqtt.device_id, status) except Exception as e: self._logger.error(f"心跳上报失败: {e}")