feat: 集成所有服务到主程序

- 使用 lifespan 管理服务生命周期
- 启动时自动连接 MQTT 并订阅告警主题
- 新增 WebSocket 端点 /ws/alerts
- 新增设备管理 API /api/v1/devices
- 新增 MQTT 状态 API /api/v1/mqtt/statistics
- 增强健康检查返回 MQTT 和 WebSocket 状态

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-02-05 13:57:25 +08:00
parent cd21d65b85
commit 885665ecf0

View File

@@ -1,4 +1,6 @@
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, Query import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from typing import Optional from typing import Optional
@@ -14,17 +16,80 @@ from app.schemas import (
AlertHandleRequest, AlertHandleRequest,
AlertStatisticsResponse, AlertStatisticsResponse,
HealthResponse, HealthResponse,
DeviceResponse,
DeviceListResponse,
DeviceStatisticsResponse,
) )
from app.services.alert_service import get_alert_service from app.services.alert_service import alert_service, get_alert_service
from app.services.ai_analyzer import trigger_async_analysis from app.services.ai_analyzer import trigger_async_analysis
from app.services.mqtt_service import get_mqtt_service
from app.services.notification_service import get_notification_service
from app.services.device_service import get_device_service
from app.utils.logger import logger from app.utils.logger import logger
import json import json
# 全局服务实例
mqtt_service = get_mqtt_service()
notification_service = get_notification_service()
device_service = get_device_service()
def handle_mqtt_alert(payload: dict):
"""处理 MQTT 告警消息"""
try:
alert = alert_service.create_alert_from_mqtt(payload)
if alert:
# 通过 WebSocket 推送新告警
notification_service.notify_sync("new_alert", alert.to_dict())
logger.info(f"MQTT 告警已处理并推送: {alert.alert_no}")
except Exception as e:
logger.error(f"处理 MQTT 告警失败: {e}")
def handle_mqtt_heartbeat(payload: dict):
"""处理 MQTT 心跳消息"""
try:
device = device_service.handle_heartbeat(payload)
if device:
# 通过 WebSocket 推送设备状态
notification_service.notify_sync("device_status", device.to_dict())
except Exception as e:
logger.error(f"处理 MQTT 心跳失败: {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动
init_db()
logger.info("数据库初始化完成")
# 设置事件循环(用于从 MQTT 回调调用异步方法)
loop = asyncio.get_event_loop()
notification_service.set_event_loop(loop)
# 注册 MQTT 处理器
mqtt_service.register_alert_handler(handle_mqtt_alert)
mqtt_service.register_heartbeat_handler(handle_mqtt_heartbeat)
# 启动 MQTT 服务
mqtt_service.start()
logger.info("AI 告警平台启动完成")
yield
# 关闭
mqtt_service.stop()
logger.info("AI 告警平台已关闭")
app = FastAPI( app = FastAPI(
title="AI 告警平台", title="AI 告警平台",
description="接收边缘端告警,提供告警查询与处理能力", description="接收边缘端告警,提供告警查询与处理能力,支持 WebSocket 实时推送",
version="1.0.0", version="2.0.0",
lifespan=lifespan,
) )
app.add_middleware( app.add_middleware(
@@ -36,29 +101,34 @@ app.add_middleware(
) )
def get_alert_service(): def get_alert_svc():
return get_alert_service() return alert_service
@app.on_event("startup") def get_device_svc():
async def startup(): return device_service
init_db()
logger.info("AI 告警平台启动")
@app.get("/health", response_model=HealthResponse) @app.get("/health", response_model=HealthResponse)
async def health_check(): async def health_check():
from sqlalchemy import text
db_status = "healthy" db_status = "healthy"
try: try:
engine = get_engine() engine = get_engine()
with engine.connect() as conn: with engine.connect() as conn:
conn.execute("SELECT 1") conn.execute(text("SELECT 1"))
except Exception as e: except Exception as e:
db_status = f"unhealthy: {e}" db_status = f"unhealthy: {e}"
mqtt_stats = mqtt_service.get_statistics()
mqtt_status = "connected" if mqtt_stats["connected"] else "disconnected"
return HealthResponse( return HealthResponse(
status="healthy" if db_status == "healthy" else "degraded", status="healthy" if db_status == "healthy" and mqtt_stats["connected"] else "degraded",
database=db_status, database=db_status,
mqtt=mqtt_status,
websocket_connections=notification_service.manager.connection_count,
) )
@@ -66,7 +136,7 @@ async def health_check():
async def create_alert( async def create_alert(
data: str = Form(..., description="JSON数据"), data: str = Form(..., description="JSON数据"),
snapshot: Optional[UploadFile] = File(None), snapshot: Optional[UploadFile] = File(None),
service=Depends(get_alert_service), service=Depends(get_alert_svc),
): ):
try: try:
alert_data = json.loads(data) alert_data = json.loads(data)
@@ -105,7 +175,7 @@ async def list_alerts(
end_time: Optional[datetime] = Query(None, description="结束时间"), end_time: Optional[datetime] = Query(None, description="结束时间"),
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100), page_size: int = Query(20, ge=1, le=100),
service=Depends(get_alert_service), service=Depends(get_alert_svc),
): ):
alerts, total = service.get_alerts( alerts, total = service.get_alerts(
camera_id=camera_id, camera_id=camera_id,
@@ -125,10 +195,18 @@ async def list_alerts(
) )
@app.get("/api/v1/alerts/statistics", response_model=AlertStatisticsResponse)
async def get_statistics(
service=Depends(get_alert_svc),
):
stats = service.get_statistics()
return AlertStatisticsResponse(**stats)
@app.get("/api/v1/alerts/{alert_id}", response_model=AlertResponse) @app.get("/api/v1/alerts/{alert_id}", response_model=AlertResponse)
async def get_alert( async def get_alert(
alert_id: int, alert_id: int,
service=Depends(get_alert_service), service=Depends(get_alert_svc),
): ):
alert = service.get_alert(alert_id) alert = service.get_alert(alert_id)
if not alert: if not alert:
@@ -141,7 +219,7 @@ async def handle_alert(
alert_id: int, alert_id: int,
handle_data: AlertHandleRequest, handle_data: AlertHandleRequest,
handled_by: Optional[str] = Query(None, description="处理人"), handled_by: Optional[str] = Query(None, description="处理人"),
service=Depends(get_alert_service), service=Depends(get_alert_svc),
): ):
alert = service.handle_alert(alert_id, handle_data, handled_by) alert = service.handle_alert(alert_id, handle_data, handled_by)
if not alert: if not alert:
@@ -149,18 +227,10 @@ async def handle_alert(
return AlertResponse(**alert.to_dict()) return AlertResponse(**alert.to_dict())
@app.get("/api/v1/alerts/statistics", response_model=AlertStatisticsResponse)
async def get_statistics(
service=Depends(get_alert_service),
):
stats = service.get_statistics()
return AlertStatisticsResponse(**stats)
@app.get("/api/v1/alerts/{alert_id}/image") @app.get("/api/v1/alerts/{alert_id}/image")
async def get_alert_image( async def get_alert_image(
alert_id: int, alert_id: int,
service=Depends(get_alert_service), service=Depends(get_alert_svc),
): ):
alert = service.get_alert(alert_id) alert = service.get_alert(alert_id)
if not alert: if not alert:
@@ -172,6 +242,77 @@ async def get_alert_image(
return FileResponse(alert.snapshot_path) return FileResponse(alert.snapshot_path)
# ==================== WebSocket 端点 ====================
@app.websocket("/ws/alerts")
async def websocket_alerts(websocket: WebSocket):
"""WebSocket 连接,接收实时告警和设备状态推送"""
await notification_service.manager.connect(websocket)
try:
while True:
# 保持连接,接收客户端心跳
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
notification_service.manager.disconnect(websocket)
except Exception as e:
logger.warning(f"WebSocket 异常: {e}")
notification_service.manager.disconnect(websocket)
# ==================== 设备管理端点 ====================
@app.get("/api/v1/devices", response_model=DeviceListResponse)
async def list_devices(
status: Optional[str] = Query(None, description="设备状态: online/offline/error"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
service=Depends(get_device_svc),
):
"""获取设备列表"""
devices, total = service.get_devices(
status=status,
page=page,
page_size=page_size,
)
return DeviceListResponse(
devices=[DeviceResponse(**d.to_dict()) for d in devices],
total=total,
page=page,
page_size=page_size,
)
@app.get("/api/v1/devices/statistics", response_model=DeviceStatisticsResponse)
async def get_device_statistics(
service=Depends(get_device_svc),
):
"""获取设备统计"""
stats = service.get_statistics()
return DeviceStatisticsResponse(**stats)
@app.get("/api/v1/devices/{device_id}", response_model=DeviceResponse)
async def get_device(
device_id: str,
service=Depends(get_device_svc),
):
"""获取设备详情"""
device = service.get_device(device_id)
if not device:
raise HTTPException(status_code=404, detail="设备不存在")
return DeviceResponse(**device.to_dict())
# ==================== MQTT 状态端点 ====================
@app.get("/api/v1/mqtt/statistics")
async def get_mqtt_statistics():
"""获取 MQTT 服务统计"""
return mqtt_service.get_statistics()
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run( uvicorn.run(