Files
Security_AI_integrated/inference/pipeline.py

365 lines
11 KiB
Python

import asyncio
import json
import os
import threading
import time
from collections import deque
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import cv2
import numpy as np
from config import get_config
from db.crud import (
create_alarm,
update_camera_status,
)
from db.models import init_db, get_session_factory
from inference.engine import YOLOEngine
from inference.roi.cache_manager import get_roi_cache
from inference.roi.roi_filter import ROIFilter
from inference.rules.algorithms import AlgorithmManager
from inference.stream import StreamManager
class InferencePipeline:
def __init__(self):
self.config = get_config()
self.db_initialized = False
self.yolo_engine = YOLOEngine(use_trt=True)
self.stream_manager = StreamManager(
buffer_size=self.config.stream.buffer_size,
reconnect_delay=self.config.stream.reconnect_delay,
)
self.roi_filter = ROIFilter()
self.roi_cache = get_roi_cache()
self.algo_manager = AlgorithmManager(working_hours=[
{
"start": [wh.start[0], wh.start[1]],
"end": [wh.end[0], wh.end[1]],
}
for wh in self.config.working_hours
])
self.camera_threads: Dict[int, threading.Thread] = {}
self.camera_stop_events: Dict[int, threading.Event] = {}
self.camera_latest_frames: Dict[int, Any] = {}
self.camera_frame_times: Dict[int, datetime] = {}
self.camera_process_counts: Dict[int, int] = {}
self.event_queue: deque = deque(maxlen=self.config.inference.event_queue_maxlen)
self.running = False
def _init_database(self):
if not self.db_initialized:
init_db()
self.db_initialized = True
self.roi_cache.initialize(get_session_factory(), refresh_interval=10.0)
self.roi_cache.start_background_refresh()
def _get_db_session(self):
from db.models import get_session_factory
SessionLocal = get_session_factory()
return SessionLocal()
def _load_cameras(self):
db = self._get_db_session()
try:
from db.crud import get_all_cameras
cameras = get_all_cameras(db, enabled_only=True)
for camera in cameras:
self.add_camera(camera)
finally:
db.close()
def add_camera(self, camera) -> bool:
camera_id = camera.id
if camera_id in self.camera_threads:
return False
self.camera_stop_events[camera_id] = threading.Event()
self.camera_process_counts[camera_id] = 0
self.stream_manager.add_stream(
str(camera_id),
camera.rtsp_url,
self.config.stream.buffer_size,
)
thread = threading.Thread(
target=self._camera_inference_loop,
args=(camera,),
daemon=True,
)
thread.start()
self.camera_threads[camera_id] = thread
self._update_camera_status(camera_id, is_running=True)
return True
def remove_camera(self, camera_id: int):
if camera_id not in self.camera_threads:
return
self.camera_stop_events[camera_id].set()
self.camera_threads[camera_id].join(timeout=10.0)
del self.camera_threads[camera_id]
del self.camera_stop_events[camera_id]
self.stream_manager.remove_stream(str(camera_id))
self.roi_filter.clear_cache(camera_id)
self.algo_manager.remove_roi(str(camera_id))
if camera_id in self.camera_latest_frames:
del self.camera_latest_frames[camera_id]
if camera_id in self.camera_frame_times:
del self.camera_frame_times[camera_id]
if camera_id in self.camera_process_counts:
del self.camera_process_counts[camera_id]
self._update_camera_status(camera_id, is_running=False)
def _update_camera_status(
self,
camera_id: int,
is_running: Optional[bool] = None,
fps: Optional[float] = None,
error_message: Optional[str] = None,
):
try:
db = self._get_db_session()
update_camera_status(
db,
camera_id,
is_running=is_running,
fps=fps,
error_message=error_message,
)
except Exception as e:
print(f"[{camera_id}] 更新状态失败: {e}")
finally:
db.close()
def _camera_inference_loop(self, camera):
camera_id = camera.id
stop_event = self.camera_stop_events[camera_id]
while not stop_event.is_set():
ret, frame = self.stream_manager.read(str(camera_id))
if not ret or frame is None:
time.sleep(0.1)
continue
self.camera_latest_frames[camera_id] = frame
self.camera_frame_times[camera_id] = datetime.now()
self.camera_process_counts[camera_id] += 1
if self.camera_process_counts[camera_id] % camera.process_every_n_frames != 0:
continue
try:
self._process_frame(camera_id, frame, camera)
except Exception as e:
print(f"[{camera_id}] 处理帧失败: {e}")
print(f"[{camera_id}] 推理线程已停止")
def _process_frame(self, camera_id: int, frame: np.ndarray, camera):
from ultralytics.engine.results import Results
roi_configs = self.roi_cache.get_rois(camera_id)
if roi_configs:
self.roi_filter.update_cache(camera_id, roi_configs)
for roi_config in roi_configs:
roi_id = roi_config["roi_id"]
rule_type = roi_config["rule"]
self.algo_manager.register_algorithm(
roi_id,
rule_type,
{
"threshold_sec": roi_config.get("threshold_sec", 360),
"confirm_sec": roi_config.get("confirm_sec", 30),
"return_sec": roi_config.get("return_sec", 5),
},
)
results = self.yolo_engine(frame, verbose=False, classes=[0])
if not results:
return
result = results[0]
detections = []
if hasattr(result, "boxes") and result.boxes is not None:
boxes = result.boxes.xyxy.cpu().numpy()
confs = result.boxes.conf.cpu().numpy()
for i, box in enumerate(boxes):
detections.append({
"bbox": box.tolist(),
"conf": float(confs[i]),
"cls": 0,
})
if roi_configs:
filtered_detections = self.roi_filter.filter_detections(
detections, roi_configs
)
else:
filtered_detections = detections
for detection in filtered_detections:
matched_rois = detection.get("matched_rois", [])
for roi_conf in matched_rois:
roi_id = roi_conf["roi_id"]
rule_type = roi_conf["rule"]
alerts = self.algo_manager.process(
roi_id,
str(camera_id),
rule_type,
[detection],
datetime.now(),
)
for alert in alerts:
self._handle_alert(camera_id, alert, frame, roi_conf)
def _handle_alert(
self,
camera_id: int,
alert: Dict[str, Any],
frame: np.ndarray,
roi_config: Dict[str, Any],
):
try:
snapshot_path = None
bbox = alert.get("bbox", [])
if bbox and len(bbox) >= 4:
x1, y1, x2, y2 = [int(v) for v in bbox]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(frame.shape[1], x2), min(frame.shape[0], y2)
snapshot_dir = self.config.alert.snapshot_path
os.makedirs(snapshot_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"cam_{camera_id}_{roi_config['roi_id']}_{alert['alert_type']}_{timestamp}.jpg"
snapshot_path = os.path.join(snapshot_dir, filename)
cv2.imwrite(snapshot_path, frame, [cv2.IMWRITE_JPEG_QUALITY, self.config.alert.image_quality])
db = self._get_db_session()
try:
alarm = create_alarm(
db,
camera_id=camera_id,
event_type=alert["alert_type"],
confidence=alert.get("confidence", alert.get("conf", 0.0)),
snapshot_path=snapshot_path,
roi_id=roi_config["roi_id"],
)
alert["alarm_id"] = alarm.id
finally:
db.close()
self.event_queue.append({
"camera_id": camera_id,
"roi_id": roi_config["roi_id"],
"event_type": alert["alert_type"],
"confidence": alert.get("confidence", alert.get("conf", 0.0)),
"message": alert.get("message", ""),
"snapshot_path": snapshot_path,
"timestamp": datetime.now().isoformat(),
"llm_checked": False,
})
print(f"[{camera_id}] 🚨 告警: {alert['alert_type']} - {alert.get('message', '')}")
except Exception as e:
print(f"[{camera_id}] 处理告警失败: {e}")
def get_latest_frame(self, camera_id: int) -> Optional[np.ndarray]:
return self.camera_latest_frames.get(camera_id)
def get_camera_fps(self, camera_id: int) -> float:
stream = self.stream_manager.get_stream(str(camera_id))
if stream:
return stream.fps
return 0.0
def get_event_queue(self) -> List[Dict[str, Any]]:
return list(self.event_queue)
def start(self):
if self.running:
return
self._init_database()
self._load_cameras()
self.running = True
def stop(self):
if not self.running:
return
self.running = False
for camera_id in list(self.camera_threads.keys()):
self.remove_camera(camera_id)
self.stream_manager.stop_all()
self.algo_manager.reset_all()
print("推理pipeline已停止")
def get_status(self) -> Dict[str, Any]:
result = {
"running": self.running,
"camera_count": len(self.camera_threads),
"cameras": {},
"event_queue_size": len(self.event_queue),
}
for cid in self.camera_threads:
frame_time = self.camera_frame_times.get(cid)
result["cameras"][str(cid)] = {
"is_running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(),
"fps": self.get_camera_fps(cid),
"last_check_time": frame_time.isoformat() if frame_time else None,
}
return result
_pipeline: Optional[InferencePipeline] = None
def get_pipeline() -> InferencePipeline:
global _pipeline
if _pipeline is None:
_pipeline = InferencePipeline()
return _pipeline
def start_pipeline():
pipeline = get_pipeline()
pipeline.start()
return pipeline
def stop_pipeline():
global _pipeline
if _pipeline is not None:
_pipeline.stop()
_pipeline = None