ResultReporter 创建后需调用 initialize() 方法初始化 MQTT 连接 Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
567 lines
20 KiB
Python
567 lines
20 KiB
Python
"""
|
||
主入口模块
|
||
整合所有模块,启动推理服务
|
||
"""
|
||
|
||
import logging
|
||
import os
|
||
import sys
|
||
import threading
|
||
import signal
|
||
import time
|
||
from datetime import datetime
|
||
from typing import Dict, Any, Optional, List, Tuple
|
||
|
||
from config.settings import get_settings, Settings
|
||
from core.config_sync import get_config_sync_manager, ConfigSyncManager
|
||
from core.video_stream import MultiStreamManager, VideoFrame
|
||
from core.preprocessor import ImagePreprocessor
|
||
from core.tensorrt_engine import TensorRTEngine, EngineManager
|
||
from core.postprocessor import PostProcessor
|
||
from core.result_reporter import ResultReporter
|
||
from algorithms import AlgorithmManager
|
||
from utils.logger import get_logger, StructuredLogger
|
||
from utils.version_control import get_version_control
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class EdgeInferenceService:
|
||
"""边缘推理服务主类
|
||
|
||
整合所有模块,提供完整的推理服务
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._running = False
|
||
self._settings = get_settings()
|
||
self._logger = get_logger("main")
|
||
self._version_control = get_version_control()
|
||
|
||
self._db_manager = None
|
||
self._config_manager: Optional[ConfigSyncManager] = None
|
||
self._stream_manager: Optional[MultiStreamManager] = None
|
||
self._preprocessor: Optional[ImagePreprocessor] = None
|
||
self._engine_manager: Optional[EngineManager] = None
|
||
self._postprocessor: Optional[PostProcessor] = None
|
||
self._reporter: Optional[ResultReporter] = None
|
||
self._algorithm_manager: Optional[AlgorithmManager] = None
|
||
|
||
self._processing_threads: Dict[str, threading.Thread] = {}
|
||
self._stop_event = threading.Event()
|
||
|
||
self._performance_stats = {
|
||
"start_time": None,
|
||
"total_frames_processed": 0,
|
||
"total_alerts_generated": 0,
|
||
"uptime_seconds": 0,
|
||
}
|
||
|
||
self._batch_roi_queue: List[tuple] = []
|
||
self._batch_lock = threading.Lock()
|
||
self._batch_event = threading.Event()
|
||
self._inference_thread: Optional[threading.Thread] = None
|
||
self._max_batch_size = 8
|
||
self._batch_timeout_sec = 0.05 # 50ms 攒批窗口
|
||
|
||
self._logger.info("Edge_Inference_Service 初始化开始")
|
||
|
||
def _init_database(self):
|
||
"""初始化数据库"""
|
||
try:
|
||
from config.database import SQLiteManager
|
||
self._db_manager = SQLiteManager()
|
||
if self._db_manager._conn:
|
||
self._logger.info("数据库初始化成功")
|
||
else:
|
||
self._logger.warning("数据库不可见,服务将在无数据库模式下运行")
|
||
except Exception as e:
|
||
self._logger.warning(f"数据库初始化失败,服务将在无数据库模式下运行: {e}")
|
||
self._db_manager = None
|
||
|
||
def _init_config_manager(self):
|
||
"""初始化配置管理器"""
|
||
try:
|
||
self._config_manager = get_config_sync_manager()
|
||
self._config_manager.start_config_subscription()
|
||
self._logger.info("配置管理器初始化成功")
|
||
except Exception as e:
|
||
self._logger.error(f"配置管理器初始化失败: {e}")
|
||
raise
|
||
|
||
def _init_stream_manager(self):
|
||
"""初始化流管理器"""
|
||
self._stream_manager = MultiStreamManager()
|
||
self._logger.info("流管理器初始化成功")
|
||
|
||
def _init_preprocessor(self):
|
||
"""初始化预处理器"""
|
||
self._preprocessor = ImagePreprocessor()
|
||
self._logger.info("预处理器初始化成功")
|
||
|
||
def _init_engine(self):
|
||
"""初始化推理引擎"""
|
||
try:
|
||
self._engine_manager = EngineManager()
|
||
|
||
engine_path = self._settings.inference.model_path
|
||
if os.path.exists(engine_path):
|
||
self._engine_manager.load_engine("default", engine_path)
|
||
self._logger.info(f"推理引擎加载成功: {engine_path}")
|
||
else:
|
||
self._logger.warning(f"引擎文件不存在: {engine_path}")
|
||
|
||
except Exception as e:
|
||
self._logger.error(f"推理引擎初始化失败: {e}")
|
||
raise
|
||
|
||
def _init_postprocessor(self):
|
||
"""初始化后处理器"""
|
||
self._postprocessor = PostProcessor()
|
||
self._logger.info("后处理器初始化成功")
|
||
|
||
def _init_reporter(self):
|
||
"""初始化结果上报器"""
|
||
self._reporter = ResultReporter()
|
||
self._reporter.initialize() # 初始化存储和MQTT连接
|
||
self._logger.info("结果上报器初始化成功")
|
||
|
||
def _init_algorithm_manager(self):
|
||
"""初始化算法管理器"""
|
||
try:
|
||
self._algorithm_manager = AlgorithmManager()
|
||
self._algorithm_manager.start_config_subscription()
|
||
self._logger.info("算法管理器初始化成功")
|
||
except Exception as e:
|
||
self._logger.error(f"算法管理器初始化失败: {e}")
|
||
|
||
def initialize(self):
|
||
"""初始化所有组件"""
|
||
self._logger.info("=" * 50)
|
||
self._logger.info("Edge_Inference_Service 启动")
|
||
self._logger.info("=" * 50)
|
||
|
||
self._init_database()
|
||
self._init_config_manager()
|
||
self._init_stream_manager()
|
||
self._init_preprocessor()
|
||
self._init_engine()
|
||
self._init_postprocessor()
|
||
self._init_reporter()
|
||
self._init_algorithm_manager()
|
||
|
||
self._performance_stats["start_time"] = datetime.now()
|
||
|
||
self._version_control.record_update(
|
||
version="1.0.0",
|
||
update_type="启动",
|
||
description="Edge_Inference_Service 启动运行",
|
||
updated_by="系统",
|
||
affected_items=["全局"],
|
||
)
|
||
|
||
self._logger.info("所有组件初始化完成")
|
||
|
||
def _load_cameras(self):
|
||
"""加载摄像头配置"""
|
||
cameras = self._config_manager.get_cameras()
|
||
|
||
for camera in cameras:
|
||
try:
|
||
self._stream_manager.add_stream(
|
||
camera_id=camera.camera_id,
|
||
rtsp_url=camera.rtsp_url,
|
||
target_fps=self._settings.video_stream.default_fps,
|
||
on_frame_callback=self._create_frame_callback(camera.camera_id)
|
||
)
|
||
self._logger.info(f"已添加摄像头: {camera.camera_id}")
|
||
except Exception as e:
|
||
self._logger.error(f"添加摄像头失败 {camera.camera_id}: {e}")
|
||
|
||
def _create_frame_callback(self, camera_id: str):
|
||
"""创建帧处理回调"""
|
||
def callback(frame):
|
||
self._process_frame(camera_id, frame)
|
||
return callback
|
||
|
||
def _process_frame(self, camera_id: str, frame: VideoFrame):
|
||
"""处理视频帧 - 批量处理多 ROI"""
|
||
try:
|
||
roi_configs = self._config_manager.get_roi_configs_with_bindings(camera_id)
|
||
|
||
# 每100帧打印一次状态
|
||
if self._performance_stats["total_frames_processed"] % 100 == 0:
|
||
self._logger.info(f"[{camera_id}] 已处理 {self._performance_stats['total_frames_processed']} 帧, ROI数: {len(roi_configs)}")
|
||
|
||
roi_items = []
|
||
for roi in roi_configs:
|
||
if not roi.enabled:
|
||
continue
|
||
if not roi.bindings:
|
||
continue
|
||
|
||
for bind in roi.bindings:
|
||
if not bind.enabled:
|
||
continue
|
||
|
||
try:
|
||
cropped, scale_info = self._preprocessor.preprocess_single(
|
||
frame.image, roi
|
||
)
|
||
roi_items.append((camera_id, roi, bind, frame, cropped, scale_info))
|
||
except Exception as e:
|
||
self._logger.error(f"预处理 ROI 失败 {roi.roi_id}: {e}")
|
||
|
||
if not roi_items:
|
||
return
|
||
|
||
with self._batch_lock:
|
||
self._batch_roi_queue.extend(roi_items)
|
||
|
||
# 通知推理线程有新数据
|
||
self._batch_event.set()
|
||
|
||
self._performance_stats["total_frames_processed"] += 1
|
||
|
||
except Exception as e:
|
||
self._logger.error(f"处理帧失败 {camera_id}: {e}")
|
||
|
||
def _batch_process_rois(self):
|
||
"""批量处理 ROI - 真正的 batch 推理"""
|
||
with self._batch_lock:
|
||
roi_items = self._batch_roi_queue
|
||
if not roi_items:
|
||
return
|
||
self._batch_roi_queue = []
|
||
|
||
try:
|
||
images = [item[4] for item in roi_items]
|
||
scale_infos = [item[5] for item in roi_items]
|
||
|
||
# 真正的 batch: 将所有 ROI 裁剪图拼成 [N,3,H,W] 一次推理
|
||
batch_data, _ = self._preprocessor._batch_preprocessor.preprocess_batch(
|
||
images
|
||
)
|
||
|
||
engine = self._engine_manager.get_engine("default")
|
||
if engine is None:
|
||
return
|
||
|
||
# 一次性推理整个 batch
|
||
outputs, inference_time_ms = engine.infer(batch_data)
|
||
|
||
batch_size = len(roi_items)
|
||
batch_results = self._postprocessor.batch_process_detections(
|
||
outputs,
|
||
batch_size,
|
||
conf_threshold=self._settings.inference.conf_threshold
|
||
)
|
||
|
||
# 诊断日志:显示每个 ROI 的检测结果数量
|
||
total_detections = sum(len(r[0]) for r in batch_results)
|
||
if total_detections > 0:
|
||
self._logger.info(f"[推理] batch_size={batch_size}, 总检测数={total_detections}")
|
||
|
||
for idx, (camera_id, roi, bind, frame, _, scale_info) in enumerate(roi_items):
|
||
boxes, scores, class_ids = batch_results[idx]
|
||
|
||
if len(boxes) > 0:
|
||
self._handle_detections(
|
||
camera_id, roi, bind, frame,
|
||
boxes, scores, class_ids,
|
||
scale_info
|
||
)
|
||
|
||
except Exception as e:
|
||
self._logger.error(f"批量处理 ROI 失败: {e}")
|
||
|
||
def _process_roi_frame(self, camera_id: str, frame: VideoFrame, roi):
|
||
"""收集 ROI 帧数据 - 已集成到 _process_frame 中"""
|
||
pass
|
||
|
||
def _build_tracks(
|
||
self,
|
||
roi,
|
||
boxes: any,
|
||
scores: any,
|
||
class_ids: any,
|
||
scale_info: tuple
|
||
) -> list:
|
||
"""将检测结果转换为算法所需的 tracks 格式
|
||
|
||
坐标从 letterbox 空间还原到 ROI 裁剪空间
|
||
"""
|
||
tracks = []
|
||
class_names = getattr(self._settings, 'class_names', ['person'])
|
||
|
||
# 将 letterbox 坐标还原到 ROI 裁剪空间
|
||
reverted_boxes = self._preprocessor.revert_boxes(
|
||
[box.tolist() if hasattr(box, 'tolist') else list(box) for box in boxes],
|
||
scale_info
|
||
)
|
||
|
||
for i, box in enumerate(reverted_boxes):
|
||
class_id = int(class_ids[i]) if class_ids[i] else 0
|
||
track = {
|
||
"track_id": f"{roi.roi_id}_{i}",
|
||
"class": class_names[class_id] if class_id < len(class_names) else f"class_{class_id}",
|
||
"confidence": float(scores[i]),
|
||
"bbox": box,
|
||
"matched_rois": [{"roi_id": roi.roi_id}],
|
||
}
|
||
tracks.append(track)
|
||
|
||
return tracks
|
||
|
||
def _handle_detections(
|
||
self,
|
||
camera_id: str,
|
||
roi,
|
||
bind,
|
||
frame: VideoFrame,
|
||
boxes: any,
|
||
scores: any,
|
||
class_ids: any,
|
||
scale_info: tuple
|
||
):
|
||
"""处理检测结果 - 算法接管判断权"""
|
||
try:
|
||
if self._algorithm_manager is None:
|
||
self._logger.warning("算法管理器不可用,跳过算法处理")
|
||
return
|
||
|
||
roi_id = roi.roi_id
|
||
algo_code = bind.algo_code
|
||
algo_params = bind.params or {}
|
||
|
||
# 诊断日志:检测到目标(使用 INFO 级别确保能看到)
|
||
if len(boxes) > 0:
|
||
self._logger.info(f"[{camera_id}] ROI={roi_id[:8]} 检测到 {len(boxes)} 个目标, algo={algo_code}")
|
||
|
||
self._algorithm_manager.register_algorithm(
|
||
roi_id=roi_id,
|
||
bind_id=bind.bind_id,
|
||
algorithm_type=algo_code,
|
||
params=algo_params
|
||
)
|
||
|
||
tracks = self._build_tracks(roi, boxes, scores, class_ids, scale_info)
|
||
|
||
if not tracks:
|
||
return
|
||
|
||
# 诊断日志:tracks 内容(INFO 级别)
|
||
self._logger.info(f"[{camera_id}] tracks: {[t.get('class') for t in tracks]}, target_class={bind.target_class}")
|
||
|
||
alerts = self._algorithm_manager.process(
|
||
roi_id=roi_id,
|
||
bind_id=bind.bind_id,
|
||
camera_id=camera_id,
|
||
algorithm_type=algo_code,
|
||
tracks=tracks,
|
||
current_time=frame.timestamp
|
||
)
|
||
|
||
# 诊断日志:算法处理结果
|
||
if alerts:
|
||
self._logger.info(f"[{camera_id}] 算法 {algo_code} 返回 {len(alerts)} 个告警")
|
||
else:
|
||
# 获取算法状态用于诊断
|
||
algo_status = self._algorithm_manager.get_status(roi_id)
|
||
self._logger.info(f"[{camera_id}] 算法 {algo_code} 无告警, 状态: {algo_status}")
|
||
|
||
for alert in alerts:
|
||
self._performance_stats["total_alerts_generated"] += 1
|
||
|
||
from core.result_reporter import AlertInfo
|
||
alert_info = AlertInfo(
|
||
alert_id=f"{roi_id}_{bind.bind_id}_{int(frame.timestamp.timestamp())}",
|
||
camera_id=camera_id,
|
||
roi_id=roi_id,
|
||
bind_id=bind.bind_id,
|
||
alert_type=alert.get("alert_type", "detection"),
|
||
target_class=alert.get("class", bind.target_class or "unknown"),
|
||
confidence=alert.get("confidence", 1.0),
|
||
bbox=alert.get("bbox", []),
|
||
message=alert.get("message", ""),
|
||
timestamp=frame.timestamp,
|
||
duration_minutes=alert.get("duration_minutes"),
|
||
)
|
||
self._reporter.report_alert(alert_info, screenshot=frame.image)
|
||
|
||
self._logger.info(
|
||
f"告警已生成: type={alert.get('alert_type', 'detection')}, "
|
||
f"camera={camera_id}, roi={roi_id}, "
|
||
f"confidence={alert.get('confidence', 1.0)}"
|
||
)
|
||
|
||
except Exception as e:
|
||
self._logger.error(f"处理检测结果失败: {e}")
|
||
|
||
def _inference_worker(self):
|
||
"""推理线程:攒批窗口内收集 ROI 请求,批量推理"""
|
||
while not self._stop_event.is_set():
|
||
# 等待有新数据到达或超时
|
||
self._batch_event.wait(timeout=self._batch_timeout_sec)
|
||
self._batch_event.clear()
|
||
|
||
with self._batch_lock:
|
||
queue_size = len(self._batch_roi_queue)
|
||
|
||
# 攒批窗口:等到攒够 max_batch 或超时后处理
|
||
if queue_size > 0 and queue_size < self._max_batch_size:
|
||
# 再等一小段时间凑更多
|
||
self._batch_event.wait(timeout=self._batch_timeout_sec)
|
||
self._batch_event.clear()
|
||
|
||
with self._batch_lock:
|
||
if not self._batch_roi_queue:
|
||
continue
|
||
|
||
self._batch_process_rois()
|
||
|
||
def start(self):
|
||
"""启动服务"""
|
||
if self._running:
|
||
return
|
||
|
||
self._running = True
|
||
self._stop_event.clear()
|
||
|
||
self._load_cameras()
|
||
|
||
# 启动独立推理线程(生产者-消费者模式)
|
||
self._inference_thread = threading.Thread(
|
||
target=self._inference_worker,
|
||
name="InferenceWorker",
|
||
daemon=True
|
||
)
|
||
self._inference_thread.start()
|
||
self._logger.info("推理线程已启动")
|
||
|
||
self._stream_manager.start_all()
|
||
|
||
self._logger.info("Edge_Inference_Service 已启动")
|
||
|
||
self._start_heartbeat_thread()
|
||
|
||
self._register_signal_handlers()
|
||
|
||
self._wait_for_shutdown()
|
||
|
||
def _start_heartbeat_thread(self):
|
||
"""启动心跳线程"""
|
||
def heartbeat():
|
||
while not self._stop_event.is_set():
|
||
try:
|
||
uptime = (datetime.now() - self._performance_stats["start_time"]).total_seconds()
|
||
self._performance_stats["uptime_seconds"] = uptime
|
||
|
||
status = {
|
||
"running": True,
|
||
"uptime_seconds": uptime,
|
||
"frames_processed": self._performance_stats["total_frames_processed"],
|
||
"alerts_generated": self._performance_stats["total_alerts_generated"],
|
||
"stream_stats": self._stream_manager.get_statistics() if self._stream_manager else {},
|
||
}
|
||
|
||
if self._reporter:
|
||
self._reporter.report_heartbeat("edge_inference_device", status)
|
||
|
||
except Exception as e:
|
||
self._logger.error(f"心跳上报失败: {e}")
|
||
|
||
time.sleep(30)
|
||
|
||
thread = threading.Thread(target=heartbeat, name="Heartbeat", daemon=True)
|
||
thread.start()
|
||
|
||
def _register_signal_handlers(self):
|
||
"""注册信号处理器"""
|
||
def handle_signal(signum, frame):
|
||
self._logger.info(f"收到信号 {signum}, 正在停止服务...")
|
||
self.stop()
|
||
|
||
signal.signal(signal.SIGINT, handle_signal)
|
||
signal.signal(signal.SIGTERM, handle_signal)
|
||
|
||
def _wait_for_shutdown(self):
|
||
"""等待关闭信号"""
|
||
while not self._stop_event.is_set():
|
||
time.sleep(1)
|
||
|
||
def stop(self):
|
||
"""停止服务"""
|
||
if not self._running:
|
||
return
|
||
|
||
self._running = False
|
||
self._stop_event.set()
|
||
self._batch_event.set() # 唤醒推理线程以退出
|
||
|
||
if self._inference_thread and self._inference_thread.is_alive():
|
||
self._inference_thread.join(timeout=5)
|
||
|
||
if self._stream_manager:
|
||
self._stream_manager.stop_all()
|
||
self._stream_manager.close()
|
||
|
||
if self._engine_manager:
|
||
self._engine_manager.release_all()
|
||
|
||
if self._config_manager:
|
||
self._config_manager.stop_config_subscription()
|
||
self._config_manager.close()
|
||
|
||
if self._algorithm_manager:
|
||
self._algorithm_manager.stop_config_subscription()
|
||
|
||
if self._reporter:
|
||
self._reporter.close()
|
||
|
||
self._performance_stats["uptime_seconds"] = (
|
||
(datetime.now() - self._performance_stats["start_time"]).total_seconds()
|
||
)
|
||
|
||
self._logger.info("Edge_Inference_Service 已停止")
|
||
self._logger.info(f"运行统计: {self._performance_stats}")
|
||
|
||
def get_status(self) -> Dict[str, Any]:
|
||
"""获取服务状态"""
|
||
return {
|
||
"running": self._running,
|
||
"start_time": (
|
||
self._performance_stats["start_time"].isoformat()
|
||
if self._performance_stats["start_time"] else None
|
||
),
|
||
"uptime_seconds": self._performance_stats["uptime_seconds"],
|
||
"total_frames_processed": self._performance_stats["total_frames_processed"],
|
||
"total_alerts_generated": self._performance_stats["total_alerts_generated"],
|
||
"stream_manager": (
|
||
self._stream_manager.get_statistics()
|
||
if self._stream_manager else {}
|
||
),
|
||
"config_version": (
|
||
self._config_manager.config_version
|
||
if self._config_manager else None
|
||
),
|
||
}
|
||
|
||
|
||
def main():
|
||
"""主函数入口"""
|
||
service = EdgeInferenceService()
|
||
|
||
try:
|
||
service.initialize()
|
||
service.start()
|
||
except KeyboardInterrupt:
|
||
service.stop()
|
||
except Exception as e:
|
||
logger.error(f"服务异常: {e}")
|
||
raise
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|