commit ad6282c6d523dc08bc6c7933b28fe5238ae41737 Author: 16337 <1633794139@qq.com> Date: Mon Jan 12 17:38:39 2026 +0800 Initial commit diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..b6b1ecf --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,10 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 已忽略包含查询文件的默认文件夹 +/queries/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ diff --git a/.idea/Detector.iml b/.idea/Detector.iml new file mode 100644 index 0000000..0dd87bd --- /dev/null +++ b/.idea/Detector.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..710d32f --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,12 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..68ee80e --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..8fdd359 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/__pycache__/detector.cpython-310.pyc b/__pycache__/detector.cpython-310.pyc new file mode 100644 index 0000000..f3a3898 Binary files /dev/null and b/__pycache__/detector.cpython-310.pyc differ diff --git a/__pycache__/sort.cpython-310.pyc b/__pycache__/sort.cpython-310.pyc new file mode 100644 index 0000000..0a9760e Binary files /dev/null and b/__pycache__/sort.cpython-310.pyc differ diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..ff37923 --- /dev/null +++ b/config.yaml @@ -0,0 +1,43 @@ +model: + path: "C:/Users/16337/PycharmProjects/Security/yolo11n.pt" + imgsz: 480 + conf_threshold: 0.5 + device: "cuda" # cuda, cpu + +common: + working_hours: [9, 17] # 工作时间:9:00 ~ 17:00(24小时制) + process_every_n_frames: 3 # 每3帧处理1帧 + alert_cooldown_sec: 300 # 离岗告警冷却(秒) + crowd_cooldown_sec: 180 # 聚集告警冷却(秒) + entry_grace_period_sec: 1.0 # 入岗保护期(防漏检) + +cameras: + - id: "cam_01" + rtsp_url: "rtsp://admin:admin@172.16.8.19:554/cam/realmonitor?channel=16&subtype=1" + roi_points: [[380, 50], [530, 100], [550, 550], [140, 420]] # 离岗检测区域 + crowd_roi_points: [[220, 50], [380, 60], [180, 525], [0, 500]] # 聚集检测区域 + off_duty_threshold_sec: 300 # 离岗超时告警(秒) + on_duty_confirm_sec: 5 # 上岗确认时间(秒) + process_every_n_frames: 5 + off_duty_confirm_sec: 30 # 离岗确认时间(秒) + crowd_threshold: 5 # 聚集人数阈值(最低触发) + + + - id: "cam_02" + rtsp_url: "rtsp://admin:admin@172.16.8.13:554/cam/realmonitor?channel=7&subtype=1" + roi_points: [ [ 380, 50 ], [ 530, 100 ], [ 550, 550 ], [ 140, 420 ] ] # 离岗检测区域 + crowd_roi_points: [ [ 220, 50 ], [ 380, 60 ], [ 180, 525 ], [ 0, 500 ] ] # 聚集检测区域 + off_duty_threshold_sec: 600 + on_duty_confirm_sec: 10 + off_duty_confirm_sec: 20 + crowd_threshold: 3 + + + - id: "cam_03" + rtsp_url: "rtsp://admin:admin@172.16.8.26:554/cam/realmonitor?channel=3&subtype=1" + roi_points: [ [ 380, 50 ], [ 530, 100 ], [ 550, 550 ], [ 140, 420 ] ] # 离岗检测区域 + crowd_roi_points: [ [ 220, 50 ], [ 380, 60 ], [ 180, 525 ], [ 0, 500 ] ] # 聚集检测区域 + off_duty_threshold_sec: 600 + on_duty_confirm_sec: 10 + off_duty_confirm_sec: 20 + crowd_threshold: 3 \ No newline at end of file diff --git a/detector.py b/detector.py new file mode 100644 index 0000000..f729806 --- /dev/null +++ b/detector.py @@ -0,0 +1,218 @@ +import cv2 +import numpy as np +from ultralytics import YOLO +from sort import Sort +import time +import datetime +import threading +import queue +import torch +from collections import deque + + +class ThreadedFrameReader: + def __init__(self, src, maxsize=1): + self.cap = cv2.VideoCapture(src, cv2.CAP_FFMPEG) + self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) + self.q = queue.Queue(maxsize=maxsize) + self.running = True + self.thread = threading.Thread(target=self._reader) + self.thread.daemon = True + self.thread.start() + + def _reader(self): + while self.running: + ret, frame = self.cap.read() + if not ret: + time.sleep(0.1) + continue + if not self.q.empty(): + try: + self.q.get_nowait() + except queue.Empty: + pass + self.q.put(frame) + + def read(self): + if not self.q.empty(): + return True, self.q.get() + return False, None + + def release(self): + self.running = False + self.cap.release() + + +def is_point_in_roi(x, y, roi): + return cv2.pointPolygonTest(roi, (int(x), int(y)), False) >= 0 + + +class OffDutyCrowdDetector: + def __init__(self, config, model, device, use_half): + self.config = config + self.model = model + self.device = device + self.use_half = use_half + + # 解析 ROI + self.roi = np.array(config["roi_points"], dtype=np.int32) + self.crowd_roi = np.array(config["crowd_roi_points"], dtype=np.int32) + + # 状态变量 + self.tracker = Sort( + max_age=30, + min_hits=2, + iou_threshold=0.3 + ) + + self.is_on_duty = False + self.on_duty_start_time = None + self.is_off_duty = True + self.last_no_person_time = None + self.off_duty_timer_start = None + self.last_alert_time = 0 + + self.last_crowd_alert_time = 0 + self.crowd_history = deque(maxlen=1500) # 自动限制5分钟(假设5fps) + + self.last_person_seen_time = None + self.frame_count = 0 + + # 缓存配置 + self.working_start_min = config.get("working_hours", [9, 17])[0] * 60 + self.working_end_min = config.get("working_hours", [9, 17])[1] * 60 + self.process_every = config.get("process_every_n_frames", 3) + + def in_working_hours(self): + now = datetime.datetime.now() + total_min = now.hour * 60 + now.minute + return self.working_start_min <= total_min <= self.working_end_min + + def count_people_in_roi(self, boxes, roi): + count = 0 + for box in boxes: + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() + cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 + if is_point_in_roi(cx, cy, roi): + count += 1 + return count + + def run(self): + """主循环:供线程调用""" + frame_reader = ThreadedFrameReader(self.config["rtsp_url"]) + try: + while True: + ret, frame = frame_reader.read() + if not ret: + time.sleep(0.01) + continue + + self.frame_count += 1 + if self.frame_count % self.process_every != 0: + continue + + current_time = time.time() + now = datetime.datetime.now() + + # YOLO 推理 + results = self.model( + frame, + imgsz=self.config.get("imgsz", 480), + conf=self.config.get("conf_thresh", 0.4), + verbose=False, + device=self.device, + half=self.use_half, + classes=[0] # person class + ) + boxes = results[0].boxes + + # 更新 tracker(可选,用于ID跟踪) + dets = [] + for box in boxes: + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() + conf = float(box.conf) + dets.append([x1, y1, x2, y2, conf]) + dets = np.array(dets) if dets else np.empty((0, 5)) + self.tracker.update(dets) + + # === 离岗检测 === + if self.in_working_hours(): + roi_has_person = self.count_people_in_roi(boxes, self.roi) > 0 + if roi_has_person: + self.last_person_seen_time = current_time + + # 入岗保护期 + effective_on_duty = ( + self.last_person_seen_time is not None and + (current_time - self.last_person_seen_time) < 1.0 + ) + + if effective_on_duty: + self.last_no_person_time = None + if self.is_off_duty: + if self.on_duty_start_time is None: + self.on_duty_start_time = current_time + elif current_time - self.on_duty_start_time >= self.config.get("on_duty_confirm", 5): + self.is_on_duty = True + self.is_off_duty = False + self.on_duty_start_time = None + print(f"[{self.config['id']}] ✅ 上岗确认") + else: + self.on_duty_start_time = None + self.last_person_seen_time = None + if not self.is_off_duty: + if self.last_no_person_time is None: + self.last_no_person_time = current_time + elif current_time - self.last_no_person_time >= self.config.get("off_duty_confirm", 30): + self.is_off_duty = True + self.is_on_duty = False + self.off_duty_timer_start = current_time + print(f"[{self.config['id']}] ⏳ 开始离岗计时") + + # 离岗告警 + if self.is_off_duty and self.off_duty_timer_start: + elapsed = current_time - self.off_duty_timer_start + if elapsed >= self.config.get("off_duty_threshold", 300): + if current_time - self.last_alert_time >= self.config.get("alert_cooldown", 300): + print(f"[{self.config['id']}] 🚨 离岗告警!已离岗 {elapsed/60:.1f} 分钟") + self.last_alert_time = current_time + + # === 聚集检测 === + crowd_count = self.count_people_in_roi(boxes, self.crowd_roi) + self.crowd_history.append((current_time, crowd_count)) + + # 动态阈值 + if crowd_count >= 10: + req_dur = 60 + elif crowd_count >= 7: + req_dur = 120 + elif crowd_count >= 5: + req_dur = 300 + else: + req_dur = float('inf') + + if req_dur < float('inf'): + recent = [(t, c) for t, c in self.crowd_history if current_time - t <= req_dur] + if recent: + valid = [c for t, c in recent if c >= 4] + ratio = len(valid) / len(recent) + if ratio >= 0.9 and (current_time - self.last_crowd_alert_time) >= self.config.get("crowd_cooldown", 180): + print(f"[{self.config['id']}] 🚨 聚集告警!{crowd_count}人持续{req_dur//60}分钟") + self.last_crowd_alert_time = current_time + + # 可视化(可选,部署时可关闭) + if True: # 设为 True 可显示窗口 + vis = results[0].plot() + overlay = vis.copy() + cv2.fillPoly(overlay, [self.roi], (0,255,0)) + cv2.fillPoly(overlay, [self.crowd_roi], (0,0,255)) + cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis) + cv2.imshow(f"Monitor - {self.config['id']}", vis) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + except Exception as e: + print(f"[{self.config['id']}] Error: {e}") + finally: + frame_reader.release() + cv2.destroyAllWindows() \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..eb389a0 --- /dev/null +++ b/main.py @@ -0,0 +1,16 @@ +# 这是一个示例 Python 脚本。 + +# 按 Shift+F10 执行或将其替换为您的代码。 +# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 + + +def print_hi(name): + # 在下面的代码行中使用断点来调试脚本。 + print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。 + + +# 按装订区域中的绿色按钮以运行脚本。 +if __name__ == '__main__': + print_hi('PyCharm') + +# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助 diff --git a/monitor.py b/monitor.py new file mode 100644 index 0000000..d1aa976 --- /dev/null +++ b/monitor.py @@ -0,0 +1,291 @@ +import cv2 +import numpy as np +import yaml +import torch +from ultralytics import YOLO +import time +import datetime +import threading +import queue +import sys +import argparse +from collections import defaultdict + + +class ThreadedFrameReader: + def __init__(self, cam_id, rtsp_url): + self.cam_id = cam_id + self.rtsp_url = rtsp_url + self.cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG) + self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) + self.q = queue.Queue(maxsize=2) + self.running = True + self.thread = threading.Thread(target=self._reader, daemon=True) + self.thread.start() + + def _reader(self): + while self.running: + ret, frame = self.cap.read() + if not ret: + time.sleep(0.1) + continue + if self.q.full(): + try: + self.q.get_nowait() + except queue.Empty: + pass + self.q.put(frame) + + def read(self): + if not self.q.empty(): + return True, self.q.get() + return False, None + + def release(self): + self.running = False + self.cap.release() + + +class MultiCameraMonitor: + def __init__(self, config_path): + with open(config_path, 'r', encoding='utf-8') as f: + self.cfg = yaml.safe_load(f) + + # === 全局模型(只加载一次)=== + model_cfg = self.cfg['model'] + self.device = model_cfg.get('device', 'auto') + if self.device == 'auto' or not self.device: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + print(f"🚀 全局加载模型到 {self.device}...") + self.model = YOLO(model_cfg['path']) + self.model.to(self.device) + self.use_half = (self.device == 'cuda') + if self.use_half: + print("✅ 启用 FP16 推理") + + self.imgsz = model_cfg['imgsz'] + self.conf_thresh = model_cfg['conf_threshold'] + + # === 初始化所有摄像头 === + self.common = self.cfg['common'] + self.cameras = {} + self.frame_readers = {} + self.queues = {} # cam_id -> queue for detection results + + for cam_cfg in self.cfg['cameras']: + cam_id = cam_cfg['id'] + self.cameras[cam_id] = CameraLogic(cam_id, cam_cfg, self.common) + self.frame_readers[cam_id] = ThreadedFrameReader(cam_id, cam_cfg['rtsp_url']) + self.queues[cam_id] = queue.Queue(maxsize=1) # 存放检测结果 + + # === 控制信号 === + self.running = True + self.inference_thread = threading.Thread(target=self._inference_loop, daemon=True) + self.inference_thread.start() + + def _inference_loop(self): + """统一推理线程:轮询各摄像头最新帧,逐个推理""" + while self.running: + processed = False + for cam_id, reader in self.frame_readers.items(): + ret, frame = reader.read() + if not ret: + continue + + # 获取该摄像头是否需要处理此帧 + cam_logic = self.cameras[cam_id] + if cam_logic.should_skip_frame(): + continue + + # 执行推理 + results = self.model( + frame, + imgsz=self.imgsz, + conf=self.conf_thresh, + verbose=False, + device=self.device, + half=self.use_half, + classes=[0] # person + ) + + # 将结果送回对应摄像头逻辑 + if not self.queues[cam_id].full(): + self.queues[cam_id].put((frame.copy(), results[0])) + processed = True + + if not processed: + time.sleep(0.01) + + def run(self): + """启动所有摄像头的显示和告警逻辑(主线程)""" + try: + while self.running: + for cam_id, cam_logic in self.cameras.items(): + if not self.queues[cam_id].empty(): + frame, results = self.queues[cam_id].get() + cam_logic.process(frame, results) + # 让 OpenCV 刷新所有窗口 + key = cv2.waitKey(1) & 0xFF + if key == ord('q'): # 可选:按 q 退出 + break + time.sleep(0.01) # 避免 CPU 占用过高 + except KeyboardInterrupt: + pass + finally: + self.stop() + + def stop(self): + self.running = False + for reader in self.frame_readers.values(): + reader.release() + cv2.destroyAllWindows() + + +class CameraLogic: + def __init__(self, cam_id, cam_cfg, common_cfg): + self.cam_id = cam_id + self.roi_off_duty = np.array(cam_cfg['roi_points'], dtype=np.int32) + self.roi_crowd = np.array(cam_cfg['crowd_roi_points'], dtype=np.int32) + self.off_duty_threshold_sec = cam_cfg['off_duty_threshold_sec'] + self.on_duty_confirm_sec = cam_cfg['on_duty_confirm_sec'] + self.off_duty_confirm_sec = cam_cfg['off_duty_confirm_sec'] + self.crowd_threshold_min = cam_cfg['crowd_threshold'] + + self.working_hours = common_cfg['working_hours'] + self.process_every_n = cam_cfg.get('process_every_n_frames', common_cfg['process_every_n_frames']) + self.alert_cooldown_sec = cam_cfg.get('alert_cooldown_sec', common_cfg['alert_cooldown_sec']) + self.crowd_cooldown_sec = cam_cfg.get('crowd_cooldown_sec', common_cfg['crowd_cooldown_sec']) + + self.frame_count = 0 + self.is_on_duty = False + self.is_off_duty = True + self.on_duty_start_time = None + self.last_no_person_time = None + self.off_duty_timer_start = None + self.last_alert_time = 0 + self.last_crowd_alert_time = 0 + self.crowd_history = [] + self.last_person_seen_time = None + + def should_skip_frame(self): + self.frame_count += 1 + return self.frame_count % self.process_every_n != 0 + + def is_point_in_roi(self, x, y, roi): + return cv2.pointPolygonTest(roi, (int(x), int(y)), False) >= 0 + + def in_working_hours(self): + now = datetime.datetime.now() + h = now.hour + start_h, end_h = self.working_hours + return start_h <= h < end_h or (start_h == end_h == 0) + + def detect_crowd_count(self, boxes, roi): + count = 0 + for box in boxes: + x1, y1, x2, y2 = box.xyxy[0].cpu().numpy() + cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 + if self.is_point_in_roi(cx, cy, roi): + count += 1 + return count + + def process(self, frame, results): + current_time = time.time() + now = datetime.datetime.now() + in_work = self.in_working_hours() + boxes = results.boxes + + # === 离岗逻辑 === + roi_has_person_raw = any( + self.is_point_in_roi((b.xyxy[0][0] + b.xyxy[0][2]) / 2, + (b.xyxy[0][1] + b.xyxy[0][3]) / 2, + self.roi_off_duty) + for b in boxes + ) + + if in_work: + if roi_has_person_raw: + self.last_person_seen_time = current_time + effective = ( + self.last_person_seen_time is not None and + (current_time - self.last_person_seen_time) < 1.0 # grace period + ) + + if effective: + self.last_no_person_time = None + if self.is_off_duty: + if self.on_duty_start_time is None: + self.on_duty_start_time = current_time + elif current_time - self.on_duty_start_time >= self.on_duty_confirm_sec: + self.is_on_duty, self.is_off_duty = True, False + self.on_duty_start_time = None + print(f"[{self.cam_id}] ✅ 上岗确认成功 ({now.strftime('%H:%M:%S')})") + else: + self.on_duty_start_time = None + self.last_person_seen_time = None + if not self.is_off_duty: + if self.last_no_person_time is None: + self.last_no_person_time = current_time + elif current_time - self.last_no_person_time >= self.off_duty_confirm_sec: + self.is_off_duty, self.is_on_duty = True, False + self.last_no_person_time = None + self.off_duty_timer_start = current_time + print(f"[{self.cam_id}] 🚪 进入离岗计时") + + # === 聚集检测 === + crowd_count = self.detect_crowd_count(boxes, self.roi_crowd) + self.crowd_history.append((current_time, crowd_count)) + self.crowd_history = [(t, c) for t, c in self.crowd_history if current_time - t <= 300] + + if in_work and crowd_count >= self.crowd_threshold_min: + recent = [c for t, c in self.crowd_history[-10:]] + if len(recent) >= 5 and sum(c >= self.crowd_threshold_min for c in recent[-5:]) >= 4: + if current_time - self.last_crowd_alert_time >= self.crowd_cooldown_sec: + print(f"[{self.cam_id}] 🚨 聚集告警:{crowd_count}人") + self.last_crowd_alert_time = current_time + + # === 离岗告警 === + if in_work and self.is_off_duty and self.off_duty_timer_start: + if (current_time - self.off_duty_timer_start) >= self.off_duty_threshold_sec: + if (current_time - self.last_alert_time) >= self.alert_cooldown_sec: + print(f"[{self.cam_id}] 🚨 离岗告警!") + self.last_alert_time = current_time + + # === 可视化 === + vis = results.plot() + overlay = vis.copy() + cv2.fillPoly(overlay, [self.roi_off_duty], (0, 255, 0)) + cv2.fillPoly(overlay, [self.roi_crowd], (0, 0, 255)) + cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis) + + status = "OUT OF HOURS" + color = (128, 128, 128) + if in_work: + if self.is_on_duty: + status, color = "ON DUTY", (0, 255, 0) + elif self.is_off_duty: + if self.off_duty_timer_start: + elapsed = int(current_time - self.off_duty_timer_start) + if elapsed >= self.off_duty_threshold_sec: + status, color = "OFF DUTY!", (0, 0, 255) + else: + status = f"IDLE - {elapsed}s" + color = (0, 255, 255) + else: + status, color = "OFF DUTY", (255, 0, 0) + + cv2.putText(vis, f"[{self.cam_id}] {status}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2) + cv2.imshow(f"Monitor - {self.cam_id}", vis) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", default="config.yaml", help="配置文件路径") + args = parser.parse_args() + + monitor = MultiCameraMonitor(args.config) + monitor.run() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/run_detectors.py b/run_detectors.py new file mode 100644 index 0000000..3269d8e --- /dev/null +++ b/run_detectors.py @@ -0,0 +1,52 @@ +import yaml +import threading +from ultralytics import YOLO +import torch +from detector import OffDutyCrowdDetector +import os + + +def load_config(config_path="config.yaml"): + with open(config_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) + + +def main(): + config = load_config() + + # 全局模型(共享) + model_path = config["model"]["path"] + device = config["model"].get("device", "cuda" if torch.cuda.is_available() else "cpu") + use_half = (device == "cuda") + + print(f"Loading model {model_path} on {device} (FP16: {use_half})") + model = YOLO(model_path) + model.to(device) + if use_half: + model.model.half() + + # 启动每个摄像头的检测线程 + threads = [] + for cam_cfg in config["cameras"]: + # 合并 common 配置 + full_cfg = {**config.get("common", {}), **cam_cfg} + full_cfg["imgsz"] = config["model"]["imgsz"] + full_cfg["conf_thresh"] = config["model"]["conf_thresh"] + full_cfg["working_hours"] = config["common"]["working_hours"] + + detector = OffDutyCrowdDetector(full_cfg, model, device, use_half) + thread = threading.Thread(target=detector.run, daemon=True) + thread.start() + threads.append(thread) + print(f"Started detector for {cam_cfg['id']}") + + print(f"✅ 已启动 {len(threads)} 路摄像头检测,按 Ctrl+C 退出") + try: + for t in threads: + t.join() + except KeyboardInterrupt: + print("\n🛑 Shutting down...") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/sort.py b/sort.py new file mode 100644 index 0000000..fb41847 --- /dev/null +++ b/sort.py @@ -0,0 +1,214 @@ +import numpy as np +from scipy.optimize import linear_sum_assignment +from filterpy.kalman import KalmanFilter + +def linear_assignment(cost_matrix): + x, y = linear_sum_assignment(cost_matrix) + return np.array(list(zip(x, y))) + +def iou_batch(bb_test, bb_gt): + """ + From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] + """ + bb_test = np.expand_dims(bb_test, 1) + bb_gt = np.expand_dims(bb_gt, 0) + xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0]) + yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1]) + xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2]) + yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3]) + w = np.maximum(0., xx2 - xx1) + h = np.maximum(0., yy2 - yy1) + wh = w * h + o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1]) + + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh) + return o + +def convert_bbox_to_z(bbox): + """ + Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form + [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is + the aspect ratio + """ + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + x = bbox[0] + w / 2. + y = bbox[1] + h / 2. + s = w * h # scale is just area + r = w / float(h) + return np.array([x, y, s, r]).reshape((4, 1)) + +def convert_x_to_bbox(x, score=None): + """ + Takes a bounding box in the centre form [x,y,s,r] and returns it in the form + [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right + """ + w = np.sqrt(x[2] * x[3]) + h = x[2] / w + if score is None: + return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4)) + else: + return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5)) + +class KalmanBoxTracker(object): + """ + This class represents the internal state of individual tracked objects observed as bbox. + """ + count = 0 + + def __init__(self, bbox): + """ + Initialises a tracker using initial bounding box. + """ + self.kf = KalmanFilter(dim_x=7, dim_z=4) + self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], + [0, 1, 0, 0, 0, 1, 0], + [0, 0, 1, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 1]]) + self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0]]) + + self.kf.R[2:, 2:] *= 10. + self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities + self.kf.P *= 10. + self.kf.Q[-1, -1] *= 0.01 + self.kf.Q[4:, 4:] *= 0.01 + + self.kf.x[:4] = convert_bbox_to_z(bbox) + self.time_since_update = 0 + self.id = KalmanBoxTracker.count + KalmanBoxTracker.count += 1 + self.history = [] + self.hits = 0 + self.hit_streak = 0 + self.age = 0 + + def update(self, bbox): + """ + Updates the state vector with observed bbox. + """ + self.time_since_update = 0 + self.history = [] + self.hits += 1 + self.hit_streak += 1 + self.kf.update(convert_bbox_to_z(bbox)) + + def predict(self): + """ + Advances the state vector and returns the predicted bounding box estimate. + """ + if (self.kf.x[6] + self.kf.x[2]) <= 0: + self.kf.x[6] *= 0.0 + self.kf.predict() + self.age += 1 + if self.time_since_update > 0: + self.hit_streak = 0 + self.time_since_update += 1 + self.history.append(convert_x_to_bbox(self.kf.x)) + return self.history[-1] + + def get_state(self): + """ + Returns the current bounding box estimate. + """ + return convert_x_to_bbox(self.kf.x) + +class Sort(object): + def __init__(self, max_age=5, min_hits=2, iou_threshold=0.3): + """ + Sets key parameters for SORT + """ + self.max_age = max_age + self.min_hits = min_hits + self.iou_threshold = iou_threshold + self.trackers = [] + self.frame_count = 0 + + def update(self, dets=np.empty((0, 5))): + """ + Params: + dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],...] + Requires: this method must be called once for each frame even with empty detections. + Returns the a similar array, where the last column is the object ID. + + NOTE: The number of objects returned may differ from the number of detections provided. + """ + self.frame_count += 1 + trks = np.zeros((len(self.trackers), 5)) + to_del = [] + ret = [] + for t, trk in enumerate(trks): + pos = self.trackers[t].predict()[0] + trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] + if np.any(np.isnan(pos)): + to_del.append(t) + trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) + for t in reversed(to_del): + self.trackers.pop(t) + matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold) + + # update matched trackers with assigned detections + for m in matched: + self.trackers[m[1]].update(dets[m[0], :]) + + # create and initialise new trackers for unmatched detections + for i in unmatched_dets: + trk = KalmanBoxTracker(dets[i, :]) + self.trackers.append(trk) + i = len(self.trackers) + for trk in reversed(self.trackers): + d = trk.get_state()[0] + if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): + ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive + i -= 1 + # remove dead tracklet + if trk.time_since_update > self.max_age: + self.trackers.pop(i) + if len(ret) > 0: + return np.concatenate(ret) + return np.empty((0, 5)) + +def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3): + """ + Assigns detections to tracked object (both represented as bounding boxes) + Returns 3 lists of matches, unmatched_detections, unmatched_trackers + """ + if len(trackers) == 0: + return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 1), dtype=int) + iou_matrix = iou_batch(detections, trackers) + + if min(iou_matrix.shape) > 0: + a = (iou_matrix > iou_threshold).astype(np.int32) + if a.sum(1).max() == 1 and a.sum(0).max() == 1: + matched_indices = np.stack(np.where(a), axis=1) + else: + matched_indices = linear_assignment(-iou_matrix) + else: + matched_indices = np.empty(shape=(0, 2)) + + unmatched_detections = [] + for d, det in enumerate(detections): + if d not in matched_indices[:, 0]: + unmatched_detections.append(d) + unmatched_trackers = [] + for t, trk in enumerate(trackers): + if t not in matched_indices[:, 1]: + unmatched_trackers.append(t) + + matches = [] + for m in matched_indices: + if iou_matrix[m[0], m[1]] < iou_threshold: + unmatched_detections.append(m[0]) + unmatched_trackers.append(m[1]) + else: + matches.append(m.reshape(1, 2)) + if len(matches) == 0: + matches = np.empty((0, 2), dtype=int) + else: + matches = np.concatenate(matches, axis=0) + + return matches, np.array(unmatched_detections), np.array(unmatched_trackers) \ No newline at end of file