Initial commit

This commit is contained in:
2026-01-12 17:38:39 +08:00
commit ad6282c6d5
15 changed files with 891 additions and 0 deletions

10
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,10 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 已忽略包含查询文件的默认文件夹
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/

8
.idea/Detector.iml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="C:\Users\16337\miniconda3\envs\yolov11" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@@ -0,0 +1,12 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyStubPackagesAdvertiser" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<list>
<option value="pandas" />
</list>
</option>
</inspection_tool>
</profile>
</component>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="human_identify" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="C:\Users\16337\miniconda3\envs\yolov11" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Detector.iml" filepath="$PROJECT_DIR$/.idea/Detector.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

Binary file not shown.

Binary file not shown.

43
config.yaml Normal file
View File

@@ -0,0 +1,43 @@
model:
path: "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
imgsz: 480
conf_threshold: 0.5
device: "cuda" # cuda, cpu
common:
working_hours: [9, 17] # 工作时间9:00 ~ 17:0024小时制
process_every_n_frames: 3 # 每3帧处理1帧
alert_cooldown_sec: 300 # 离岗告警冷却(秒)
crowd_cooldown_sec: 180 # 聚集告警冷却(秒)
entry_grace_period_sec: 1.0 # 入岗保护期(防漏检)
cameras:
- id: "cam_01"
rtsp_url: "rtsp://admin:admin@172.16.8.19:554/cam/realmonitor?channel=16&subtype=1"
roi_points: [[380, 50], [530, 100], [550, 550], [140, 420]] # 离岗检测区域
crowd_roi_points: [[220, 50], [380, 60], [180, 525], [0, 500]] # 聚集检测区域
off_duty_threshold_sec: 300 # 离岗超时告警(秒)
on_duty_confirm_sec: 5 # 上岗确认时间(秒)
process_every_n_frames: 5
off_duty_confirm_sec: 30 # 离岗确认时间(秒)
crowd_threshold: 5 # 聚集人数阈值(最低触发)
- id: "cam_02"
rtsp_url: "rtsp://admin:admin@172.16.8.13:554/cam/realmonitor?channel=7&subtype=1"
roi_points: [ [ 380, 50 ], [ 530, 100 ], [ 550, 550 ], [ 140, 420 ] ] # 离岗检测区域
crowd_roi_points: [ [ 220, 50 ], [ 380, 60 ], [ 180, 525 ], [ 0, 500 ] ] # 聚集检测区域
off_duty_threshold_sec: 600
on_duty_confirm_sec: 10
off_duty_confirm_sec: 20
crowd_threshold: 3
- id: "cam_03"
rtsp_url: "rtsp://admin:admin@172.16.8.26:554/cam/realmonitor?channel=3&subtype=1"
roi_points: [ [ 380, 50 ], [ 530, 100 ], [ 550, 550 ], [ 140, 420 ] ] # 离岗检测区域
crowd_roi_points: [ [ 220, 50 ], [ 380, 60 ], [ 180, 525 ], [ 0, 500 ] ] # 聚集检测区域
off_duty_threshold_sec: 600
on_duty_confirm_sec: 10
off_duty_confirm_sec: 20
crowd_threshold: 3

218
detector.py Normal file
View File

