Files
security-ai-edge/main.py

470 lines
16 KiB
Python
Raw Normal View History

"""
主入口模块
整合所有模块启动推理服务
"""
2026-01-29 18:33:12 +08:00
import logging
import os
import sys
import threading
import signal
import time
from datetime import datetime
from typing import Dict, Any, Optional
2026-01-29 18:33:12 +08:00
from config.settings import get_settings, Settings
from core.config_sync import get_config_sync_manager, ConfigSyncManager
from core.video_stream import MultiStreamManager, VideoFrame
from core.preprocessor import ImagePreprocessor
from core.tensorrt_engine import TensorRTEngine, EngineManager
from core.postprocessor import PostProcessor
from core.result_reporter import ResultReporter
from algorithms import AlgorithmManager
from utils.logger import get_logger, StructuredLogger
from utils.version_control import get_version_control
2026-01-29 18:33:12 +08:00
logger = logging.getLogger(__name__)
2026-01-29 18:33:12 +08:00
class EdgeInferenceService:
"""边缘推理服务主类
整合所有模块提供完整的推理服务
"""
def __init__(self):
self._running = False
self._settings = get_settings()
self._logger = get_logger("main")
self._version_control = get_version_control()
self._db_manager = None
self._config_manager: Optional[ConfigSyncManager] = None
self._stream_manager: Optional[MultiStreamManager] = None
self._preprocessor: Optional[ImagePreprocessor] = None
self._engine_manager: Optional[EngineManager] = None
self._postprocessor: Optional[PostProcessor] = None
self._reporter: Optional[ResultReporter] = None
self._algorithm_manager: Optional[AlgorithmManager] = None
self._processing_threads: Dict[str, threading.Thread] = {}
self._stop_event = threading.Event()
self._performance_stats = {
"start_time": None,
"total_frames_processed": 0,
"total_alerts_generated": 0,
"uptime_seconds": 0,
}
self._logger.info("Edge_Inference_Service 初始化开始")
def _init_database(self):
"""初始化数据库"""
try:
from config.database import SQLiteManager
self._db_manager = SQLiteManager()
if self._db_manager._conn:
self._logger.info("数据库初始化成功")
else:
self._logger.warning("数据库不可见,服务将在无数据库模式下运行")
except Exception as e:
self._logger.warning(f"数据库初始化失败,服务将在无数据库模式下运行: {e}")
self._db_manager = None
def _init_config_manager(self):
"""初始化配置管理器"""
try:
self._config_manager = get_config_sync_manager()
self._config_manager.start_config_subscription()
self._logger.info("配置管理器初始化成功")
except Exception as e:
self._logger.error(f"配置管理器初始化失败: {e}")
raise
def _init_stream_manager(self):
"""初始化流管理器"""
self._stream_manager = MultiStreamManager()
self._logger.info("流管理器初始化成功")
def _init_preprocessor(self):
"""初始化预处理器"""
self._preprocessor = ImagePreprocessor()
self._logger.info("预处理器初始化成功")
def _init_engine(self):
"""初始化推理引擎"""
try:
self._engine_manager = EngineManager()
engine_path = self._settings.inference.model_path
if os.path.exists(engine_path):
self._engine_manager.load_engine("default", engine_path)
self._logger.info(f"推理引擎加载成功: {engine_path}")
else:
self._logger.warning(f"引擎文件不存在: {engine_path}")
except Exception as e:
self._logger.error(f"推理引擎初始化失败: {e}")
raise
def _init_postprocessor(self):
"""初始化后处理器"""
self._postprocessor = PostProcessor()
self._logger.info("后处理器初始化成功")
def _init_reporter(self):
"""初始化结果上报器"""
self._reporter = ResultReporter()
self._logger.info("结果上报器初始化成功")
def _init_algorithm_manager(self):
"""初始化算法管理器"""
try:
self._algorithm_manager = AlgorithmManager()
self._algorithm_manager.start_config_subscription()
self._logger.info("算法管理器初始化成功")
except Exception as e:
self._logger.error(f"算法管理器初始化失败: {e}")
def initialize(self):
"""初始化所有组件"""
self._logger.info("=" * 50)
self._logger.info("Edge_Inference_Service 启动")
self._logger.info("=" * 50)
self._init_database()
self._init_config_manager()
self._init_stream_manager()
self._init_preprocessor()
self._init_engine()
self._init_postprocessor()
self._init_reporter()
self._init_algorithm_manager()
self._performance_stats["start_time"] = datetime.now()
self._version_control.record_update(
version="1.0.0",
update_type="启动",
description="Edge_Inference_Service 启动运行",
updated_by="系统",
affected_items=["全局"],
)
self._logger.info("所有组件初始化完成")
def _load_cameras(self):
"""加载摄像头配置"""
cameras = self._config_manager.get_cameras()
for camera in cameras:
try:
self._stream_manager.add_stream(
camera_id=camera.camera_id,
rtsp_url=camera.rtsp_url,
target_fps=self._settings.video_stream.default_fps,
on_frame_callback=self._create_frame_callback(camera.camera_id)
)
self._logger.info(f"已添加摄像头: {camera.camera_id}")
except Exception as e:
self._logger.error(f"添加摄像头失败 {camera.camera_id}: {e}")
def _create_frame_callback(self, camera_id: str):
"""创建帧处理回调"""
def callback(frame):
self._process_frame(camera_id, frame)
return callback
def _process_frame(self, camera_id: str, frame: VideoFrame):
"""处理视频帧"""
try:
start_time = time.perf_counter()
roi_configs = self._config_manager.get_roi_configs(camera_id)
for roi in roi_configs:
self._process_roi_frame(camera_id, frame, roi)
processing_time_ms = (time.perf_counter() - start_time) * 1000
self._performance_stats["total_frames_processed"] += 1
self._logger.log_inference_latency(
processing_time_ms,
batch_size=1
)
except Exception as e:
self._logger.error(f"处理帧失败 {camera_id}: {e}")
def _process_roi_frame(
self,
camera_id: str,
frame: VideoFrame,
roi
):
"""处理ROI帧"""
try:
if not roi.enabled:
return
cropped = self._preprocessor.preprocess_single(frame.image, roi)
processed_image, scale_info = cropped
batch_data = self._preprocessor._batch_preprocessor._stack_and_normalize(
[processed_image]
)
engine = self._engine_manager.get_engine("default")
if engine is None:
return
outputs, inference_time_ms = engine.infer(batch_data)
boxes, scores, class_ids = self._postprocessor.process_detections(
outputs,
conf_threshold=self._settings.inference.conf_threshold
)
if len(boxes) > 0:
self._handle_detections(
camera_id, roi, frame,
boxes, scores, class_ids,
scale_info
)
except Exception as e:
self._logger.error(f"处理ROI帧失败: {e}")
def _build_tracks(
self,
roi,
boxes: any,
scores: any,
class_ids: any,
scale_info: tuple
) -> list:
"""将检测结果转换为算法所需的 tracks 格式"""
tracks = []
class_names = getattr(self._settings, 'class_names', ['person'])
for i, box in enumerate(boxes):
class_id = int(class_ids[i]) if class_ids[i] else 0
track = {
"track_id": f"{roi.roi_id}_{i}",
"class": class_names[class_id] if class_id < len(class_names) else f"class_{class_id}",
"confidence": float(scores[i]),
"bbox": box,
"matched_rois": [{"roi_id": roi.roi_id}],
}
tracks.append(track)
return tracks
def _handle_detections(
self,
camera_id: str,
roi,
frame: VideoFrame,
boxes: any,
scores: any,
class_ids: any,
scale_info: tuple
):
"""处理检测结果 - 算法接管判断权"""
try:
if self._algorithm_manager is None:
self._logger.warning("算法管理器不可用,跳过算法处理")
return
roi_id = roi.roi_id
algorithm_type = roi.algorithm_type
algo_type_str = algorithm_type.value if hasattr(algorithm_type, 'value') else algorithm_type
self._algorithm_manager.register_algorithm(
roi_id=roi_id,
algorithm_type=algo_type_str,
params={
"working_hours": roi.working_hours,
"confirm_on_duty_sec": roi.confirm_on_duty_sec,
"confirm_leave_sec": roi.confirm_leave_sec,
"cooldown_sec": roi.cooldown_sec,
"target_class": roi.target_class,
}
)
tracks = self._build_tracks(roi, boxes, scores, class_ids, scale_info)
if not tracks:
return
alerts = self._algorithm_manager.process(
roi_id=roi_id,
camera_id=camera_id,
algorithm_type=algo_type_str,
tracks=tracks,
current_time=frame.timestamp
)
for alert in alerts:
self._performance_stats["total_alerts_generated"] += 1
from core.result_reporter import AlertInfo
alert_info = AlertInfo(
alert_id=f"{roi_id}_{int(frame.timestamp.timestamp())}",
camera_id=camera_id,
roi_id=roi_id,
alert_type=alert.get("alert_type", "detection"),
target_class=alert.get("class", "unknown"),
confidence=alert.get("confidence", 1.0),
bbox=alert.get("bbox", []),
message=alert.get("message", ""),
timestamp=frame.timestamp,
duration_minutes=alert.get("duration_minutes"),
)
self._reporter.report_alert(alert_info, screenshot=frame.image)
self._logger.log_alert(
alert.get("alert_type", "detection"),
camera_id,
roi_id,
alert.get("confidence", 1.0)
)
except Exception as e:
self._logger.error(f"处理检测结果失败: {e}")
def start(self):
"""启动服务"""
if self._running:
return
self._running = True
self._stop_event.clear()
self._load_cameras()
self._stream_manager.start_all()
self._logger.info("Edge_Inference_Service 已启动")
self._start_heartbeat_thread()
self._register_signal_handlers()
self._wait_for_shutdown()
def _start_heartbeat_thread(self):
"""启动心跳线程"""
def heartbeat():
while not self._stop_event.is_set():
try:
uptime = (datetime.now() - self._performance_stats["start_time"]).total_seconds()
self._performance_stats["uptime_seconds"] = uptime
status = {
"running": True,
"uptime_seconds": uptime,
"frames_processed": self._performance_stats["total_frames_processed"],
"alerts_generated": self._performance_stats["total_alerts_generated"],
"stream_stats": self._stream_manager.get_statistics() if self._stream_manager else {},
}
if self._reporter:
self._reporter.report_heartbeat("edge_inference_device", status)
except Exception as e:
self._logger.error(f"心跳上报失败: {e}")
time.sleep(30)
thread = threading.Thread(target=heartbeat, name="Heartbeat", daemon=True)
thread.start()
def _register_signal_handlers(self):
"""注册信号处理器"""
def handle_signal(signum, frame):
self._logger.info(f"收到信号 {signum}, 正在停止服务...")
self.stop()
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
def _wait_for_shutdown(self):
"""等待关闭信号"""
while not self._stop_event.is_set():
time.sleep(1)
def stop(self):
"""停止服务"""
if not self._running:
return
self._running = False
self._stop_event.set()
if self._stream_manager:
self._stream_manager.stop_all()
self._stream_manager.close()
if self._engine_manager:
self._engine_manager.release_all()
if self._config_manager:
self._config_manager.stop_config_subscription()
self._config_manager.close()
if self._algorithm_manager:
self._algorithm_manager.stop_config_subscription()
if self._reporter:
self._reporter.close()
self._performance_stats["uptime_seconds"] = (
(datetime.now() - self._performance_stats["start_time"]).total_seconds()
)
self._logger.info("Edge_Inference_Service 已停止")
self._logger.info(f"运行统计: {self._performance_stats}")
def get_status(self) -> Dict[str, Any]:
"""获取服务状态"""
return {
"running": self._running,
"start_time": (
self._performance_stats["start_time"].isoformat()
if self._performance_stats["start_time"] else None
),
"uptime_seconds": self._performance_stats["uptime_seconds"],
"total_frames_processed": self._performance_stats["total_frames_processed"],
"total_alerts_generated": self._performance_stats["total_alerts_generated"],
"stream_manager": (
self._stream_manager.get_statistics()
if self._stream_manager else {}
),
"config_version": (
self._config_manager.config_version
if self._config_manager else None
),
}
2026-01-29 18:33:12 +08:00
def main():
"""主函数入口"""
service = EdgeInferenceService()
try:
service.initialize()
service.start()
except KeyboardInterrupt:
service.stop()
except Exception as e:
logger.error(f"服务异常: {e}")
raise
if __name__ == "__main__":
main()