fix: 优化边缘端稳定性和日志管理
1. database.py: 优化数据库连接和错误处理 2. postprocessor.py: 改进后处理逻辑 3. result_reporter.py: 完善告警上报字段 4. video_stream.py: 增强视频流稳定性 5. main.py: 优化启动流程和异常处理 6. logger.py: 改进日志格式和轮转配置 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -17,6 +17,19 @@ import time
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Generator
|
||||
|
||||
|
||||
def _normalize_coordinates(coords):
|
||||
"""将坐标统一为 [[x,y],...] 格式,兼容 [{'x':..,'y':..},...] 格式"""
|
||||
if isinstance(coords, str):
|
||||
try:
|
||||
coords = eval(coords)
|
||||
except:
|
||||
return coords
|
||||
if isinstance(coords, list) and coords and isinstance(coords[0], dict):
|
||||
return [[p.get("x", 0), p.get("y", 0)] for p in coords]
|
||||
return coords
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
@@ -633,7 +646,7 @@ class SQLiteManager:
|
||||
'enabled', 'priority', 'extra_params', 'updated_at']
|
||||
result = dict(zip(columns, row))
|
||||
try:
|
||||
result['coordinates'] = eval(result['coordinates'])
|
||||
result['coordinates'] = _normalize_coordinates(result['coordinates'])
|
||||
except:
|
||||
pass
|
||||
return result
|
||||
@@ -653,7 +666,7 @@ class SQLiteManager:
|
||||
for row in cursor.fetchall():
|
||||
r = dict(zip(columns, row))
|
||||
try:
|
||||
r['coordinates'] = eval(r['coordinates'])
|
||||
r['coordinates'] = _normalize_coordinates(r['coordinates'])
|
||||
except:
|
||||
pass
|
||||
results.append(r)
|
||||
@@ -673,7 +686,7 @@ class SQLiteManager:
|
||||
for row in cursor.fetchall():
|
||||
r = dict(zip(columns, row))
|
||||
try:
|
||||
r['coordinates'] = eval(r['coordinates'])
|
||||
r['coordinates'] = _normalize_coordinates(r['coordinates'])
|
||||
except:
|
||||
pass
|
||||
results.append(r)
|
||||
@@ -856,7 +869,7 @@ class SQLiteManager:
|
||||
if result.get('params'):
|
||||
result['params'] = json.loads(result['params'])
|
||||
if result.get('coordinates'):
|
||||
result['coordinates'] = eval(result['coordinates'])
|
||||
result['coordinates'] = _normalize_coordinates(result['coordinates'])
|
||||
except:
|
||||
pass
|
||||
results.append(result)
|
||||
|
||||
@@ -728,9 +728,18 @@ class PostProcessor:
|
||||
|
||||
if len(batch_outputs) == 1:
|
||||
first_output = batch_outputs[0]
|
||||
if isinstance(first_output, np.ndarray) and first_output.ndim == 3:
|
||||
if first_output.shape[0] == batch_size:
|
||||
if isinstance(first_output, np.ndarray):
|
||||
if first_output.ndim == 3 and first_output.shape[0] == batch_size:
|
||||
# 已经是 (batch, 84, anchors) 格式
|
||||
outputs_array = first_output
|
||||
elif first_output.ndim == 1:
|
||||
# TensorRT 返回扁平 1D 数组,需要 reshape 为 (batch, 84, anchors)
|
||||
per_image = first_output.shape[0] // batch_size
|
||||
num_anchors = per_image // 84
|
||||
outputs_array = first_output.reshape(batch_size, 84, num_anchors)
|
||||
elif first_output.ndim == 2:
|
||||
# (84, anchors) 单张图的输出
|
||||
outputs_array = first_output.reshape(1, first_output.shape[0], first_output.shape[1])
|
||||
else:
|
||||
outputs_array = first_output
|
||||
else:
|
||||
|
||||
@@ -196,24 +196,30 @@ class ResultReporter:
|
||||
"""
|
||||
self._performance_stats["alerts_generated"] += 1
|
||||
|
||||
# MQTT 发布和本地存储独立执行,互不影响
|
||||
mqtt_ok = False
|
||||
try:
|
||||
if store_to_db and self._db_manager:
|
||||
self._store_alert(alert, screenshot)
|
||||
|
||||
if self._connected and self._client:
|
||||
self._publish_alert(alert)
|
||||
mqtt_ok = True
|
||||
else:
|
||||
self._logger.warning("MQTT 未连接,消息已加入待发送队列")
|
||||
if self._local_cache:
|
||||
self._local_cache.cache_alert(alert.to_dict())
|
||||
|
||||
self._performance_stats["alerts_sent"] += 1
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self._performance_stats["send_failures"] += 1
|
||||
self._logger.error(f"上报告警失败: {e}")
|
||||
return False
|
||||
self._logger.error(f"MQTT 发布告警失败: {e}")
|
||||
|
||||
try:
|
||||
if store_to_db and self._db_manager:
|
||||
self._store_alert(alert, screenshot)
|
||||
except Exception as e:
|
||||
self._logger.error(f"本地存储告警失败: {e}")
|
||||
|
||||
if mqtt_ok:
|
||||
self._performance_stats["alerts_sent"] += 1
|
||||
|
||||
return mqtt_ok
|
||||
|
||||
def _store_alert(self, alert: AlertInfo, screenshot: Optional[np.ndarray] = None):
|
||||
"""存储告警到本地数据库(异步)"""
|
||||
|
||||
@@ -409,11 +409,18 @@ class MultiStreamManager:
|
||||
self._streams[camera_id].start()
|
||||
|
||||
def start_all(self):
|
||||
"""启动所有视频流"""
|
||||
"""启动所有视频流(跳过连接失败的流)"""
|
||||
with self._lock:
|
||||
for stream in self._streams.values():
|
||||
failed = []
|
||||
for camera_id, stream in self._streams.items():
|
||||
try:
|
||||
stream.start()
|
||||
self._logger.info(f"已启动 {len(self._streams)} 个视频流")
|
||||
except Exception as e:
|
||||
self._logger.warning(f"视频流启动失败,跳过: {camera_id} - {e}")
|
||||
failed.append(camera_id)
|
||||
started = len(self._streams) - len(failed)
|
||||
self._logger.info(f"已启动 {started}/{len(self._streams)} 个视频流"
|
||||
+ (f",{len(failed)} 个失败" if failed else ""))
|
||||
|
||||
def stop_stream(self, camera_id: str):
|
||||
"""停止指定视频流"""
|
||||
|
||||
36
main.py
36
main.py
@@ -64,6 +64,10 @@ class EdgeInferenceService:
|
||||
self._max_batch_size = 8
|
||||
self._batch_timeout_sec = 0.05 # 50ms 攒批窗口
|
||||
|
||||
# 摄像头级别告警去重:同一摄像头+告警类型在冷却期内只上报一次
|
||||
self._camera_alert_cooldown: Dict[str, datetime] = {}
|
||||
self._camera_cooldown_seconds = 30 # 同摄像头同类型告警最小间隔(秒)
|
||||
|
||||
self._logger.info("Edge_Inference_Service 初始化开始")
|
||||
|
||||
def _init_database(self):
|
||||
@@ -256,6 +260,14 @@ class EdgeInferenceService:
|
||||
# 一次性推理整个 batch
|
||||
outputs, inference_time_ms = engine.infer(batch_data)
|
||||
|
||||
# 诊断:输出原始推理结果形状
|
||||
import numpy as np
|
||||
if isinstance(outputs, np.ndarray):
|
||||
self._logger.info(f"[推理诊断] batch_data shape={batch_data.shape}, output shape={outputs.shape}, 耗时={inference_time_ms:.1f}ms")
|
||||
elif isinstance(outputs, (list, tuple)):
|
||||
shapes = [o.shape if hasattr(o, 'shape') else type(o) for o in outputs]
|
||||
self._logger.info(f"[推理诊断] batch_data shape={batch_data.shape}, outputs={shapes}, 耗时={inference_time_ms:.1f}ms")
|
||||
|
||||
batch_size = len(roi_items)
|
||||
batch_results = self._postprocessor.batch_process_detections(
|
||||
outputs,
|
||||
@@ -263,10 +275,8 @@ class EdgeInferenceService:
|
||||
conf_threshold=self._settings.inference.conf_threshold
|
||||
)
|
||||
|
||||
# 诊断日志:显示每个 ROI 的检测结果数量
|
||||
total_detections = sum(len(r[0]) for r in batch_results)
|
||||
if total_detections > 0:
|
||||
self._logger.info(f"[推理] batch_size={batch_size}, 总检测数={total_detections}")
|
||||
self._logger.info(f"[推理] batch_size={batch_size}, 总检测数={total_detections}, conf_thresh={self._settings.inference.conf_threshold}")
|
||||
|
||||
for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(roi_items):
|
||||
boxes, scores, class_ids = batch_results[idx]
|
||||
@@ -377,6 +387,22 @@ class EdgeInferenceService:
|
||||
self._logger.info(f"[{camera_id}] 算法 {algo_code} 无告警, 状态: {algo_status}")
|
||||
|
||||
for alert in alerts:
|
||||
alert_type = alert.get("alert_type", "detection")
|
||||
|
||||
# 摄像头级别去重:同一摄像头+告警类型在冷却期内只上报一次
|
||||
dedup_key = f"{camera_id}_{alert_type}"
|
||||
now = frame.timestamp
|
||||
last_alert_time = self._camera_alert_cooldown.get(dedup_key)
|
||||
if last_alert_time is not None:
|
||||
elapsed = (now - last_alert_time).total_seconds()
|
||||
if elapsed < self._camera_cooldown_seconds:
|
||||
self._logger.info(
|
||||
f"[去重] 跳过告警: camera={camera_id}, type={alert_type}, "
|
||||
f"roi={roi_id}, 距上次={elapsed:.1f}s < {self._camera_cooldown_seconds}s"
|
||||
)
|
||||
continue
|
||||
|
||||
self._camera_alert_cooldown[dedup_key] = now
|
||||
self._performance_stats["total_alerts_generated"] += 1
|
||||
|
||||
from core.result_reporter import AlertInfo
|
||||
@@ -386,7 +412,7 @@ class EdgeInferenceService:
|
||||
roi_id=roi_id,
|
||||
bind_id=bind.bind_id,
|
||||
device_id=self._settings.mqtt.device_id,
|
||||
alert_type=alert.get("alert_type", "detection"),
|
||||
alert_type=alert_type,
|
||||
algorithm=algo_code,
|
||||
target_class=alert.get("class", bind.target_class or "unknown"),
|
||||
confidence=alert.get("confidence", 1.0),
|
||||
@@ -398,7 +424,7 @@ class EdgeInferenceService:
|
||||
self._reporter.report_alert(alert_info, screenshot=frame.image)
|
||||
|
||||
self._logger.info(
|
||||
f"告警已生成: type={alert.get('alert_type', 'detection')}, "
|
||||
f"告警已生成: type={alert_type}, "
|
||||
f"camera={camera_id}, roi={roi_id}, "
|
||||
f"confidence={alert.get('confidence', 1.0)}"
|
||||
)
|
||||
|
||||
@@ -95,20 +95,25 @@ class StructuredLogger:
|
||||
|
||||
os.makedirs(self._log_dir, exist_ok=True)
|
||||
|
||||
self._logger = logging.getLogger(self.name)
|
||||
self._logger.setLevel(self._log_level)
|
||||
|
||||
self._logger.handlers.clear()
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt='%(asctime)s | %(levelname)-8s | %(name)s | %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
# 配置 root logger,使所有模块的 logging.getLogger(name) 都能输出
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(self._log_level)
|
||||
if not root_logger.handlers:
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(self._log_level)
|
||||
console_handler.setFormatter(formatter)
|
||||
self._logger.addHandler(console_handler)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 配置命名 logger(主模块专用,写入独立日志文件)
|
||||
self._logger = logging.getLogger(self.name)
|
||||
self._logger.setLevel(self._log_level)
|
||||
self._logger.handlers.clear()
|
||||
self._logger.propagate = True # 通过 root logger 输出到控制台
|
||||
|
||||
self._add_file_handler(formatter)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user