@@ -0,0 +1,218 @@
import cv2
import numpy as np
from ultralytics import YOLO
from sort import Sort
import time
import datetime
import threading
import queue
import torch
from collections import deque
class ThreadedFrameReader:
def __init__(self, src, maxsize=1):
self.cap = cv2.VideoCapture(src, cv2.CAP_FFMPEG)
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
self.q = queue.Queue(maxsize=maxsize)
self.running = True
self.thread = threading.Thread(target=self._reader)
self.thread.daemon = True
self.thread.start()
def _reader(self):
while self.running:
ret, frame = self.cap.read()
if not ret:
time.sleep(0.1)
continue
if not self.q.empty():
try:
self.q.get_nowait()
except queue.Empty:
pass
self.q.put(frame)
def read(self):
if not self.q.empty():
return True, self.q.get()
return False, None
def release(self):
self.running = False
self.cap.release()
def is_point_in_roi(x, y, roi):
return cv2.pointPolygonTest(roi, (int(x), int(y)), False) >= 0
class OffDutyCrowdDetector:
def __init__(self, config, model, device, use_half):
self.config = config
self.model = model
self.device = device
self.use_half = use_half
# 解析 ROI
self.roi = np.array(config["roi_points"], dtype=np.int32)
self.crowd_roi = np.array(config["crowd_roi_points"], dtype=np.int32)
# 状态变量
self.tracker = Sort(
max_age=30,
min_hits=2,
iou_threshold=0.3
)
self.is_on_duty = False
self.on_duty_start_time = None
self.is_off_duty = True
self.last_no_person_time = None
self.off_duty_timer_start = None
self.last_alert_time = 0
self.last_crowd_alert_time = 0
self.crowd_history = deque(maxlen=1500) # 自动限制5分钟假设5fps
self.last_person_seen_time = None
self.frame_count = 0
# 缓存配置
self.working_start_min = config.get("working_hours", [9, 17])[0] * 60
self.working_end_min = config.get("working_hours", [9, 17])[1] * 60
self.process_every = config.get("process_every_n_frames", 3)
def in_working_hours(self):
now = datetime.datetime.now()
total_min = now.hour * 60 + now.minute
return self.working_start_min <= total_min <= self.working_end_min
def count_people_in_roi(self, boxes, roi):
count = 0
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
if is_point_in_roi(cx, cy, roi):
count += 1
return count
def run(self):
"""主循环:供线程调用"""
frame_reader = ThreadedFrameReader(self.config["rtsp_url"])
try:
while True:
ret, frame = frame_reader.read()
if not ret:
time.sleep(0.01)
continue
self.frame_count += 1
if self.frame_count % self.process_every != 0:
continue
current_time = time.time()
now = datetime.datetime.now()
# YOLO 推理
results = self.model(
frame,
imgsz=self.config.get("imgsz", 480),
conf=self.config.get("conf_thresh", 0.4),
verbose=False,
device=self.device,
half=self.use_half,
classes=[0] # person class
)
boxes = results[0].boxes
# 更新 tracker可选用于ID跟踪
dets = []
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
conf = float(box.conf)
dets.append([x1, y1, x2, y2, conf])
dets = np.array(dets) if dets else np.empty((0, 5))
self.tracker.update(dets)
# === 离岗检测 ===
if self.in_working_hours():
roi_has_person = self.count_people_in_roi(boxes, self.roi) > 0
if roi_has_person:
self.last_person_seen_time = current_time
# 入岗保护期
effective_on_duty = (
self.last_person_seen_time is not None and
(current_time - self.last_person_seen_time) < 1.0
)
if effective_on_duty:
self.last_no_person_time = None
if self.is_off_duty:
if self.on_duty_start_time is None:
self.on_duty_start_time = current_time
elif current_time - self.on_duty_start_time >= self.config.get("on_duty_confirm", 5):
self.is_on_duty = True
self.is_off_duty = False
self.on_duty_start_time = None
print(f"[{self.config['id']}] ✅ 上岗确认")
else:
self.on_duty_start_time = None
self.last_person_seen_time = None
if not self.is_off_duty:
if self.last_no_person_time is None:
self.last_no_person_time = current_time
elif current_time - self.last_no_person_time >= self.config.get("off_duty_confirm", 30):
self.is_off_duty = True
self.is_on_duty = False
self.off_duty_timer_start = current_time
print(f"[{self.config['id']}] ⏳ 开始离岗计时")
# 离岗告警
if self.is_off_duty and self.off_duty_timer_start:
elapsed = current_time - self.off_duty_timer_start
if elapsed >= self.config.get("off_duty_threshold", 300):
if current_time - self.last_alert_time >= self.config.get("alert_cooldown", 300):
print(f"[{self.config['id']}] 🚨 离岗告警!已离岗 {elapsed/60:.1f} 分钟")
self.last_alert_time = current_time
# === 聚集检测 ===
crowd_count = self.count_people_in_roi(boxes, self.crowd_roi)
self.crowd_history.append((current_time, crowd_count))
# 动态阈值
if crowd_count >= 10:
req_dur = 60
elif crowd_count >= 7:
req_dur = 120
elif crowd_count >= 5:
req_dur = 300
else:
req_dur = float('inf')
if req_dur < float('inf'):
recent = [(t, c) for t, c in self.crowd_history if current_time - t <= req_dur]
if recent:
valid = [c for t, c in recent if c >= 4]
ratio = len(valid) / len(recent)
if ratio >= 0.9 and (current_time - self.last_crowd_alert_time) >= self.config.get("crowd_cooldown", 180):
print(f"[{self.config['id']}] 🚨 聚集告警!{crowd_count}人持续{req_dur//60}分钟")
self.last_crowd_alert_time = current_time
# 可视化(可选,部署时可关闭)
if True: # 设为 True 可显示窗口
vis = results[0].plot()
overlay = vis.copy()
cv2.fillPoly(overlay, [self.roi], (0,255,0))
cv2.fillPoly(overlay, [self.crowd_roi], (0,0,255))
cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis)
cv2.imshow(f"Monitor - {self.config['id']}", vis)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
except Exception as e:
print(f"[{self.config['id']}] Error: {e}")
finally:
frame_reader.release()
cv2.destroyAllWindows()

