feat: 集成所有服务到主程序
- 使用 lifespan 管理服务生命周期 - 启动时自动连接 MQTT 并订阅告警主题 - 新增 WebSocket 端点 /ws/alerts - 新增设备管理 API /api/v1/devices - 新增 MQTT 状态 API /api/v1/mqtt/statistics - 增强健康检查返回 MQTT 和 WebSocket 状态 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
191
app/main.py
191
app/main.py
@@ -1,4 +1,6 @@
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, Query
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from typing import Optional
|
||||
@@ -14,17 +16,80 @@ from app.schemas import (
|
||||
AlertHandleRequest,
|
||||
AlertStatisticsResponse,
|
||||
HealthResponse,
|
||||
DeviceResponse,
|
||||
DeviceListResponse,
|
||||
DeviceStatisticsResponse,
|
||||
)
|
||||
from app.services.alert_service import get_alert_service
|
||||
from app.services.alert_service import alert_service, get_alert_service
|
||||
from app.services.ai_analyzer import trigger_async_analysis
|
||||
from app.services.mqtt_service import get_mqtt_service
|
||||
from app.services.notification_service import get_notification_service
|
||||
from app.services.device_service import get_device_service
|
||||
from app.utils.logger import logger
|
||||
import json
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
mqtt_service = get_mqtt_service()
|
||||
notification_service = get_notification_service()
|
||||
device_service = get_device_service()
|
||||
|
||||
|
||||
def handle_mqtt_alert(payload: dict):
|
||||
"""处理 MQTT 告警消息"""
|
||||
try:
|
||||
alert = alert_service.create_alert_from_mqtt(payload)
|
||||
if alert:
|
||||
# 通过 WebSocket 推送新告警
|
||||
notification_service.notify_sync("new_alert", alert.to_dict())
|
||||
logger.info(f"MQTT 告警已处理并推送: {alert.alert_no}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MQTT 告警失败: {e}")
|
||||
|
||||
|
||||
def handle_mqtt_heartbeat(payload: dict):
|
||||
"""处理 MQTT 心跳消息"""
|
||||
try:
|
||||
device = device_service.handle_heartbeat(payload)
|
||||
if device:
|
||||
# 通过 WebSocket 推送设备状态
|
||||
notification_service.notify_sync("device_status", device.to_dict())
|
||||
except Exception as e:
|
||||
logger.error(f"处理 MQTT 心跳失败: {e}")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
# 启动
|
||||
init_db()
|
||||
logger.info("数据库初始化完成")
|
||||
|
||||
# 设置事件循环(用于从 MQTT 回调调用异步方法)
|
||||
loop = asyncio.get_event_loop()
|
||||
notification_service.set_event_loop(loop)
|
||||
|
||||
# 注册 MQTT 处理器
|
||||
mqtt_service.register_alert_handler(handle_mqtt_alert)
|
||||
mqtt_service.register_heartbeat_handler(handle_mqtt_heartbeat)
|
||||
|
||||
# 启动 MQTT 服务
|
||||
mqtt_service.start()
|
||||
|
||||
logger.info("AI 告警平台启动完成")
|
||||
|
||||
yield
|
||||
|
||||
# 关闭
|
||||
mqtt_service.stop()
|
||||
logger.info("AI 告警平台已关闭")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="AI 告警平台",
|
||||
description="接收边缘端告警,提供告警查询与处理能力",
|
||||
version="1.0.0",
|
||||
description="接收边缘端告警,提供告警查询与处理能力,支持 WebSocket 实时推送",
|
||||
version="2.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -36,29 +101,34 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
def get_alert_service():
|
||||
return get_alert_service()
|
||||
def get_alert_svc():
|
||||
return alert_service
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
init_db()
|
||||
logger.info("AI 告警平台启动")
|
||||
def get_device_svc():
|
||||
return device_service
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
from sqlalchemy import text
|
||||
|
||||
db_status = "healthy"
|
||||
try:
|
||||
engine = get_engine()
|
||||
with engine.connect() as conn:
|
||||
conn.execute("SELECT 1")
|
||||
conn.execute(text("SELECT 1"))
|
||||
except Exception as e:
|
||||
db_status = f"unhealthy: {e}"
|
||||
|
||||
mqtt_stats = mqtt_service.get_statistics()
|
||||
mqtt_status = "connected" if mqtt_stats["connected"] else "disconnected"
|
||||
|
||||
return HealthResponse(
|
||||
status="healthy" if db_status == "healthy" else "degraded",
|
||||
status="healthy" if db_status == "healthy" and mqtt_stats["connected"] else "degraded",
|
||||
database=db_status,
|
||||
mqtt=mqtt_status,
|
||||
websocket_connections=notification_service.manager.connection_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -66,7 +136,7 @@ async def health_check():
|
||||
async def create_alert(
|
||||
data: str = Form(..., description="JSON数据"),
|
||||
snapshot: Optional[UploadFile] = File(None),
|
||||
service=Depends(get_alert_service),
|
||||
service=Depends(get_alert_svc),
|
||||
):
|
||||
try:
|
||||
alert_data = json.loads(data)
|
||||
@@ -105,7 +175,7 @@ async def list_alerts(
|
||||
end_time: Optional[datetime] = Query(None, description="结束时间"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
service=Depends(get_alert_service),
|
||||
service=Depends(get_alert_svc),
|
||||
):
|
||||
alerts, total = service.get_alerts(
|
||||
camera_id=camera_id,
|
||||
@@ -125,10 +195,18 @@ async def list_alerts(
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/v1/alerts/statistics", response_model=AlertStatisticsResponse)
|
||||
async def get_statistics(
|
||||
service=Depends(get_alert_svc),
|
||||
):
|
||||
stats = service.get_statistics()
|
||||
return AlertStatisticsResponse(**stats)
|
||||
|
||||
|
||||
@app.get("/api/v1/alerts/{alert_id}", response_model=AlertResponse)
|
||||
async def get_alert(
|
||||
alert_id: int,
|
||||
service=Depends(get_alert_service),
|
||||
service=Depends(get_alert_svc),
|
||||
):
|
||||
alert = service.get_alert(alert_id)
|
||||
if not alert:
|
||||
@@ -141,7 +219,7 @@ async def handle_alert(
|
||||
alert_id: int,
|
||||
handle_data: AlertHandleRequest,
|
||||
handled_by: Optional[str] = Query(None, description="处理人"),
|
||||
service=Depends(get_alert_service),
|
||||
service=Depends(get_alert_svc),
|
||||
):
|
||||
alert = service.handle_alert(alert_id, handle_data, handled_by)
|
||||
if not alert:
|
||||
@@ -149,18 +227,10 @@ async def handle_alert(
|
||||
return AlertResponse(**alert.to_dict())
|
||||
|
||||
|
||||
@app.get("/api/v1/alerts/statistics", response_model=AlertStatisticsResponse)
|
||||
async def get_statistics(
|
||||
service=Depends(get_alert_service),
|
||||
):
|
||||
stats = service.get_statistics()
|
||||
return AlertStatisticsResponse(**stats)
|
||||
|
||||
|
||||
@app.get("/api/v1/alerts/{alert_id}/image")
|
||||
async def get_alert_image(
|
||||
alert_id: int,
|
||||
service=Depends(get_alert_service),
|
||||
service=Depends(get_alert_svc),
|
||||
):
|
||||
alert = service.get_alert(alert_id)
|
||||
if not alert:
|
||||
@@ -172,6 +242,77 @@ async def get_alert_image(
|
||||
return FileResponse(alert.snapshot_path)
|
||||
|
||||
|
||||
# ==================== WebSocket 端点 ====================
|
||||
|
||||
@app.websocket("/ws/alerts")
|
||||
async def websocket_alerts(websocket: WebSocket):
|
||||
"""WebSocket 连接,接收实时告警和设备状态推送"""
|
||||
await notification_service.manager.connect(websocket)
|
||||
try:
|
||||
while True:
|
||||
# 保持连接,接收客户端心跳
|
||||
data = await websocket.receive_text()
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
except WebSocketDisconnect:
|
||||
notification_service.manager.disconnect(websocket)
|
||||
except Exception as e:
|
||||
logger.warning(f"WebSocket 异常: {e}")
|
||||
notification_service.manager.disconnect(websocket)
|
||||
|
||||
|
||||
# ==================== 设备管理端点 ====================
|
||||
|
||||
@app.get("/api/v1/devices", response_model=DeviceListResponse)
|
||||
async def list_devices(
|
||||
status: Optional[str] = Query(None, description="设备状态: online/offline/error"),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
service=Depends(get_device_svc),
|
||||
):
|
||||
"""获取设备列表"""
|
||||
devices, total = service.get_devices(
|
||||
status=status,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return DeviceListResponse(
|
||||
devices=[DeviceResponse(**d.to_dict()) for d in devices],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/api/v1/devices/statistics", response_model=DeviceStatisticsResponse)
|
||||
async def get_device_statistics(
|
||||
service=Depends(get_device_svc),
|
||||
):
|
||||
"""获取设备统计"""
|
||||
stats = service.get_statistics()
|
||||
return DeviceStatisticsResponse(**stats)
|
||||
|
||||
|
||||
@app.get("/api/v1/devices/{device_id}", response_model=DeviceResponse)
|
||||
async def get_device(
|
||||
device_id: str,
|
||||
service=Depends(get_device_svc),
|
||||
):
|
||||
"""获取设备详情"""
|
||||
device = service.get_device(device_id)
|
||||
if not device:
|
||||
raise HTTPException(status_code=404, detail="设备不存在")
|
||||
return DeviceResponse(**device.to_dict())
|
||||
|
||||
|
||||
# ==================== MQTT 状态端点 ====================
|
||||
|
||||
@app.get("/api/v1/mqtt/statistics")
|
||||
async def get_mqtt_statistics():
|
||||
"""获取 MQTT 服务统计"""
|
||||
return mqtt_service.get_statistics()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
|
||||
Reference in New Issue
Block a user