Files
security-ai-edge/main.py
16337 b0ddb6ee1a feat(project): move edge_inference_service contents to root and update paths
- Moved all project files and directories (config, core, models, etc.) from
  edge_inference_service/ to the repository root ai_edge/
- Updated model path in config/settings.py to reflect new structure
- Revised usage paths in __init__.py documentation
2026-01-29 18:43:19 +08:00

413 lines
14 KiB
Python

"""
主入口模块
整合所有模块,启动推理服务
"""
import logging
import os
import sys
import threading
import signal
import time
from datetime import datetime
from typing import Dict, Any, Optional
from config.settings import get_settings, Settings
from config.database import init_database
from core.config_sync import get_config_sync_manager, ConfigSyncManager
from core.video_stream import MultiStreamManager, VideoFrame
from core.preprocessor import ImagePreprocessor
from core.tensorrt_engine import TensorRTEngine, EngineManager
from core.postprocessor import PostProcessor
from core.result_reporter import ResultReporter
from utils.logger import get_logger, StructuredLogger
from utils.version_control import get_version_control
logger = logging.getLogger(__name__)
class EdgeInferenceService:
"""边缘推理服务主类
整合所有模块,提供完整的推理服务
"""
def __init__(self):
self._running = False
self._settings = get_settings()
self._logger = get_logger("main")
self._version_control = get_version_control()
self._db_manager = None
self._config_manager: Optional[ConfigSyncManager] = None
self._stream_manager: Optional[MultiStreamManager] = None
self._preprocessor: Optional[ImagePreprocessor] = None
self._engine_manager: Optional[EngineManager] = None
self._postprocessor: Optional[PostProcessor] = None
self._reporter: Optional[ResultReporter] = None
self._processing_threads: Dict[str, threading.Thread] = {}
self._stop_event = threading.Event()
self._performance_stats = {
"start_time": None,
"total_frames_processed": 0,
"total_alerts_generated": 0,
"uptime_seconds": 0,
}
self._logger.info("Edge_Inference_Service 初始化开始")
def _init_database(self):
"""初始化数据库"""
try:
from config.database import DatabaseManager
self._db_manager = DatabaseManager()
if self._db_manager.is_available:
self._logger.info("数据库初始化成功")
else:
self._logger.warning("数据库不可见,服务将在无数据库模式下运行")
except Exception as e:
self._logger.warning(f"数据库初始化失败,服务将在无数据库模式下运行: {e}")
self._db_manager = None
def _init_config_manager(self):
"""初始化配置管理器"""
try:
self._config_manager = get_config_sync_manager()
self._config_manager.start_config_subscription()
self._logger.info("配置管理器初始化成功")
except Exception as e:
self._logger.error(f"配置管理器初始化失败: {e}")
raise
def _init_stream_manager(self):
"""初始化流管理器"""
self._stream_manager = MultiStreamManager()
self._logger.info("流管理器初始化成功")
def _init_preprocessor(self):
"""初始化预处理器"""
self._preprocessor = ImagePreprocessor()
self._logger.info("预处理器初始化成功")
def _init_engine(self):
"""初始化推理引擎"""
try:
self._engine_manager = EngineManager()
engine_path = self._settings.inference.model_path
if os.path.exists(engine_path):
self._engine_manager.load_engine("default", engine_path)
self._logger.info(f"推理引擎加载成功: {engine_path}")
else:
self._logger.warning(f"引擎文件不存在: {engine_path}")
except Exception as e:
self._logger.error(f"推理引擎初始化失败: {e}")
raise
def _init_postprocessor(self):
"""初始化后处理器"""
self._postprocessor = PostProcessor()
self._logger.info("后处理器初始化成功")
def _init_reporter(self):
"""初始化结果上报器"""
self._reporter = ResultReporter()
self._logger.info("结果上报器初始化成功")
def initialize(self):
"""初始化所有组件"""
self._logger.info("=" * 50)
self._logger.info("Edge_Inference_Service 启动")
self._logger.info("=" * 50)
self._init_database()
self._init_config_manager()
self._init_stream_manager()
self._init_preprocessor()
self._init_engine()
self._init_postprocessor()
self._init_reporter()
self._performance_stats["start_time"] = datetime.now()
self._version_control.record_update(
version="1.0.0",
update_type="启动",
description="Edge_Inference_Service 启动运行",
updated_by="系统",
affected_items=["全局"],
)
self._logger.info("所有组件初始化完成")
def _load_cameras(self):
"""加载摄像头配置"""
cameras = self._config_manager.get_cameras()
for camera in cameras:
try:
self._stream_manager.add_stream(
camera_id=camera.camera_id,
rtsp_url=camera.rtsp_url,
target_fps=self._settings.video_stream.default_fps,
on_frame_callback=self._create_frame_callback(camera.camera_id)
)
self._logger.info(f"已添加摄像头: {camera.camera_id}")
except Exception as e:
self._logger.error(f"添加摄像头失败 {camera.camera_id}: {e}")
def _create_frame_callback(self, camera_id: str):
"""创建帧处理回调"""
def callback(frame):
self._process_frame(camera_id, frame)
return callback
def _process_frame(self, camera_id: str, frame: VideoFrame):
"""处理视频帧"""
try:
start_time = time.perf_counter()
roi_configs = self._config_manager.get_roi_configs(camera_id)
for roi in roi_configs:
self._process_roi_frame(camera_id, frame, roi)
processing_time_ms = (time.perf_counter() - start_time) * 1000
self._performance_stats["total_frames_processed"] += 1
self._logger.log_inference_latency(
processing_time_ms,
batch_size=1
)
except Exception as e:
self._logger.error(f"处理帧失败 {camera_id}: {e}")
def _process_roi_frame(
self,
camera_id: str,
frame: VideoFrame,
roi
):
"""处理ROI帧"""
try:
cropped = self._preprocessor.preprocess_single(frame.image, roi)
processed_image, scale_info = cropped
batch_data = self._preprocessor._batch_preprocessor._stack_and_normalize(
[processed_image]
)
engine = self._engine_manager.get_engine("default")
if engine is None:
return
outputs, inference_time_ms = engine.infer(batch_data)
boxes, scores, class_ids = self._postprocessor.process_detections(
outputs,
conf_threshold=self._settings.inference.conf_threshold
)
if len(boxes) > 0:
self._handle_detections(
camera_id, roi.roi_id, frame,
boxes, scores, class_ids,
scale_info
)
except Exception as e:
self._logger.error(f"处理ROI帧失败: {e}")
def _handle_detections(
self,
camera_id: str,
roi_id: str,
frame: VideoFrame,
boxes: any,
scores: any,
class_ids: any,
scale_info: tuple
):
"""处理检测结果"""
try:
from config.config_models import AlgorithmType
mapped_boxes = self._postprocessor.map_coordinates(
boxes, scale_info,
(frame.width, frame.height)
)
for i, box in enumerate(mapped_boxes):
detection_result = {
"class_id": int(class_ids[i]) if len(class_ids) > 0 else 0,
"confidence": float(scores[i]),
"bbox": box,
}
alert_result = self._postprocessor.check_alarm_condition(
roi_id, True, frame.timestamp
)
if alert_result["should_alert"]:
self._performance_stats["total_alerts_generated"] += 1
screenshot = frame.image
self._reporter.report_detection_alert(
camera_id=camera_id,
roi_id=roi_id,
alert_type="detection",
detection={
"class_name": f"class_{detection_result['class_id']}",
"confidence": detection_result["confidence"],
"bbox": detection_result["bbox"],
"message": f"检测到目标"
},
screenshot=screenshot
)
self._logger.log_alert(
"detection",
camera_id,
roi_id,
detection_result["confidence"]
)
except Exception as e:
self._logger.error(f"处理检测结果失败: {e}")
def start(self):
"""启动服务"""
if self._running:
return
self._running = True
self._stop_event.clear()
self._load_cameras()
self._stream_manager.start_all()
self._logger.info("Edge_Inference_Service 已启动")
self._start_heartbeat_thread()
self._register_signal_handlers()
self._wait_for_shutdown()
def _start_heartbeat_thread(self):
"""启动心跳线程"""
def heartbeat():
while not self._stop_event.is_set():
try:
uptime = (datetime.now() - self._performance_stats["start_time"]).total_seconds()
self._performance_stats["uptime_seconds"] = uptime
status = {
"running": True,
"uptime_seconds": uptime,
"frames_processed": self._performance_stats["total_frames_processed"],
"alerts_generated": self._performance_stats["total_alerts_generated"],
"stream_stats": self._stream_manager.get_statistics() if self._stream_manager else {},
}
if self._reporter:
self._reporter.report_heartbeat("edge_inference_device", status)
except Exception as e:
self._logger.error(f"心跳上报失败: {e}")
time.sleep(30)
thread = threading.Thread(target=heartbeat, name="Heartbeat", daemon=True)
thread.start()
def _register_signal_handlers(self):
"""注册信号处理器"""
def handle_signal(signum, frame):
self._logger.info(f"收到信号 {signum}, 正在停止服务...")
self.stop()
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
def _wait_for_shutdown(self):
"""等待关闭信号"""
while not self._stop_event.is_set():
time.sleep(1)
def stop(self):
"""停止服务"""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._stream_manager:
self._stream_manager.stop_all()
self._stream_manager.close()
if self._engine_manager:
self._engine_manager.release_all()
if self._config_manager:
self._config_manager.stop_config_subscription()
self._config_manager.close()
if self._reporter:
self._reporter.close()
self._performance_stats["uptime_seconds"] = (
(datetime.now() - self._performance_stats["start_time"]).total_seconds()
)
self._logger.info("Edge_Inference_Service 已停止")
self._logger.info(f"运行统计: {self._performance_stats}")
def get_status(self) -> Dict[str, Any]:
"""获取服务状态"""
return {
"running": self._running,
"start_time": (
self._performance_stats["start_time"].isoformat()
if self._performance_stats["start_time"] else None
),
"uptime_seconds": self._performance_stats["uptime_seconds"],
"total_frames_processed": self._performance_stats["total_frames_processed"],
"total_alerts_generated": self._performance_stats["total_alerts_generated"],
"stream_manager": (
self._stream_manager.get_statistics()
if self._stream_manager else {}
),
"config_version": (
self._config_manager.config_version
if self._config_manager else None
),
}
def main():
"""主函数入口"""
service = EdgeInferenceService()
try:
service.initialize()
service.start()
except KeyboardInterrupt:
service.stop()
except Exception as e:
logger.error(f"服务异常: {e}")
raise
if __name__ == "__main__":
main()