From 885665ecf095e828d7a553609e6f35e3214394b4 Mon Sep 17 00:00:00 2001 From: 16337 <1633794139@qq.com> Date: Thu, 5 Feb 2026 13:57:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=9B=86=E6=88=90=E6=89=80=E6=9C=89?= =?UTF-8?q?=E6=9C=8D=E5=8A=A1=E5=88=B0=E4=B8=BB=E7=A8=8B=E5=BA=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用 lifespan 管理服务生命周期 - 启动时自动连接 MQTT 并订阅告警主题 - 新增 WebSocket 端点 /ws/alerts - 新增设备管理 API /api/v1/devices - 新增 MQTT 状态 API /api/v1/mqtt/statistics - 增强健康检查返回 MQTT 和 WebSocket 状态 Co-Authored-By: Claude Opus 4.5 --- app/main.py | 191 +++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 166 insertions(+), 25 deletions(-) diff --git a/app/main.py b/app/main.py index f146a28..1d66f74 100644 --- a/app/main.py +++ b/app/main.py @@ -1,4 +1,6 @@ -from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, Query +import asyncio +from contextlib import asynccontextmanager +from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from typing import Optional @@ -14,17 +16,80 @@ from app.schemas import ( AlertHandleRequest, AlertStatisticsResponse, HealthResponse, + DeviceResponse, + DeviceListResponse, + DeviceStatisticsResponse, ) -from app.services.alert_service import get_alert_service +from app.services.alert_service import alert_service, get_alert_service from app.services.ai_analyzer import trigger_async_analysis +from app.services.mqtt_service import get_mqtt_service +from app.services.notification_service import get_notification_service +from app.services.device_service import get_device_service from app.utils.logger import logger import json +# 全局服务实例 +mqtt_service = get_mqtt_service() +notification_service = get_notification_service() +device_service = get_device_service() + + +def handle_mqtt_alert(payload: dict): + """处理 MQTT 告警消息""" + try: + alert = alert_service.create_alert_from_mqtt(payload) + if alert: + # 通过 WebSocket 推送新告警 + notification_service.notify_sync("new_alert", alert.to_dict()) + logger.info(f"MQTT 告警已处理并推送: {alert.alert_no}") + except Exception as e: + logger.error(f"处理 MQTT 告警失败: {e}") + + +def handle_mqtt_heartbeat(payload: dict): + """处理 MQTT 心跳消息""" + try: + device = device_service.handle_heartbeat(payload) + if device: + # 通过 WebSocket 推送设备状态 + notification_service.notify_sync("device_status", device.to_dict()) + except Exception as e: + logger.error(f"处理 MQTT 心跳失败: {e}") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + # 启动 + init_db() + logger.info("数据库初始化完成") + + # 设置事件循环(用于从 MQTT 回调调用异步方法) + loop = asyncio.get_event_loop() + notification_service.set_event_loop(loop) + + # 注册 MQTT 处理器 + mqtt_service.register_alert_handler(handle_mqtt_alert) + mqtt_service.register_heartbeat_handler(handle_mqtt_heartbeat) + + # 启动 MQTT 服务 + mqtt_service.start() + + logger.info("AI 告警平台启动完成") + + yield + + # 关闭 + mqtt_service.stop() + logger.info("AI 告警平台已关闭") + + app = FastAPI( title="AI 告警平台", - description="接收边缘端告警,提供告警查询与处理能力", - version="1.0.0", + description="接收边缘端告警,提供告警查询与处理能力,支持 WebSocket 实时推送", + version="2.0.0", + lifespan=lifespan, ) app.add_middleware( @@ -36,29 +101,34 @@ app.add_middleware( ) -def get_alert_service(): - return get_alert_service() +def get_alert_svc(): + return alert_service -@app.on_event("startup") -async def startup(): - init_db() - logger.info("AI 告警平台启动") +def get_device_svc(): + return device_service @app.get("/health", response_model=HealthResponse) async def health_check(): + from sqlalchemy import text + db_status = "healthy" try: engine = get_engine() with engine.connect() as conn: - conn.execute("SELECT 1") + conn.execute(text("SELECT 1")) except Exception as e: db_status = f"unhealthy: {e}" + mqtt_stats = mqtt_service.get_statistics() + mqtt_status = "connected" if mqtt_stats["connected"] else "disconnected" + return HealthResponse( - status="healthy" if db_status == "healthy" else "degraded", + status="healthy" if db_status == "healthy" and mqtt_stats["connected"] else "degraded", database=db_status, + mqtt=mqtt_status, + websocket_connections=notification_service.manager.connection_count, ) @@ -66,7 +136,7 @@ async def health_check(): async def create_alert( data: str = Form(..., description="JSON数据"), snapshot: Optional[UploadFile] = File(None), - service=Depends(get_alert_service), + service=Depends(get_alert_svc), ): try: alert_data = json.loads(data) @@ -105,7 +175,7 @@ async def list_alerts( end_time: Optional[datetime] = Query(None, description="结束时间"), page: int = Query(1, ge=1), page_size: int = Query(20, ge=1, le=100), - service=Depends(get_alert_service), + service=Depends(get_alert_svc), ): alerts, total = service.get_alerts( camera_id=camera_id, @@ -125,10 +195,18 @@ async def list_alerts( ) +@app.get("/api/v1/alerts/statistics", response_model=AlertStatisticsResponse) +async def get_statistics( + service=Depends(get_alert_svc), +): + stats = service.get_statistics() + return AlertStatisticsResponse(**stats) + + @app.get("/api/v1/alerts/{alert_id}", response_model=AlertResponse) async def get_alert( alert_id: int, - service=Depends(get_alert_service), + service=Depends(get_alert_svc), ): alert = service.get_alert(alert_id) if not alert: @@ -141,7 +219,7 @@ async def handle_alert( alert_id: int, handle_data: AlertHandleRequest, handled_by: Optional[str] = Query(None, description="处理人"), - service=Depends(get_alert_service), + service=Depends(get_alert_svc), ): alert = service.handle_alert(alert_id, handle_data, handled_by) if not alert: @@ -149,18 +227,10 @@ async def handle_alert( return AlertResponse(**alert.to_dict()) -@app.get("/api/v1/alerts/statistics", response_model=AlertStatisticsResponse) -async def get_statistics( - service=Depends(get_alert_service), -): - stats = service.get_statistics() - return AlertStatisticsResponse(**stats) - - @app.get("/api/v1/alerts/{alert_id}/image") async def get_alert_image( alert_id: int, - service=Depends(get_alert_service), + service=Depends(get_alert_svc), ): alert = service.get_alert(alert_id) if not alert: @@ -172,6 +242,77 @@ async def get_alert_image( return FileResponse(alert.snapshot_path) +# ==================== WebSocket 端点 ==================== + +@app.websocket("/ws/alerts") +async def websocket_alerts(websocket: WebSocket): + """WebSocket 连接,接收实时告警和设备状态推送""" + await notification_service.manager.connect(websocket) + try: + while True: + # 保持连接,接收客户端心跳 + data = await websocket.receive_text() + if data == "ping": + await websocket.send_text("pong") + except WebSocketDisconnect: + notification_service.manager.disconnect(websocket) + except Exception as e: + logger.warning(f"WebSocket 异常: {e}") + notification_service.manager.disconnect(websocket) + + +# ==================== 设备管理端点 ==================== + +@app.get("/api/v1/devices", response_model=DeviceListResponse) +async def list_devices( + status: Optional[str] = Query(None, description="设备状态: online/offline/error"), + page: int = Query(1, ge=1), + page_size: int = Query(20, ge=1, le=100), + service=Depends(get_device_svc), +): + """获取设备列表""" + devices, total = service.get_devices( + status=status, + page=page, + page_size=page_size, + ) + return DeviceListResponse( + devices=[DeviceResponse(**d.to_dict()) for d in devices], + total=total, + page=page, + page_size=page_size, + ) + + +@app.get("/api/v1/devices/statistics", response_model=DeviceStatisticsResponse) +async def get_device_statistics( + service=Depends(get_device_svc), +): + """获取设备统计""" + stats = service.get_statistics() + return DeviceStatisticsResponse(**stats) + + +@app.get("/api/v1/devices/{device_id}", response_model=DeviceResponse) +async def get_device( + device_id: str, + service=Depends(get_device_svc), +): + """获取设备详情""" + device = service.get_device(device_id) + if not device: + raise HTTPException(status_code=404, detail="设备不存在") + return DeviceResponse(**device.to_dict()) + + +# ==================== MQTT 状态端点 ==================== + +@app.get("/api/v1/mqtt/statistics") +async def get_mqtt_statistics(): + """获取 MQTT 服务统计""" + return mqtt_service.get_statistics() + + if __name__ == "__main__": import uvicorn uvicorn.run(