16
main.py Normal file
View File

@@ -0,0 +1,16 @@
# 这是一个示例 Python 脚本。
# 按 Shift+F10 执行或将其替换为您的代码。
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
def print_hi(name):
# 在下面的代码行中使用断点来调试脚本。
print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。
# 按装订区域中的绿色按钮以运行脚本。
if __name__ == '__main__':
print_hi('PyCharm')
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助

291
monitor.py Normal file
View File

@@ -0,0 +1,291 @@
import cv2
import numpy as np
import yaml
import torch
from ultralytics import YOLO
import time
import datetime
import threading
import queue
import sys
import argparse
from collections import defaultdict
class ThreadedFrameReader:
def __init__(self, cam_id, rtsp_url):
self.cam_id = cam_id
self.rtsp_url = rtsp_url
self.cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
self.q = queue.Queue(maxsize=2)
self.running = True
self.thread = threading.Thread(target=self._reader, daemon=True)
self.thread.start()
def _reader(self):
while self.running:
ret, frame = self.cap.read()
if not ret:
time.sleep(0.1)
continue
if self.q.full():
try:
self.q.get_nowait()
except queue.Empty:
pass
self.q.put(frame)
def read(self):
if not self.q.empty():
return True, self.q.get()
return False, None
def release(self):
self.running = False
self.cap.release()
class MultiCameraMonitor:
def __init__(self, config_path):
with open(config_path, 'r', encoding='utf-8') as f:
self.cfg = yaml.safe_load(f)
# === 全局模型(只加载一次)===
model_cfg = self.cfg['model']
self.device = model_cfg.get('device', 'auto')
if self.device == 'auto' or not self.device:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"🚀 全局加载模型到 {self.device}...")
self.model = YOLO(model_cfg['path'])
self.model.to(self.device)
self.use_half = (self.device == 'cuda')
if self.use_half:
print("✅ 启用 FP16 推理")
self.imgsz = model_cfg['imgsz']
self.conf_thresh = model_cfg['conf_threshold']
# === 初始化所有摄像头 ===
self.common = self.cfg['common']
self.cameras = {}
self.frame_readers = {}
self.queues = {} # cam_id -> queue for detection results
for cam_cfg in self.cfg['cameras']:
cam_id = cam_cfg['id']
self.cameras[cam_id] = CameraLogic(cam_id, cam_cfg, self.common)
self.frame_readers[cam_id] = ThreadedFrameReader(cam_id, cam_cfg['rtsp_url'])
self.queues[cam_id] = queue.Queue(maxsize=1) # 存放检测结果
# === 控制信号 ===
self.running = True
self.inference_thread = threading.Thread(target=self._inference_loop, daemon=True)
self.inference_thread.start()
def _inference_loop(self):
"""统一推理线程:轮询各摄像头最新帧,逐个推理"""
while self.running:
processed = False
for cam_id, reader in self.frame_readers.items():
ret, frame = reader.read()
if not ret:
continue
# 获取该摄像头是否需要处理此帧
cam_logic = self.cameras[cam_id]
if cam_logic.should_skip_frame():
continue
# 执行推理
results = self.model(
frame,
imgsz=self.imgsz,
conf=self.conf_thresh,
verbose=False,
device=self.device,
half=self.use_half,
classes=[0] # person
)
# 将结果送回对应摄像头逻辑
if not self.queues[cam_id].full():
self.queues[cam_id].put((frame.copy(), results[0]))
processed = True
if not processed:
time.sleep(0.01)
def run(self):
"""启动所有摄像头的显示和告警逻辑(主线程)"""
try:
while self.running:
for cam_id, cam_logic in self.cameras.items():
if not self.queues[cam_id].empty():
frame, results = self.queues[cam_id].get()
cam_logic.process(frame, results)
# 让 OpenCV 刷新所有窗口
key = cv2.waitKey(1) & 0xFF
if key == ord('q'): # 可选:按 q 退出
break
time.sleep(0.01) # 避免 CPU 占用过高
except KeyboardInterrupt:
pass
finally:
self.stop()
def stop(self):
self.running = False
for reader in self.frame_readers.values():
reader.release()
cv2.destroyAllWindows()
class CameraLogic:
def __init__(self, cam_id, cam_cfg, common_cfg):
self.cam_id = cam_id
self.roi_off_duty = np.array(cam_cfg['roi_points'], dtype=np.int32)
self.roi_crowd = np.array(cam_cfg['crowd_roi_points'], dtype=np.int32)
self.off_duty_threshold_sec = cam_cfg['off_duty_threshold_sec']
self.on_duty_confirm_sec = cam_cfg['on_duty_confirm_sec']
self.off_duty_confirm_sec = cam_cfg['off_duty_confirm_sec']
self.crowd_threshold_min = cam_cfg['crowd_threshold']
self.working_hours = common_cfg['working_hours']
self.process_every_n = cam_cfg.get('process_every_n_frames', common_cfg['process_every_n_frames'])
self.alert_cooldown_sec = cam_cfg.get('alert_cooldown_sec', common_cfg['alert_cooldown_sec'])
self.crowd_cooldown_sec = cam_cfg.get('crowd_cooldown_sec', common_cfg['crowd_cooldown_sec'])
self.frame_count = 0
self.is_on_duty = False
self.is_off_duty = True
self.on_duty_start_time = None
self.last_no_person_time = None
self.off_duty_timer_start = None
self.last_alert_time = 0
self.last_crowd_alert_time = 0
self.crowd_history = []
self.last_person_seen_time = None
def should_skip_frame(self):
self.frame_count += 1
return self.frame_count % self.process_every_n != 0
def is_point_in_roi(self, x, y, roi):
return cv2.pointPolygonTest(roi, (int(x), int(y)), False) >= 0
def in_working_hours(self):
now = datetime.datetime.now()
h = now.hour
start_h, end_h = self.working_hours
return start_h <= h < end_h or (start_h == end_h == 0)
def detect_crowd_count(self, boxes, roi):
count = 0
for box in boxes:
x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
if self.is_point_in_roi(cx, cy, roi):
count += 1
return count
def process(self, frame, results):
current_time = time.time()
now = datetime.datetime.now()
in_work = self.in_working_hours()
boxes = results.boxes
# === 离岗逻辑 ===
roi_has_person_raw = any(
self.is_point_in_roi((b.xyxy[0][0] + b.xyxy[0][2]) / 2,
(b.xyxy[0][1] + b.xyxy[0][3]) / 2,
self.roi_off_duty)
for b in boxes
)
if in_work:
if roi_has_person_raw:
self.last_person_seen_time = current_time
effective = (
self.last_person_seen_time is not None and
(current_time - self.last_person_seen_time) < 1.0 # grace period
)
if effective:
self.last_no_person_time = None
if self.is_off_duty:
if self.on_duty_start_time is None:
self.on_duty_start_time = current_time
elif current_time - self.on_duty_start_time >= self.on_duty_confirm_sec:
self.is_on_duty, self.is_off_duty = True, False
self.on_duty_start_time = None
print(f"[{self.cam_id}] ✅ 上岗确认成功 ({now.strftime('%H:%M:%S')})")
else:
self.on_duty_start_time = None
self.last_person_seen_time = None
if not self.is_off_duty:
if self.last_no_person_time is None:
self.last_no_person_time = current_time
elif current_time - self.last_no_person_time >= self.off_duty_confirm_sec:
self.is_off_duty, self.is_on_duty = True, False
self.last_no_person_time = None
self.off_duty_timer_start = current_time
print(f"[{self.cam_id}] 🚪 进入离岗计时")
# === 聚集检测 ===
crowd_count = self.detect_crowd_count(boxes, self.roi_crowd)
self.crowd_history.append((current_time, crowd_count))
self.crowd_history = [(t, c) for t, c in self.crowd_history if current_time - t <= 300]
if in_work and crowd_count >= self.crowd_threshold_min:
recent = [c for t, c in self.crowd_history[-10:]]
if len(recent) >= 5 and sum(c >= self.crowd_threshold_min for c in recent[-5:]) >= 4:
if current_time - self.last_crowd_alert_time >= self.crowd_cooldown_sec:
print(f"[{self.cam_id}] 🚨 聚集告警:{crowd_count}")
self.last_crowd_alert_time = current_time
# === 离岗告警 ===
if in_work and self.is_off_duty and self.off_duty_timer_start:
if (current_time - self.off_duty_timer_start) >= self.off_duty_threshold_sec:
if (current_time - self.last_alert_time) >= self.alert_cooldown_sec:
print(f"[{self.cam_id}] 🚨 离岗告警!")
self.last_alert_time = current_time
# === 可视化 ===
vis = results.plot()
overlay = vis.copy()
cv2.fillPoly(overlay, [self.roi_off_duty], (0, 255, 0))
cv2.fillPoly(overlay, [self.roi_crowd], (0, 0, 255))
cv2.addWeighted(overlay, 0.2, vis, 0.8, 0, vis)
status = "OUT OF HOURS"
color = (128, 128, 128)
if in_work:
if self.is_on_duty:
status, color = "ON DUTY", (0, 255, 0)
elif self.is_off_duty:
if self.off_duty_timer_start:
elapsed = int(current_time - self.off_duty_timer_start)
if elapsed >= self.off_duty_threshold_sec:
status, color = "OFF DUTY!", (0, 0, 255)
else:
status = f"IDLE - {elapsed}s"
color = (0, 255, 255)
else:
status, color = "OFF DUTY", (255, 0, 0)
cv2.putText(vis, f"[{self.cam_id}] {status}", (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)
cv2.imshow(f"Monitor - {self.cam_id}", vis)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config.yaml", help="配置文件路径")
args = parser.parse_args()
monitor = MultiCameraMonitor(args.config)
monitor.run()
if __name__ == "__main__":
main()

52
run_detectors.py Normal file
View File

@@ -0,0 +1,52 @@
import yaml
import threading
from ultralytics import YOLO
import torch
from detector import OffDutyCrowdDetector
import os
def load_config(config_path="config.yaml"):
with open(config_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def main():
config = load_config()
# 全局模型(共享)
model_path = config["model"]["path"]
device = config["model"].get("device", "cuda" if torch.cuda.is_available() else "cpu")
use_half = (device == "cuda")
print(f"Loading model {model_path} on {device} (FP16: {use_half})")
model = YOLO(model_path)
model.to(device)
if use_half:
model.model.half()
# 启动每个摄像头的检测线程
threads = []
for cam_cfg in config["cameras"]:
# 合并 common 配置
full_cfg = {**config.get("common", {}), **cam_cfg}
full_cfg["imgsz"] = config["model"]["imgsz"]
full_cfg["conf_thresh"] = config["model"]["conf_thresh"]
full_cfg["working_hours"] = config["common"]["working_hours"]
detector = OffDutyCrowdDetector(full_cfg, model, device, use_half)
thread = threading.Thread(target=detector.run, daemon=True)
thread.start()
threads.append(thread)
print(f"Started detector for {cam_cfg['id']}")
print(f"✅ 已启动 {len(threads)} 路摄像头检测,按 Ctrl+C 退出")
try:
for t in threads:
t.join()
except KeyboardInterrupt:
print("\n🛑 Shutting down...")
if __name__ == "__main__":
main()

214
sort.py Normal file
View File

@@ -0,0 +1,214 @@
import numpy as np
from scipy.optimize import linear_sum_assignment
from filterpy.kalman import KalmanFilter
def linear_assignment(cost_matrix):
x, y = linear_sum_assignment(cost_matrix)
return np.array(list(zip(x, y)))
def iou_batch(bb_test, bb_gt):
"""
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
"""
bb_test = np.expand_dims(bb_test, 1)
bb_gt = np.expand_dims(bb_gt, 0)
xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
w = np.maximum(0., xx2 - xx1)
h = np.maximum(0., yy2 - yy1)
wh = w * h
o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
return o
def convert_bbox_to_z(bbox):
"""
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
the aspect ratio
"""
w = bbox[2] - bbox[0]
h = bbox[3] - bbox[1]
x = bbox[0] + w / 2.
y = bbox[1] + h / 2.
s = w * h # scale is just area
r = w / float(h)
return np.array([x, y, s, r]).reshape((4, 1))
def convert_x_to_bbox(x, score=None):
"""
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
"""
w = np.sqrt(x[2] * x[3])
h = x[2] / w
if score is None:
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4))
else:
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5))
class KalmanBoxTracker(object):
"""
This class represents the internal state of individual tracked objects observed as bbox.
"""
count = 0
def __init__(self, bbox):
"""
Initialises a tracker using initial bounding box.
"""
self.kf = KalmanFilter(dim_x=7, dim_z=4)
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 0, 1],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1]])
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0]])
self.kf.R[2:, 2:] *= 10.
self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
self.kf.P *= 10.
self.kf.Q[-1, -1] *= 0.01
self.kf.Q[4:, 4:] *= 0.01
self.kf.x[:4] = convert_bbox_to_z(bbox)
self.time_since_update = 0
self.id = KalmanBoxTracker.count
KalmanBoxTracker.count += 1
self.history = []
self.hits = 0
self.hit_streak = 0
self.age = 0
def update(self, bbox):
"""
Updates the state vector with observed bbox.
"""
self.time_since_update = 0
self.history = []
self.hits += 1
self.hit_streak += 1
self.kf.update(convert_bbox_to_z(bbox))
def predict(self):
"""
Advances the state vector and returns the predicted bounding box estimate.
"""
if (self.kf.x[6] + self.kf.x[2]) <= 0:
self.kf.x[6] *= 0.0
self.kf.predict()
self.age += 1
if self.time_since_update > 0:
self.hit_streak = 0
self.time_since_update += 1
self.history.append(convert_x_to_bbox(self.kf.x))
return self.history[-1]
def get_state(self):
"""
Returns the current bounding box estimate.
"""
return convert_x_to_bbox(self.kf.x)
class Sort(object):
def __init__(self, max_age=5, min_hits=2, iou_threshold=0.3):
"""
Sets key parameters for SORT
"""
self.max_age = max_age
self.min_hits = min_hits
self.iou_threshold = iou_threshold
self.trackers = []
self.frame_count = 0
def update(self, dets=np.empty((0, 5))):
"""
Params:
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],...]
Requires: this method must be called once for each frame even with empty detections.
Returns the a similar array, where the last column is the object ID.
NOTE: The number of objects returned may differ from the number of detections provided.
"""
self.frame_count += 1
trks = np.zeros((len(self.trackers), 5))
to_del = []
ret = []
for t, trk in enumerate(trks):
pos = self.trackers[t].predict()[0]
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
if np.any(np.isnan(pos)):
to_del.append(t)
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
for t in reversed(to_del):
self.trackers.pop(t)
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
# update matched trackers with assigned detections
for m in matched:
self.trackers[m[1]].update(dets[m[0], :])
# create and initialise new trackers for unmatched detections
for i in unmatched_dets:
trk = KalmanBoxTracker(dets[i, :])
self.trackers.append(trk)
i = len(self.trackers)
for trk in reversed(self.trackers):
d = trk.get_state()[0]
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive
i -= 1
# remove dead tracklet
if trk.time_since_update > self.max_age:
self.trackers.pop(i)
if len(ret) > 0:
return np.concatenate(ret)
return np.empty((0, 5))
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
"""
Assigns detections to tracked object (both represented as bounding boxes)
Returns 3 lists of matches, unmatched_detections, unmatched_trackers
"""
if len(trackers) == 0:
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 1), dtype=int)
iou_matrix = iou_batch(detections, trackers)
if min(iou_matrix.shape) > 0:
a = (iou_matrix > iou_threshold).astype(np.int32)
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
matched_indices = np.stack(np.where(a), axis=1)
else:
matched_indices = linear_assignment(-iou_matrix)
else:
matched_indices = np.empty(shape=(0, 2))
unmatched_detections = []
for d, det in enumerate(detections):
if d not in matched_indices[:, 0]:
unmatched_detections.append(d)
unmatched_trackers = []
for t, trk in enumerate(trackers):
if t not in matched_indices[:, 1]:
unmatched_trackers.append(t)
matches = []
for m in matched_indices:
if iou_matrix[m[0], m[1]] < iou_threshold:
unmatched_detections.append(m[0])
unmatched_trackers.append(m[1])
else:
matches.append(m.reshape(1, 2))
if len(matches) == 0:
matches = np.empty((0, 2), dtype=int)
else:
matches = np.concatenate(matches, axis=0)
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)