Files
iot-device-management-service/app/main.py
16337 885665ecf0 feat: 集成所有服务到主程序
- 使用 lifespan 管理服务生命周期
- 启动时自动连接 MQTT 并订阅告警主题
- 新增 WebSocket 端点 /ws/alerts
- 新增设备管理 API /api/v1/devices
- 新增 MQTT 状态 API /api/v1/mqtt/statistics
- 增强健康检查返回 MQTT 和 WebSocket 状态

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 13:57:25 +08:00

324 lines
9.5 KiB
Python

import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends, Query, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from typing import Optional
from datetime import datetime
from pathlib import Path
from app.config import settings
from app.models import init_db, get_engine, AlertStatus
from app.schemas import (
AlertCreate,
AlertResponse,
AlertListResponse,
AlertHandleRequest,
AlertStatisticsResponse,
HealthResponse,
DeviceResponse,
DeviceListResponse,
DeviceStatisticsResponse,
)
from app.services.alert_service import alert_service, get_alert_service
from app.services.ai_analyzer import trigger_async_analysis
from app.services.mqtt_service import get_mqtt_service
from app.services.notification_service import get_notification_service
from app.services.device_service import get_device_service
from app.utils.logger import logger
import json
# 全局服务实例
mqtt_service = get_mqtt_service()
notification_service = get_notification_service()
device_service = get_device_service()
def handle_mqtt_alert(payload: dict):
"""处理 MQTT 告警消息"""
try:
alert = alert_service.create_alert_from_mqtt(payload)
if alert:
# 通过 WebSocket 推送新告警
notification_service.notify_sync("new_alert", alert.to_dict())
logger.info(f"MQTT 告警已处理并推送: {alert.alert_no}")
except Exception as e:
logger.error(f"处理 MQTT 告警失败: {e}")
def handle_mqtt_heartbeat(payload: dict):
"""处理 MQTT 心跳消息"""
try:
device = device_service.handle_heartbeat(payload)
if device:
# 通过 WebSocket 推送设备状态
notification_service.notify_sync("device_status", device.to_dict())
except Exception as e:
logger.error(f"处理 MQTT 心跳失败: {e}")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动
init_db()
logger.info("数据库初始化完成")
# 设置事件循环(用于从 MQTT 回调调用异步方法)
loop = asyncio.get_event_loop()
notification_service.set_event_loop(loop)
# 注册 MQTT 处理器
mqtt_service.register_alert_handler(handle_mqtt_alert)
mqtt_service.register_heartbeat_handler(handle_mqtt_heartbeat)
# 启动 MQTT 服务
mqtt_service.start()
logger.info("AI 告警平台启动完成")
yield
# 关闭
mqtt_service.stop()
logger.info("AI 告警平台已关闭")
app = FastAPI(
title="AI 告警平台",
description="接收边缘端告警,提供告警查询与处理能力,支持 WebSocket 实时推送",
version="2.0.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_alert_svc():
return alert_service
def get_device_svc():
return device_service
@app.get("/health", response_model=HealthResponse)
async def health_check():
from sqlalchemy import text
db_status = "healthy"
try:
engine = get_engine()
with engine.connect() as conn:
conn.execute(text("SELECT 1"))
except Exception as e:
db_status = f"unhealthy: {e}"
mqtt_stats = mqtt_service.get_statistics()
mqtt_status = "connected" if mqtt_stats["connected"] else "disconnected"
return HealthResponse(
status="healthy" if db_status == "healthy" and mqtt_stats["connected"] else "degraded",
database=db_status,
mqtt=mqtt_status,
websocket_connections=notification_service.manager.connection_count,
)
@app.post("/api/v1/alerts", response_model=AlertResponse)
async def create_alert(
data: str = Form(..., description="JSON数据"),
snapshot: Optional[UploadFile] = File(None),
service=Depends(get_alert_svc),
):
try:
alert_data = json.loads(data)
alert_create = AlertCreate(**alert_data)
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail=f"JSON解析失败: {e}")
snapshot_data = None
if snapshot:
snapshot_data = await snapshot.read()
alert = service.create_alert(alert_create, snapshot_data)
if alert.snapshot_url and settings.ai_model.endpoint:
trigger_async_analysis(
alert_id=alert.id,
snapshot_url=alert.snapshot_url,
alert_info={
"camera_id": alert.camera_id,
"roi_id": alert.roi_id,
"alert_type": alert.alert_type,
"confidence": alert.confidence,
"duration_minutes": alert.duration_minutes,
},
)
return AlertResponse(**alert.to_dict())
@app.get("/api/v1/alerts", response_model=AlertListResponse)
async def list_alerts(
camera_id: Optional[str] = Query(None, description="摄像头ID"),
alert_type: Optional[str] = Query(None, description="告警类型"),
status: Optional[str] = Query(None, description="处理状态"),
start_time: Optional[datetime] = Query(None, description="开始时间"),
end_time: Optional[datetime] = Query(None, description="结束时间"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
service=Depends(get_alert_svc),
):
alerts, total = service.get_alerts(
camera_id=camera_id,
alert_type=alert_type,
status=status,
start_time=start_time,
end_time=end_time,
page=page,
page_size=page_size,
)
return AlertListResponse(
alerts=[AlertResponse(**a.to_dict()) for a in alerts],
total=total,
page=page,
page_size=page_size,
)
@app.get("/api/v1/alerts/statistics", response_model=AlertStatisticsResponse)
async def get_statistics(
service=Depends(get_alert_svc),
):
stats = service.get_statistics()
return AlertStatisticsResponse(**stats)
@app.get("/api/v1/alerts/{alert_id}", response_model=AlertResponse)
async def get_alert(
alert_id: int,
service=Depends(get_alert_svc),
):
alert = service.get_alert(alert_id)
if not alert:
raise HTTPException(status_code=404, detail="告警不存在")
return AlertResponse(**alert.to_dict())
@app.put("/api/v1/alerts/{alert_id}/handle", response_model=AlertResponse)
async def handle_alert(
alert_id: int,
handle_data: AlertHandleRequest,
handled_by: Optional[str] = Query(None, description="处理人"),
service=Depends(get_alert_svc),
):
alert = service.handle_alert(alert_id, handle_data, handled_by)
if not alert:
raise HTTPException(status_code=404, detail="告警不存在")
return AlertResponse(**alert.to_dict())
@app.get("/api/v1/alerts/{alert_id}/image")
async def get_alert_image(
alert_id: int,
service=Depends(get_alert_svc),
):
alert = service.get_alert(alert_id)
if not alert:
raise HTTPException(status_code=404, detail="告警不存在")
if not alert.snapshot_path:
raise HTTPException(status_code=404, detail="图片不存在")
return FileResponse(alert.snapshot_path)
# ==================== WebSocket 端点 ====================
@app.websocket("/ws/alerts")
async def websocket_alerts(websocket: WebSocket):
"""WebSocket 连接,接收实时告警和设备状态推送"""
await notification_service.manager.connect(websocket)
try:
while True:
# 保持连接,接收客户端心跳
data = await websocket.receive_text()
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
notification_service.manager.disconnect(websocket)
except Exception as e:
logger.warning(f"WebSocket 异常: {e}")
notification_service.manager.disconnect(websocket)
# ==================== 设备管理端点 ====================
@app.get("/api/v1/devices", response_model=DeviceListResponse)
async def list_devices(
status: Optional[str] = Query(None, description="设备状态: online/offline/error"),
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
service=Depends(get_device_svc),
):
"""获取设备列表"""
devices, total = service.get_devices(
status=status,
page=page,
page_size=page_size,
)
return DeviceListResponse(
devices=[DeviceResponse(**d.to_dict()) for d in devices],
total=total,
page=page,
page_size=page_size,
)
@app.get("/api/v1/devices/statistics", response_model=DeviceStatisticsResponse)
async def get_device_statistics(
service=Depends(get_device_svc),
):
"""获取设备统计"""
stats = service.get_statistics()
return DeviceStatisticsResponse(**stats)
@app.get("/api/v1/devices/{device_id}", response_model=DeviceResponse)
async def get_device(
device_id: str,
service=Depends(get_device_svc),
):
"""获取设备详情"""
device = service.get_device(device_id)
if not device:
raise HTTPException(status_code=404, detail="设备不存在")
return DeviceResponse(**device.to_dict())
# ==================== MQTT 状态端点 ====================
@app.get("/api/v1/mqtt/statistics")
async def get_mqtt_statistics():
"""获取 MQTT 服务统计"""
return mqtt_service.get_statistics()
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host=settings.app.host,
port=settings.app.port,
reload=settings.app.debug,
)