Files
Security_AI_integrated/inference/pipeline.py

365 lines
11 KiB
Python
Raw Normal View History

2026-01-20 17:42:18 +08:00
import asyncio
import json
import os
import threading
import time
from collections import deque
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
import cv2
2026-01-20 17:42:18 +08:00
import numpy as np
from config import get_config
from db.crud import (
create_alarm,
update_camera_status,
)
2026-01-20 17:46:32 +08:00
from db.models import init_db, get_session_factory
2026-01-20 17:42:18 +08:00
from inference.engine import YOLOEngine
2026-01-20 17:46:32 +08:00
from inference.roi.cache_manager import get_roi_cache
2026-01-20 17:42:18 +08:00
from inference.roi.roi_filter import ROIFilter
from inference.rules.algorithms import AlgorithmManager
from inference.stream import StreamManager
class InferencePipeline:
def __init__(self):
self.config = get_config()
self.db_initialized = False
self.yolo_engine = YOLOEngine(use_trt=True)
self.stream_manager = StreamManager(
buffer_size=self.config.stream.buffer_size,
reconnect_delay=self.config.stream.reconnect_delay,
)
self.roi_filter = ROIFilter()
2026-01-20 17:46:32 +08:00
self.roi_cache = get_roi_cache()
2026-01-20 17:42:18 +08:00
self.algo_manager = AlgorithmManager(working_hours=[
{
"start": [wh.start[0], wh.start[1]],
"end": [wh.end[0], wh.end[1]],
}
for wh in self.config.working_hours
])
self.camera_threads: Dict[int, threading.Thread] = {}
self.camera_stop_events: Dict[int, threading.Event] = {}
self.camera_latest_frames: Dict[int, Any] = {}
self.camera_frame_times: Dict[int, datetime] = {}
self.camera_process_counts: Dict[int, int] = {}
self.event_queue: deque = deque(maxlen=self.config.inference.event_queue_maxlen)
self.running = False
def _init_database(self):
if not self.db_initialized:
init_db()
self.db_initialized = True
2026-01-20 17:46:32 +08:00
self.roi_cache.initialize(get_session_factory(), refresh_interval=10.0)
self.roi_cache.start_background_refresh()
2026-01-20 17:42:18 +08:00
def _get_db_session(self):
from db.models import get_session_factory
SessionLocal = get_session_factory()
return SessionLocal()
def _load_cameras(self):
db = self._get_db_session()
try:
from db.crud import get_all_cameras
cameras = get_all_cameras(db, enabled_only=True)
for camera in cameras:
self.add_camera(camera)
finally:
db.close()
def add_camera(self, camera) -> bool:
camera_id = camera.id
if camera_id in self.camera_threads:
return False
self.camera_stop_events[camera_id] = threading.Event()
self.camera_process_counts[camera_id] = 0
self.stream_manager.add_stream(
str(camera_id),
camera.rtsp_url,
self.config.stream.buffer_size,
)
thread = threading.Thread(
target=self._camera_inference_loop,
args=(camera,),
daemon=True,
)
thread.start()
self.camera_threads[camera_id] = thread
self._update_camera_status(camera_id, is_running=True)
return True
def remove_camera(self, camera_id: int):
if camera_id not in self.camera_threads:
return
self.camera_stop_events[camera_id].set()
self.camera_threads[camera_id].join(timeout=10.0)
del self.camera_threads[camera_id]
del self.camera_stop_events[camera_id]
self.stream_manager.remove_stream(str(camera_id))
self.roi_filter.clear_cache(camera_id)
self.algo_manager.remove_roi(str(camera_id))
if camera_id in self.camera_latest_frames:
del self.camera_latest_frames[camera_id]
if camera_id in self.camera_frame_times:
del self.camera_frame_times[camera_id]
if camera_id in self.camera_process_counts:
del self.camera_process_counts[camera_id]
self._update_camera_status(camera_id, is_running=False)
def _update_camera_status(
self,
camera_id: int,
is_running: Optional[bool] = None,
fps: Optional[float] = None,
error_message: Optional[str] = None,
):
try:
db = self._get_db_session()
update_camera_status(
db,
camera_id,
is_running=is_running,
fps=fps,
error_message=error_message,
)
except Exception as e:
print(f"[{camera_id}] 更新状态失败: {e}")
finally:
db.close()
def _camera_inference_loop(self, camera):
camera_id = camera.id
stop_event = self.camera_stop_events[camera_id]
while not stop_event.is_set():
ret, frame = self.stream_manager.read(str(camera_id))
if not ret or frame is None:
time.sleep(0.1)
continue
self.camera_latest_frames[camera_id] = frame
self.camera_frame_times[camera_id] = datetime.now()
self.camera_process_counts[camera_id] += 1
if self.camera_process_counts[camera_id] % camera.process_every_n_frames != 0:
continue
try:
self._process_frame(camera_id, frame, camera)
except Exception as e:
print(f"[{camera_id}] 处理帧失败: {e}")
print(f"[{camera_id}] 推理线程已停止")
def _process_frame(self, camera_id: int, frame: np.ndarray, camera):
from ultralytics.engine.results import Results
2026-01-20 17:46:32 +08:00
roi_configs = self.roi_cache.get_rois(camera_id)
if roi_configs:
self.roi_filter.update_cache(camera_id, roi_configs)
for roi_config in roi_configs:
roi_id = roi_config["roi_id"]
rule_type = roi_config["rule"]
self.algo_manager.register_algorithm(
roi_id,
rule_type,
2026-01-20 17:42:18 +08:00
{
2026-01-20 17:46:32 +08:00
"threshold_sec": roi_config.get("threshold_sec", 360),
"confirm_sec": roi_config.get("confirm_sec", 30),
"return_sec": roi_config.get("return_sec", 5),
},
)
results = self.yolo_engine(frame, verbose=False, classes=[0])
if not results:
return
result = results[0]
detections = []
if hasattr(result, "boxes") and result.boxes is not None:
boxes = result.boxes.xyxy.cpu().numpy()
confs = result.boxes.conf.cpu().numpy()
for i, box in enumerate(boxes):
detections.append({
"bbox": box.tolist(),
"conf": float(confs[i]),
"cls": 0,
})
if roi_configs:
filtered_detections = self.roi_filter.filter_detections(
detections, roi_configs
)
else:
filtered_detections = detections
for detection in filtered_detections:
matched_rois = detection.get("matched_rois", [])
for roi_conf in matched_rois:
roi_id = roi_conf["roi_id"]
rule_type = roi_conf["rule"]
alerts = self.algo_manager.process(
2026-01-20 17:42:18 +08:00
roi_id,
2026-01-20 17:46:32 +08:00
str(camera_id),
2026-01-20 17:42:18 +08:00
rule_type,
2026-01-20 17:46:32 +08:00
[detection],
datetime.now(),
2026-01-20 17:42:18 +08:00
)
2026-01-20 17:46:32 +08:00
for alert in alerts:
self._handle_alert(camera_id, alert, frame, roi_conf)
2026-01-20 17:42:18 +08:00
def _handle_alert(
self,
camera_id: int,
alert: Dict[str, Any],
frame: np.ndarray,
roi_config: Dict[str, Any],
):
try:
snapshot_path = None
bbox = alert.get("bbox", [])
if bbox and len(bbox) >= 4:
x1, y1, x2, y2 = [int(v) for v in bbox]
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(frame.shape[1], x2), min(frame.shape[0], y2)
snapshot_dir = self.config.alert.snapshot_path
os.makedirs(snapshot_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"cam_{camera_id}_{roi_config['roi_id']}_{alert['alert_type']}_{timestamp}.jpg"
snapshot_path = os.path.join(snapshot_dir, filename)
cv2.imwrite(snapshot_path, frame, [cv2.IMWRITE_JPEG_QUALITY, self.config.alert.image_quality])
db = self._get_db_session()
try:
alarm = create_alarm(
db,
camera_id=camera_id,
event_type=alert["alert_type"],
confidence=alert.get("confidence", alert.get("conf", 0.0)),
snapshot_path=snapshot_path,
roi_id=roi_config["roi_id"],
)
alert["alarm_id"] = alarm.id
finally:
db.close()
self.event_queue.append({
"camera_id": camera_id,
"roi_id": roi_config["roi_id"],
"event_type": alert["alert_type"],
"confidence": alert.get("confidence", alert.get("conf", 0.0)),
"message": alert.get("message", ""),
"snapshot_path": snapshot_path,
"timestamp": datetime.now().isoformat(),
"llm_checked": False,
})
print(f"[{camera_id}] 🚨 告警: {alert['alert_type']} - {alert.get('message', '')}")
except Exception as e:
print(f"[{camera_id}] 处理告警失败: {e}")
def get_latest_frame(self, camera_id: int) -> Optional[np.ndarray]:
return self.camera_latest_frames.get(camera_id)
def get_camera_fps(self, camera_id: int) -> float:
stream = self.stream_manager.get_stream(str(camera_id))
if stream:
return stream.fps
return 0.0
def get_event_queue(self) -> List[Dict[str, Any]]:
return list(self.event_queue)
def start(self):
if self.running:
return
self._init_database()
self._load_cameras()
self.running = True
def stop(self):
if not self.running:
return
self.running = False
for camera_id in list(self.camera_threads.keys()):
self.remove_camera(camera_id)
self.stream_manager.stop_all()
self.algo_manager.reset_all()
print("推理pipeline已停止")
def get_status(self) -> Dict[str, Any]:
result = {
2026-01-20 17:42:18 +08:00
"running": self.running,
"camera_count": len(self.camera_threads),
"cameras": {},
2026-01-20 17:42:18 +08:00
"event_queue_size": len(self.event_queue),
}
for cid in self.camera_threads:
frame_time = self.camera_frame_times.get(cid)
result["cameras"][str(cid)] = {
"is_running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(),
"fps": self.get_camera_fps(cid),
"last_check_time": frame_time.isoformat() if frame_time else None,
}
return result
2026-01-20 17:42:18 +08:00
_pipeline: Optional[InferencePipeline] = None
def get_pipeline() -> InferencePipeline:
global _pipeline
if _pipeline is None:
_pipeline = InferencePipeline()
return _pipeline
def start_pipeline():
pipeline = get_pipeline()
pipeline.start()
return pipeline
def stop_pipeline():
global _pipeline
if _pipeline is not None:
_pipeline.stop()
_pipeline = None