diff --git a/config/database.py b/config/database.py index daa2fc1..2698433 100644 --- a/config/database.py +++ b/config/database.py @@ -17,6 +17,19 @@ import time import logging from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Generator + + +def _normalize_coordinates(coords): + """将坐标统一为 [[x,y],...] 格式,兼容 [{'x':..,'y':..},...] 格式""" + if isinstance(coords, str): + try: + coords = eval(coords) + except: + return coords + if isinstance(coords, list) and coords and isinstance(coords[0], dict): + return [[p.get("x", 0), p.get("y", 0)] for p in coords] + return coords + from dataclasses import dataclass, field from pathlib import Path @@ -633,7 +646,7 @@ class SQLiteManager: 'enabled', 'priority', 'extra_params', 'updated_at'] result = dict(zip(columns, row)) try: - result['coordinates'] = eval(result['coordinates']) + result['coordinates'] = _normalize_coordinates(result['coordinates']) except: pass return result @@ -653,7 +666,7 @@ class SQLiteManager: for row in cursor.fetchall(): r = dict(zip(columns, row)) try: - r['coordinates'] = eval(r['coordinates']) + r['coordinates'] = _normalize_coordinates(r['coordinates']) except: pass results.append(r) @@ -661,7 +674,7 @@ class SQLiteManager: except Exception as e: logger.error(f"获取ROI配置失败: {e}") return [] - + def get_all_roi_configs(self) -> List[Dict[str, Any]]: """获取所有ROI配置""" try: @@ -673,7 +686,7 @@ class SQLiteManager: for row in cursor.fetchall(): r = dict(zip(columns, row)) try: - r['coordinates'] = eval(r['coordinates']) + r['coordinates'] = _normalize_coordinates(r['coordinates']) except: pass results.append(r) @@ -856,7 +869,7 @@ class SQLiteManager: if result.get('params'): result['params'] = json.loads(result['params']) if result.get('coordinates'): - result['coordinates'] = eval(result['coordinates']) + result['coordinates'] = _normalize_coordinates(result['coordinates']) except: pass results.append(result) diff --git a/core/postprocessor.py b/core/postprocessor.py index 7c73fe3..87a87f2 100644 --- a/core/postprocessor.py +++ b/core/postprocessor.py @@ -728,9 +728,18 @@ class PostProcessor: if len(batch_outputs) == 1: first_output = batch_outputs[0] - if isinstance(first_output, np.ndarray) and first_output.ndim == 3: - if first_output.shape[0] == batch_size: + if isinstance(first_output, np.ndarray): + if first_output.ndim == 3 and first_output.shape[0] == batch_size: + # 已经是 (batch, 84, anchors) 格式 outputs_array = first_output + elif first_output.ndim == 1: + # TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 84, anchors) + per_image = first_output.shape[0] // batch_size + num_anchors = per_image // 84 + outputs_array = first_output.reshape(batch_size, 84, num_anchors) + elif first_output.ndim == 2: + # (84, anchors) 单张图的输出 + outputs_array = first_output.reshape(1, first_output.shape[0], first_output.shape[1]) else: outputs_array = first_output else: diff --git a/core/result_reporter.py b/core/result_reporter.py index ed07c03..1a5ef64 100644 --- a/core/result_reporter.py +++ b/core/result_reporter.py @@ -195,25 +195,31 @@ class ResultReporter: 是否上报成功 """ self._performance_stats["alerts_generated"] += 1 - + + # MQTT 发布和本地存储独立执行,互不影响 + mqtt_ok = False try: - if store_to_db and self._db_manager: - self._store_alert(alert, screenshot) - if self._connected and self._client: self._publish_alert(alert) + mqtt_ok = True else: self._logger.warning("MQTT 未连接,消息已加入待发送队列") if self._local_cache: self._local_cache.cache_alert(alert.to_dict()) - - self._performance_stats["alerts_sent"] += 1 - return True - except Exception as e: self._performance_stats["send_failures"] += 1 - self._logger.error(f"上报告警失败: {e}") - return False + self._logger.error(f"MQTT 发布告警失败: {e}") + + try: + if store_to_db and self._db_manager: + self._store_alert(alert, screenshot) + except Exception as e: + self._logger.error(f"本地存储告警失败: {e}") + + if mqtt_ok: + self._performance_stats["alerts_sent"] += 1 + + return mqtt_ok def _store_alert(self, alert: AlertInfo, screenshot: Optional[np.ndarray] = None): """存储告警到本地数据库(异步)""" diff --git a/core/video_stream.py b/core/video_stream.py index 9059a81..a451598 100644 --- a/core/video_stream.py +++ b/core/video_stream.py @@ -409,11 +409,18 @@ class MultiStreamManager: self._streams[camera_id].start() def start_all(self): - """启动所有视频流""" + """启动所有视频流(跳过连接失败的流)""" with self._lock: - for stream in self._streams.values(): - stream.start() - self._logger.info(f"已启动 {len(self._streams)} 个视频流") + failed = [] + for camera_id, stream in self._streams.items(): + try: + stream.start() + except Exception as e: + self._logger.warning(f"视频流启动失败,跳过: {camera_id} - {e}") + failed.append(camera_id) + started = len(self._streams) - len(failed) + self._logger.info(f"已启动 {started}/{len(self._streams)} 个视频流" + + (f",{len(failed)} 个失败" if failed else "")) def stop_stream(self, camera_id: str): """停止指定视频流""" diff --git a/main.py b/main.py index 62ad60c..aa53f65 100644 --- a/main.py +++ b/main.py @@ -64,6 +64,10 @@ class EdgeInferenceService: self._max_batch_size = 8 self._batch_timeout_sec = 0.05 # 50ms 攒批窗口 + # 摄像头级别告警去重:同一摄像头+告警类型在冷却期内只上报一次 + self._camera_alert_cooldown: Dict[str, datetime] = {} + self._camera_cooldown_seconds = 30 # 同摄像头同类型告警最小间隔(秒) + self._logger.info("Edge_Inference_Service 初始化开始") def _init_database(self): @@ -256,6 +260,14 @@ class EdgeInferenceService: # 一次性推理整个 batch outputs, inference_time_ms = engine.infer(batch_data) + # 诊断:输出原始推理结果形状 + import numpy as np + if isinstance(outputs, np.ndarray): + self._logger.info(f"[推理诊断] batch_data shape={batch_data.shape}, output shape={outputs.shape}, 耗时={inference_time_ms:.1f}ms") + elif isinstance(outputs, (list, tuple)): + shapes = [o.shape if hasattr(o, 'shape') else type(o) for o in outputs] + self._logger.info(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms") + batch_size = len(roi_items) batch_results = self._postprocessor.batch_process_detections( outputs, @@ -263,10 +275,8 @@ class EdgeInferenceService: conf_threshold=self._settings.inference.conf_threshold ) - # 诊断日志:显示每个 ROI 的检测结果数量 total_detections = sum(len(r[0]) for r in batch_results) - if total_detections > 0: - self._logger.info(f"[推理] batch_size={batch_size}, 总检测数={total_detections}") + self._logger.info(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}") for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(roi_items): boxes, scores, class_ids = batch_results[idx] @@ -377,8 +387,24 @@ class EdgeInferenceService: self._logger.info(f"[{camera_id}] 算法 {algo_code} 无告警, 状态: {algo_status}") for alert in alerts: + alert_type = alert.get("alert_type", "detection") + + # 摄像头级别去重:同一摄像头+告警类型在冷却期内只上报一次 + dedup_key = f"{camera_id}_{alert_type}" + now = frame.timestamp + last_alert_time = self._camera_alert_cooldown.get(dedup_key) + if last_alert_time is not None: + elapsed = (now - last_alert_time).total_seconds() + if elapsed < self._camera_cooldown_seconds: + self._logger.info( + f"[去重] 跳过告警: camera={camera_id}, type={alert_type}, " + f"roi={roi_id}, 距上次={elapsed:.1f}s < {self._camera_cooldown_seconds}s" + ) + continue + + self._camera_alert_cooldown[dedup_key] = now self._performance_stats["total_alerts_generated"] += 1 - + from core.result_reporter import AlertInfo alert_info = AlertInfo( alert_id=f"{roi_id}_{bind.bind_id}_{int(frame.timestamp.timestamp())}", @@ -386,7 +412,7 @@ class EdgeInferenceService: roi_id=roi_id, bind_id=bind.bind_id, device_id=self._settings.mqtt.device_id, - alert_type=alert.get("alert_type", "detection"), + alert_type=alert_type, algorithm=algo_code, target_class=alert.get("class", bind.target_class or "unknown"), confidence=alert.get("confidence", 1.0), @@ -396,9 +422,9 @@ class EdgeInferenceService: duration_minutes=alert.get("duration_minutes"), ) self._reporter.report_alert(alert_info, screenshot=frame.image) - + self._logger.info( - f"告警已生成: type={alert.get('alert_type', 'detection')}, " + f"告警已生成: type={alert_type}, " f"camera={camera_id}, roi={roi_id}, " f"confidence={alert.get('confidence', 1.0)}" ) diff --git a/utils/logger.py b/utils/logger.py index 5e0c2c9..ada4e97 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -87,29 +87,34 @@ class StructuredLogger: def _init_logger(self): """初始化日志配置""" settings = get_settings() - + self._log_level = getattr(logging, settings.log_level.upper(), logging.INFO) self._log_dir = settings.log_dir self._max_size = settings.log_file_max_size self._backup_count = settings.log_file_backup_count - + os.makedirs(self._log_dir, exist_ok=True) - - self._logger = logging.getLogger(self.name) - self._logger.setLevel(self._log_level) - - self._logger.handlers.clear() - + formatter = logging.Formatter( fmt='%(asctime)s | %(levelname)-8s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) - - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(self._log_level) - console_handler.setFormatter(formatter) - self._logger.addHandler(console_handler) - + + # 配置 root logger,使所有模块的 logging.getLogger(name) 都能输出 + root_logger = logging.getLogger() + root_logger.setLevel(self._log_level) + if not root_logger.handlers: + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(self._log_level) + console_handler.setFormatter(formatter) + root_logger.addHandler(console_handler) + + # 配置命名 logger(主模块专用,写入独立日志文件) + self._logger = logging.getLogger(self.name) + self._logger.setLevel(self._log_level) + self._logger.handlers.clear() + self._logger.propagate = True # 通过 root logger 输出到控制台 + self._add_file_handler(formatter) def _add_file_handler(self, formatter: logging.Formatter):