ROI选区01
This commit is contained in:
376
inference/pipeline.py
Normal file
376
inference/pipeline.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import get_config
|
||||
from db.crud import (
|
||||
create_alarm,
|
||||
get_all_rois,
|
||||
update_camera_status,
|
||||
)
|
||||
from db.models import init_db
|
||||
from inference.engine import YOLOEngine
|
||||
from inference.roi.roi_filter import ROIFilter
|
||||
from inference.rules.algorithms import AlgorithmManager
|
||||
from inference.stream import StreamManager
|
||||
|
||||
|
||||
class InferencePipeline:
|
||||
def __init__(self):
|
||||
self.config = get_config()
|
||||
|
||||
self.db_initialized = False
|
||||
|
||||
self.yolo_engine = YOLOEngine(use_trt=True)
|
||||
self.stream_manager = StreamManager(
|
||||
buffer_size=self.config.stream.buffer_size,
|
||||
reconnect_delay=self.config.stream.reconnect_delay,
|
||||
)
|
||||
self.roi_filter = ROIFilter()
|
||||
self.algo_manager = AlgorithmManager(working_hours=[
|
||||
{
|
||||
"start": [wh.start[0], wh.start[1]],
|
||||
"end": [wh.end[0], wh.end[1]],
|
||||
}
|
||||
for wh in self.config.working_hours
|
||||
])
|
||||
|
||||
self.camera_threads: Dict[int, threading.Thread] = {}
|
||||
self.camera_stop_events: Dict[int, threading.Event] = {}
|
||||
self.camera_latest_frames: Dict[int, Any] = {}
|
||||
self.camera_frame_times: Dict[int, datetime] = {}
|
||||
self.camera_process_counts: Dict[int, int] = {}
|
||||
|
||||
self.event_queue: deque = deque(maxlen=self.config.inference.event_queue_maxlen)
|
||||
|
||||
self.running = False
|
||||
|
||||
def _init_database(self):
|
||||
if not self.db_initialized:
|
||||
init_db()
|
||||
self.db_initialized = True
|
||||
|
||||
def _get_db_session(self):
|
||||
from db.models import get_session_factory
|
||||
SessionLocal = get_session_factory()
|
||||
return SessionLocal()
|
||||
|
||||
def _load_cameras(self):
|
||||
db = self._get_db_session()
|
||||
try:
|
||||
from db.crud import get_all_cameras
|
||||
cameras = get_all_cameras(db, enabled_only=True)
|
||||
for camera in cameras:
|
||||
self.add_camera(camera)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def add_camera(self, camera) -> bool:
|
||||
camera_id = camera.id
|
||||
if camera_id in self.camera_threads:
|
||||
return False
|
||||
|
||||
self.camera_stop_events[camera_id] = threading.Event()
|
||||
self.camera_process_counts[camera_id] = 0
|
||||
|
||||
self.stream_manager.add_stream(
|
||||
str(camera_id),
|
||||
camera.rtsp_url,
|
||||
self.config.stream.buffer_size,
|
||||
)
|
||||
|
||||
thread = threading.Thread(
|
||||
target=self._camera_inference_loop,
|
||||
args=(camera,),
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
self.camera_threads[camera_id] = thread
|
||||
|
||||
self._update_camera_status(camera_id, is_running=True)
|
||||
return True
|
||||
|
||||
def remove_camera(self, camera_id: int):
|
||||
if camera_id not in self.camera_threads:
|
||||
return
|
||||
|
||||
self.camera_stop_events[camera_id].set()
|
||||
self.camera_threads[camera_id].join(timeout=10.0)
|
||||
|
||||
del self.camera_threads[camera_id]
|
||||
del self.camera_stop_events[camera_id]
|
||||
|
||||
self.stream_manager.remove_stream(str(camera_id))
|
||||
self.roi_filter.clear_cache(camera_id)
|
||||
self.algo_manager.remove_roi(str(camera_id))
|
||||
|
||||
if camera_id in self.camera_latest_frames:
|
||||
del self.camera_latest_frames[camera_id]
|
||||
if camera_id in self.camera_frame_times:
|
||||
del self.camera_frame_times[camera_id]
|
||||
if camera_id in self.camera_process_counts:
|
||||
del self.camera_process_counts[camera_id]
|
||||
|
||||
self._update_camera_status(camera_id, is_running=False)
|
||||
|
||||
def _update_camera_status(
|
||||
self,
|
||||
camera_id: int,
|
||||
is_running: Optional[bool] = None,
|
||||
fps: Optional[float] = None,
|
||||
error_message: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
db = self._get_db_session()
|
||||
update_camera_status(
|
||||
db,
|
||||
camera_id,
|
||||
is_running=is_running,
|
||||
fps=fps,
|
||||
error_message=error_message,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{camera_id}] 更新状态失败: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _camera_inference_loop(self, camera):
|
||||
camera_id = camera.id
|
||||
stop_event = self.camera_stop_events[camera_id]
|
||||
|
||||
while not stop_event.is_set():
|
||||
ret, frame = self.stream_manager.read(str(camera_id))
|
||||
if not ret or frame is None:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
self.camera_latest_frames[camera_id] = frame
|
||||
self.camera_frame_times[camera_id] = datetime.now()
|
||||
|
||||
self.camera_process_counts[camera_id] += 1
|
||||
|
||||
if self.camera_process_counts[camera_id] % camera.process_every_n_frames != 0:
|
||||
continue
|
||||
|
||||
try:
|
||||
self._process_frame(camera_id, frame, camera)
|
||||
except Exception as e:
|
||||
print(f"[{camera_id}] 处理帧失败: {e}")
|
||||
|
||||
print(f"[{camera_id}] 推理线程已停止")
|
||||
|
||||
def _process_frame(self, camera_id: int, frame: np.ndarray, camera):
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
db = self._get_db_session()
|
||||
try:
|
||||
rois = get_all_rois(db, camera_id)
|
||||
roi_configs = [
|
||||
{
|
||||
"id": roi.id,
|
||||
"roi_id": roi.roi_id,
|
||||
"type": roi.roi_type,
|
||||
"points": json.loads(roi.points),
|
||||
"rule": roi.rule_type,
|
||||
"direction": roi.direction,
|
||||
"enabled": roi.enabled,
|
||||
"threshold_sec": roi.threshold_sec,
|
||||
"confirm_sec": roi.confirm_sec,
|
||||
"return_sec": roi.return_sec,
|
||||
}
|
||||
for roi in rois
|
||||
]
|
||||
|
||||
if roi_configs:
|
||||
self.roi_filter.update_cache(camera_id, roi_configs)
|
||||
|
||||
for roi_config in roi_configs:
|
||||
roi_id = roi_config["roi_id"]
|
||||
rule_type = roi_config["rule"]
|
||||
|
||||
self.algo_manager.register_algorithm(
|
||||
roi_id,
|
||||
rule_type,
|
||||
{
|
||||
"threshold_sec": roi_config.get("threshold_sec", 360),
|
||||
"confirm_sec": roi_config.get("confirm_sec", 30),
|
||||
"return_sec": roi_config.get("return_sec", 5),
|
||||
},
|
||||
)
|
||||
|
||||
results = self.yolo_engine(frame, verbose=False, classes=[0])
|
||||
|
||||
if not results:
|
||||
return
|
||||
|
||||
result = results[0]
|
||||
detections = []
|
||||
if hasattr(result, "boxes") and result.boxes is not None:
|
||||
boxes = result.boxes.xyxy.cpu().numpy()
|
||||
confs = result.boxes.conf.cpu().numpy()
|
||||
for i, box in enumerate(boxes):
|
||||
detections.append({
|
||||
"bbox": box.tolist(),
|
||||
"conf": float(confs[i]),
|
||||
"cls": 0,
|
||||
})
|
||||
|
||||
if roi_configs:
|
||||
filtered_detections = self.roi_filter.filter_detections(
|
||||
detections, roi_configs
|
||||
)
|
||||
else:
|
||||
filtered_detections = detections
|
||||
|
||||
for detection in filtered_detections:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
for roi_conf in matched_rois:
|
||||
roi_id = roi_conf["roi_id"]
|
||||
rule_type = roi_conf["rule"]
|
||||
|
||||
alerts = self.algo_manager.process(
|
||||
roi_id,
|
||||
str(camera_id),
|
||||
rule_type,
|
||||
[detection],
|
||||
datetime.now(),
|
||||
)
|
||||
|
||||
for alert in alerts:
|
||||
self._handle_alert(camera_id, alert, frame, roi_conf)
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _handle_alert(
|
||||
self,
|
||||
camera_id: int,
|
||||
alert: Dict[str, Any],
|
||||
frame: np.ndarray,
|
||||
roi_config: Dict[str, Any],
|
||||
):
|
||||
try:
|
||||
snapshot_path = None
|
||||
bbox = alert.get("bbox", [])
|
||||
if bbox and len(bbox) >= 4:
|
||||
x1, y1, x2, y2 = [int(v) for v in bbox]
|
||||
x1, y1 = max(0, x1), max(0, y1)
|
||||
x2, y2 = min(frame.shape[1], x2), min(frame.shape[0], y2)
|
||||
|
||||
snapshot_dir = self.config.alert.snapshot_path
|
||||
os.makedirs(snapshot_dir, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"cam_{camera_id}_{roi_config['roi_id']}_{alert['alert_type']}_{timestamp}.jpg"
|
||||
snapshot_path = os.path.join(snapshot_dir, filename)
|
||||
|
||||
cv2.imwrite(snapshot_path, frame, [cv2.IMWRITE_JPEG_QUALITY, self.config.alert.image_quality])
|
||||
|
||||
db = self._get_db_session()
|
||||
try:
|
||||
alarm = create_alarm(
|
||||
db,
|
||||
camera_id=camera_id,
|
||||
event_type=alert["alert_type"],
|
||||
confidence=alert.get("confidence", alert.get("conf", 0.0)),
|
||||
snapshot_path=snapshot_path,
|
||||
roi_id=roi_config["roi_id"],
|
||||
)
|
||||
alert["alarm_id"] = alarm.id
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
self.event_queue.append({
|
||||
"camera_id": camera_id,
|
||||
"roi_id": roi_config["roi_id"],
|
||||
"event_type": alert["alert_type"],
|
||||
"confidence": alert.get("confidence", alert.get("conf", 0.0)),
|
||||
"message": alert.get("message", ""),
|
||||
"snapshot_path": snapshot_path,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"llm_checked": False,
|
||||
})
|
||||
|
||||
print(f"[{camera_id}] 🚨 告警: {alert['alert_type']} - {alert.get('message', '')}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{camera_id}] 处理告警失败: {e}")
|
||||
|
||||
def get_latest_frame(self, camera_id: int) -> Optional[np.ndarray]:
|
||||
return self.camera_latest_frames.get(camera_id)
|
||||
|
||||
def get_camera_fps(self, camera_id: int) -> float:
|
||||
stream = self.stream_manager.get_stream(str(camera_id))
|
||||
if stream:
|
||||
return stream.fps
|
||||
return 0.0
|
||||
|
||||
def get_event_queue(self) -> List[Dict[str, Any]]:
|
||||
return list(self.event_queue)
|
||||
|
||||
def start(self):
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self._init_database()
|
||||
self._load_cameras()
|
||||
self.running = True
|
||||
|
||||
def stop(self):
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
|
||||
for camera_id in list(self.camera_threads.keys()):
|
||||
self.remove_camera(camera_id)
|
||||
|
||||
self.stream_manager.stop_all()
|
||||
self.algo_manager.reset_all()
|
||||
|
||||
print("推理pipeline已停止")
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"running": self.running,
|
||||
"camera_count": len(self.camera_threads),
|
||||
"cameras": {
|
||||
cid: {
|
||||
"running": self.camera_stop_events[cid] is not None and not self.camera_stop_events[cid].is_set(),
|
||||
"fps": self.get_camera_fps(cid),
|
||||
"frame_time": self.camera_frame_times.get(cid).isoformat() if self.camera_frame_times.get(cid) else None,
|
||||
}
|
||||
for cid in self.camera_threads
|
||||
},
|
||||
"event_queue_size": len(self.event_queue),
|
||||
}
|
||||
|
||||
|
||||
_pipeline: Optional[InferencePipeline] = None
|
||||
|
||||
|
||||
def get_pipeline() -> InferencePipeline:
|
||||
global _pipeline
|
||||
if _pipeline is None:
|
||||
_pipeline = InferencePipeline()
|
||||
return _pipeline
|
||||
|
||||
|
||||
def start_pipeline():
|
||||
pipeline = get_pipeline()
|
||||
pipeline.start()
|
||||
return pipeline
|
||||
|
||||
|
||||
def stop_pipeline():
|
||||
global _pipeline
|
||||
if _pipeline is not None:
|
||||
_pipeline.stop()
|
||||
_pipeline = None
|
||||
Reference in New Issue
Block a user