This commit is contained in:
116
api/sync.py
Normal file
116
api/sync.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
|
||||
from db.models import get_db
|
||||
from services.sync_service import get_sync_service
|
||||
|
||||
router = APIRouter(prefix="/api/sync", tags=["同步"])
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
def get_sync_status(db: Session = Depends(get_db)):
|
||||
"""获取同步服务状态"""
|
||||
from sqlalchemy import text
|
||||
|
||||
service = get_sync_service()
|
||||
status = service.get_status()
|
||||
|
||||
pending_cameras = db.execute(
|
||||
text("SELECT COUNT(*) FROM cameras WHERE pending_sync = 1")
|
||||
).scalar() or 0
|
||||
|
||||
pending_rois = db.execute(
|
||||
text("SELECT COUNT(*) FROM rois WHERE pending_sync = 1")
|
||||
).scalar() or 0
|
||||
|
||||
pending_alarms = db.execute(
|
||||
text("SELECT COUNT(*) FROM alarms WHERE upload_status = 'pending' OR upload_status = 'retry'")
|
||||
).scalar() or 0
|
||||
|
||||
return {
|
||||
"running": status["running"],
|
||||
"cloud_enabled": status["cloud_enabled"],
|
||||
"network_status": status["network_status"],
|
||||
"device_id": status["device_id"],
|
||||
"pending_sync": pending_cameras + pending_rois,
|
||||
"pending_alarms": pending_alarms,
|
||||
"details": {
|
||||
"pending_cameras": pending_cameras,
|
||||
"pending_rois": pending_rois,
|
||||
"pending_alarms": pending_alarms
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.get("/pending")
|
||||
def get_pending_syncs(db: Session = Depends(get_db)):
|
||||
"""获取待同步项列表"""
|
||||
from sqlalchemy import text
|
||||
from db.models import Camera, ROI, Alarm
|
||||
|
||||
pending_cameras = db.query(Camera).filter(Camera.pending_sync == True).all()
|
||||
pending_rois = db.query(ROI).filter(ROI.pending_sync == True).all()
|
||||
|
||||
from db.session import SessionLocal
|
||||
temp_db = SessionLocal()
|
||||
try:
|
||||
pending_alarms = temp_db.query(Alarm).filter(
|
||||
Alarm.upload_status.in_(['pending', 'retry'])
|
||||
).limit(100).all()
|
||||
finally:
|
||||
temp_db.close()
|
||||
|
||||
return {
|
||||
"cameras": [{"id": c.id, "name": c.name} for c in pending_cameras],
|
||||
"rois": [{"id": r.id, "name": r.name, "camera_id": r.camera_id} for r in pending_rois],
|
||||
"alarms": [{"id": a.id, "camera_id": a.camera_id, "type": a.event_type} for a in pending_alarms]
|
||||
}
|
||||
|
||||
|
||||
@router.post("/trigger")
|
||||
def trigger_sync():
|
||||
"""手动触发同步"""
|
||||
service = get_sync_service()
|
||||
from db.session import SessionLocal
|
||||
from db.crud import get_all_cameras, get_all_rois
|
||||
from db.models import Camera, ROI
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
cameras = get_all_cameras(db)
|
||||
for camera in cameras:
|
||||
service.queue_camera_sync(camera.id, 'update', {
|
||||
'name': camera.name,
|
||||
'rtsp_url': camera.rtsp_url,
|
||||
'enabled': camera.enabled
|
||||
})
|
||||
db.query(Camera).filter(Camera.id == camera.id).update({'pending_sync': False})
|
||||
db.commit()
|
||||
|
||||
rois = get_all_rois(db)
|
||||
for roi in rois:
|
||||
service.queue_roi_sync(roi.id, 'update', {
|
||||
'name': roi.name,
|
||||
'roi_type': roi.roi_type,
|
||||
'points': roi.points,
|
||||
'enabled': roi.enabled
|
||||
})
|
||||
db.query(ROI).filter(ROI.id == roi.id).update({'pending_sync': False})
|
||||
db.commit()
|
||||
|
||||
return {"message": "同步任务已加入队列", "count": len(cameras) + len(rois)}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/clear-failed")
|
||||
def clear_failed_syncs(db: Session = Depends(get_db)):
|
||||
"""清除失败的同步标记"""
|
||||
from sqlalchemy import text
|
||||
|
||||
db.execute(text("UPDATE cameras SET pending_sync = 0, sync_failed_at = NULL, sync_retry_count = 0"))
|
||||
db.execute(text("UPDATE rois SET pending_sync = 0, sync_failed_at = NULL, sync_retry_count = 0"))
|
||||
db.commit()
|
||||
|
||||
return {"message": "已清除所有失败的同步标记"}
|
||||
461
services/sync_service.py
Normal file
461
services/sync_service.py
Normal file
@@ -0,0 +1,461 @@
|
||||
"""
|
||||
云端同步服务
|
||||
|
||||
实现"云端为主、本地为辅"的双层数据存储架构:
|
||||
- 配置双向同步
|
||||
- 报警单向上报
|
||||
- 设备状态上报
|
||||
- 断网容错机制
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from queue import Queue, Empty
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from config import get_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SyncStatus(Enum):
|
||||
"""同步状态"""
|
||||
PENDING = "pending"
|
||||
SYNCING = "syncing"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
RETRY = "retry"
|
||||
|
||||
|
||||
class EntityType(Enum):
|
||||
"""实体类型"""
|
||||
CAMERA = "camera"
|
||||
ROI = "roi"
|
||||
ALARM = "alarm"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncTask:
|
||||
"""同步任务"""
|
||||
entity_type: EntityType
|
||||
entity_id: int
|
||||
operation: str # create, update, delete
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
status: SyncStatus = SyncStatus.PENDING
|
||||
retry_count: int = 0
|
||||
error_message: Optional[str] = None
|
||||
created_at: datetime = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.utcnow()
|
||||
|
||||
|
||||
class CloudAPIClient:
|
||||
"""云端 API 客户端"""
|
||||
|
||||
def __init__(self, base_url: str, api_key: str, device_id: str):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.api_key = api_key
|
||||
self.device_id = device_id
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
'X-Device-ID': device_id
|
||||
})
|
||||
|
||||
def request(self, method: str, path: str, **kwargs) -> requests.Response:
|
||||
"""发送 API 请求"""
|
||||
url = f"{self.base_url}{path}"
|
||||
response = self.session.request(method, url, **kwargs)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
def get(self, path: str, **kwargs):
|
||||
return self.request('GET', path, **kwargs)
|
||||
|
||||
def post(self, path: str, **kwargs):
|
||||
return self.request('POST', path, **kwargs)
|
||||
|
||||
def put(self, path: str, **kwargs):
|
||||
return self.request('PUT', path, **kwargs)
|
||||
|
||||
def delete(self, path: str, **kwargs):
|
||||
return self.request('DELETE', path, **kwargs)
|
||||
|
||||
|
||||
class SyncService:
|
||||
"""云端同步服务"""
|
||||
|
||||
def __init__(self):
|
||||
config = get_config()
|
||||
self.config = config
|
||||
|
||||
# 云端配置
|
||||
self.cloud_enabled = config.cloud.enabled
|
||||
self.cloud_url = config.cloud.api_url
|
||||
self.api_key = config.cloud.api_key
|
||||
self.device_id = config.cloud.device_id
|
||||
|
||||
# 同步配置
|
||||
self.sync_interval = config.cloud.sync_interval
|
||||
self.alarm_retry_interval = config.cloud.alarm_retry_interval
|
||||
self.status_report_interval = config.cloud.status_report_interval
|
||||
self.max_retries = config.cloud.max_retries
|
||||
|
||||
# 客户端
|
||||
self.client: Optional[CloudAPIClient] = None
|
||||
if self.cloud_enabled:
|
||||
self.client = CloudAPIClient(
|
||||
self.cloud_url,
|
||||
self.api_key,
|
||||
self.device_id
|
||||
)
|
||||
|
||||
# 任务队列
|
||||
self.sync_queue: Queue = Queue()
|
||||
self.alarm_queue: Queue = Queue()
|
||||
|
||||
# 状态
|
||||
self.running = False
|
||||
self.threads: List[threading.Thread] = []
|
||||
self.network_status = "disconnected"
|
||||
|
||||
# 重试配置
|
||||
self.retry_delays = [60, 300, 900, 3600] # 1分钟, 5分钟, 15分钟, 1小时
|
||||
|
||||
def start(self):
|
||||
"""启动同步服务"""
|
||||
if self.running:
|
||||
logger.warning("同步服务已在运行")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
if self.cloud_enabled:
|
||||
logger.info(f"启动云端同步服务,设备ID: {self.device_id}")
|
||||
else:
|
||||
logger.info("云端同步已禁用,使用本地模式")
|
||||
|
||||
# 启动工作线程
|
||||
self.threads.append(threading.Thread(target=self._sync_worker, daemon=True))
|
||||
self.threads.append(threading.Thread(target=self._alarm_worker, daemon=True))
|
||||
self.threads.append(threading.Thread(target=self._status_worker, daemon=True))
|
||||
|
||||
for thread in self.threads:
|
||||
thread.start()
|
||||
|
||||
logger.info("同步服务已启动")
|
||||
|
||||
def stop(self):
|
||||
"""停止同步服务"""
|
||||
self.running = False
|
||||
|
||||
for thread in self.threads:
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=5)
|
||||
|
||||
logger.info("同步服务已停止")
|
||||
|
||||
def _sync_worker(self):
|
||||
"""配置同步工作线程"""
|
||||
while self.running:
|
||||
try:
|
||||
task = self.sync_queue.get(timeout=1)
|
||||
self._execute_sync(task)
|
||||
except Empty:
|
||||
self._check_network_status()
|
||||
except Exception as e:
|
||||
logger.error(f"同步工作线程异常: {e}")
|
||||
|
||||
def _alarm_worker(self):
|
||||
"""报警上报工作线程"""
|
||||
while self.running:
|
||||
try:
|
||||
alarm_id = self.alarm_queue.get(timeout=1)
|
||||
self._upload_alarm(alarm_id)
|
||||
except Empty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"报警上报工作线程异常: {e}")
|
||||
|
||||
def _status_worker(self):
|
||||
"""状态上报工作线程"""
|
||||
while self.running:
|
||||
try:
|
||||
if self.network_status == "connected":
|
||||
self._report_status()
|
||||
except Exception as e:
|
||||
logger.error(f"状态上报失败: {e}")
|
||||
time.sleep(self.status_report_interval)
|
||||
|
||||
def _check_network_status(self):
|
||||
"""检查网络状态"""
|
||||
if not self.cloud_enabled:
|
||||
self.network_status = "disabled"
|
||||
return
|
||||
|
||||
try:
|
||||
self.client.get('/health')
|
||||
self.network_status = "connected"
|
||||
except:
|
||||
self.network_status = "disconnected"
|
||||
|
||||
def _execute_sync(self, task: SyncTask):
|
||||
"""执行同步任务"""
|
||||
logger.info(f"执行同步任务: {task.entity_type.value}/{task.entity_id} ({task.operation})")
|
||||
|
||||
task.status = SyncStatus.SYNCING
|
||||
|
||||
try:
|
||||
if task.entity_type == EntityType.CAMERA:
|
||||
self._sync_camera(task)
|
||||
elif task.entity_type == EntityType.ROI:
|
||||
self._sync_roi(task)
|
||||
|
||||
task.status = SyncStatus.SUCCESS
|
||||
logger.info(f"同步成功: {task.entity_type.value}/{task.entity_id}")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
self._handle_sync_error(task, str(e))
|
||||
except Exception as e:
|
||||
task.status = SyncStatus.FAILED
|
||||
task.error_message = str(e)
|
||||
logger.error(f"同步失败: {task.entity_type.value}/{task.entity_id}: {e}")
|
||||
|
||||
def _handle_sync_error(self, task: SyncTask, error: str):
|
||||
"""处理同步错误"""
|
||||
task.retry_count += 1
|
||||
|
||||
if task.retry_count < self.max_retries:
|
||||
task.status = SyncStatus.RETRY
|
||||
delay = self.retry_delays[task.retry_count - 1]
|
||||
task.error_message = f"第{task.retry_count}次失败: {error}"
|
||||
logger.warning(f"同步重试 ({task.retry_count}/{self.max_retries}): {task.entity_type.value}/{task.entity_id}")
|
||||
# 重新入队
|
||||
time.sleep(delay)
|
||||
self.sync_queue.put(task)
|
||||
else:
|
||||
task.status = SyncStatus.FAILED
|
||||
task.error_message = f"已超过最大重试次数: {error}"
|
||||
logger.error(f"同步失败,已达最大重试次数: {task.entity_type.value}/{task.entity_id}")
|
||||
|
||||
def _sync_camera(self, task: SyncTask):
|
||||
"""同步摄像头配置"""
|
||||
if task.operation == 'update':
|
||||
self.client.put(f"/api/v1/cameras/{task.entity_id}", json=task.data)
|
||||
elif task.operation == 'delete':
|
||||
self.client.delete(f"/api/v1/cameras/{task.entity_id}")
|
||||
|
||||
def _sync_roi(self, task: SyncTask):
|
||||
"""同步 ROI 配置"""
|
||||
if task.operation == 'update':
|
||||
self.client.put(f"/api/v1/rois/{task.entity_id}", json=task.data)
|
||||
elif task.operation == 'delete':
|
||||
self.client.delete(f"/api/v1/rois/{task.entity_id}")
|
||||
|
||||
def _upload_alarm(self, alarm_id: int):
|
||||
"""上传报警记录"""
|
||||
from db.crud import get_alarm_by_id, update_alarm_status
|
||||
from db.models import get_session_factory
|
||||
|
||||
SessionLocal = get_session_factory()
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
alarm = get_alarm_by_id(db, alarm_id)
|
||||
if not alarm:
|
||||
logger.warning(f"报警记录不存在: {alarm_id}")
|
||||
return
|
||||
|
||||
# 准备数据
|
||||
alarm_data = {
|
||||
'device_id': self.device_id,
|
||||
'camera_id': alarm.camera_id,
|
||||
'alarm_type': alarm.event_type,
|
||||
'confidence': alarm.confidence,
|
||||
'timestamp': alarm.created_at.isoformat() if alarm.created_at else None,
|
||||
'region': alarm.region_data
|
||||
}
|
||||
|
||||
# 上传图片
|
||||
if alarm.snapshot_path and os.path.exists(alarm.snapshot_path):
|
||||
with open(alarm.snapshot_path, 'rb') as f:
|
||||
files = {'file': f}
|
||||
response = self.client.post('/api/v1/alarms/images', files=files)
|
||||
alarm_data['image_url'] = response.json().get('data', {}).get('url')
|
||||
|
||||
# 上报报警
|
||||
response = self.client.post('/api/v1/alarms/report', json=alarm_data)
|
||||
cloud_id = response.json().get('data', {}).get('alarm_id')
|
||||
|
||||
# 更新本地状态
|
||||
update_alarm_status(db, alarm_id, status='uploaded', cloud_id=cloud_id)
|
||||
logger.info(f"报警上报成功: {alarm_id} -> 云端ID: {cloud_id}")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
update_alarm_status(db, alarm_id, status='retry', error_message=str(e))
|
||||
self.alarm_queue.put(alarm_id) # 重试
|
||||
except Exception as e:
|
||||
update_alarm_status(db, alarm_id, status='failed', error_message=str(e))
|
||||
logger.error(f"报警处理失败: {alarm_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _report_status(self):
|
||||
"""上报设备状态"""
|
||||
import psutil
|
||||
from db.crud import get_active_camera_count
|
||||
|
||||
try:
|
||||
metrics = {
|
||||
'device_id': self.device_id,
|
||||
'cpu_percent': psutil.cpu_percent(),
|
||||
'memory_percent': psutil.virtual_memory().percent,
|
||||
'disk_usage': psutil.disk_usage('/').percent,
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
self.client.post('/api/v1/devices/status', json=metrics)
|
||||
logger.debug(f"设备状态上报成功: CPU={metrics['cpu_percent']}%")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.warning(f"设备状态上报失败: {e}")
|
||||
|
||||
# 公共接口
|
||||
|
||||
def queue_camera_sync(self, camera_id: int, operation: str = 'update', data: Dict[str, Any] = None):
|
||||
"""将摄像头同步加入队列"""
|
||||
task = SyncTask(
|
||||
entity_type=EntityType.CAMERA,
|
||||
entity_id=camera_id,
|
||||
operation=operation,
|
||||
data=data
|
||||
)
|
||||
self.sync_queue.put(task)
|
||||
|
||||
def queue_roi_sync(self, roi_id: int, operation: str = 'update', data: Dict[str, Any] = None):
|
||||
"""将 ROI 同步加入队列"""
|
||||
task = SyncTask(
|
||||
entity_type=EntityType.ROI,
|
||||
entity_id=roi_id,
|
||||
operation=operation,
|
||||
data=data
|
||||
)
|
||||
self.sync_queue.put(task)
|
||||
|
||||
def queue_alarm_upload(self, alarm_id: int):
|
||||
"""将报警上传加入队列"""
|
||||
self.alarm_queue.put(alarm_id)
|
||||
|
||||
def sync_config_from_cloud(self, db: Session) -> Dict[str, int]:
|
||||
"""从云端拉取配置"""
|
||||
result = {'cameras': 0, 'rois': 0}
|
||||
|
||||
if not self.cloud_enabled:
|
||||
logger.info("云端同步已禁用,跳过配置拉取")
|
||||
return result
|
||||
|
||||
try:
|
||||
logger.info("从云端拉取配置...")
|
||||
|
||||
# 拉取设备配置
|
||||
response = self.client.get(f"/api/v1/devices/{self.device_id}/config")
|
||||
config = response.json().get('data', {})
|
||||
|
||||
# 处理摄像头
|
||||
cameras = config.get('cameras', [])
|
||||
for cloud_cam in cameras:
|
||||
self._merge_camera(db, cloud_cam)
|
||||
result['cameras'] += 1
|
||||
|
||||
logger.info(f"从云端拉取配置完成: {result['cameras']} 个摄像头")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"从云端拉取配置失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理云端配置时出错: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def _merge_camera(self, db: Session, cloud_data: Dict[str, Any]):
|
||||
"""合并摄像头配置"""
|
||||
from db.crud import get_camera_by_cloud_id, create_camera, update_camera
|
||||
from db.models import Camera
|
||||
|
||||
cloud_id = cloud_data.get('id')
|
||||
existing = get_camera_by_cloud_id(db, cloud_id)
|
||||
|
||||
if existing:
|
||||
# 更新现有记录
|
||||
if not existing.pending_sync:
|
||||
update_camera(db, existing.id, {
|
||||
'name': cloud_data.get('name'),
|
||||
'rtsp_url': cloud_data.get('rtsp_url'),
|
||||
'enabled': cloud_data.get('enabled', True),
|
||||
'fps_limit': cloud_data.get('fps_limit', 30),
|
||||
'process_every_n_frames': cloud_data.get('process_every_n_frames', 3),
|
||||
})
|
||||
else:
|
||||
# 创建新记录
|
||||
camera = create_camera(db, {
|
||||
'name': cloud_data.get('name'),
|
||||
'rtsp_url': cloud_data.get('rtsp_url'),
|
||||
'enabled': cloud_data.get('enabled', True),
|
||||
'fps_limit': cloud_data.get('fps_limit', 30),
|
||||
'process_every_n_frames': cloud_data.get('process_every_n_frames', 3),
|
||||
})
|
||||
# 更新 cloud_id
|
||||
camera.cloud_id = cloud_id
|
||||
db.commit()
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""获取同步服务状态"""
|
||||
return {
|
||||
'running': self.running,
|
||||
'cloud_enabled': self.cloud_enabled,
|
||||
'network_status': self.network_status,
|
||||
'device_id': self.device_id,
|
||||
'pending_sync': self.sync_queue.qsize(),
|
||||
'pending_alarms': self.alarm_queue.qsize(),
|
||||
}
|
||||
|
||||
|
||||
# 单例
|
||||
_sync_service: Optional[SyncService] = None
|
||||
|
||||
|
||||
def get_sync_service() -> SyncService:
|
||||
"""获取同步服务单例"""
|
||||
global _sync_service
|
||||
if _sync_service is None:
|
||||
_sync_service = SyncService()
|
||||
return _sync_service
|
||||
|
||||
|
||||
def start_sync_service():
|
||||
"""启动同步服务"""
|
||||
service = get_sync_service()
|
||||
service.start()
|
||||
|
||||
|
||||
def stop_sync_service():
|
||||
"""停止同步服务"""
|
||||
global _sync_service
|
||||
if _sync_service:
|
||||
_sync_service.stop()
|
||||
_sync_service = None
|
||||
214
sort.py
214
sort.py
@@ -1,214 +0,0 @@
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from filterpy.kalman import KalmanFilter
|
||||
|
||||
def linear_assignment(cost_matrix):
|
||||
x, y = linear_sum_assignment(cost_matrix)
|
||||
return np.array(list(zip(x, y)))
|
||||
|
||||
def iou_batch(bb_test, bb_gt):
|
||||
"""
|
||||
From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2]
|
||||
"""
|
||||
bb_test = np.expand_dims(bb_test, 1)
|
||||
bb_gt = np.expand_dims(bb_gt, 0)
|
||||
xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0])
|
||||
yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1])
|
||||
xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2])
|
||||
yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3])
|
||||
w = np.maximum(0., xx2 - xx1)
|
||||
h = np.maximum(0., yy2 - yy1)
|
||||
wh = w * h
|
||||
o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1])
|
||||
+ (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh)
|
||||
return o
|
||||
|
||||
def convert_bbox_to_z(bbox):
|
||||
"""
|
||||
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
|
||||
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
|
||||
the aspect ratio
|
||||
"""
|
||||
w = bbox[2] - bbox[0]
|
||||
h = bbox[3] - bbox[1]
|
||||
x = bbox[0] + w / 2.
|
||||
y = bbox[1] + h / 2.
|
||||
s = w * h # scale is just area
|
||||
r = w / float(h)
|
||||
return np.array([x, y, s, r]).reshape((4, 1))
|
||||
|
||||
def convert_x_to_bbox(x, score=None):
|
||||
"""
|
||||
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
|
||||
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
|
||||
"""
|
||||
w = np.sqrt(x[2] * x[3])
|
||||
h = x[2] / w
|
||||
if score is None:
|
||||
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4))
|
||||
else:
|
||||
return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5))
|
||||
|
||||
class KalmanBoxTracker(object):
|
||||
"""
|
||||
This class represents the internal state of individual tracked objects observed as bbox.
|
||||
"""
|
||||
count = 0
|
||||
|
||||
def __init__(self, bbox):
|
||||
"""
|
||||
Initialises a tracker using initial bounding box.
|
||||
"""
|
||||
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
||||
self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0],
|
||||
[0, 1, 0, 0, 0, 1, 0],
|
||||
[0, 0, 1, 0, 0, 0, 1],
|
||||
[0, 0, 0, 1, 0, 0, 0],
|
||||
[0, 0, 0, 0, 1, 0, 0],
|
||||
[0, 0, 0, 0, 0, 1, 0],
|
||||
[0, 0, 0, 0, 0, 0, 1]])
|
||||
self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0],
|
||||
[0, 1, 0, 0, 0, 0, 0],
|
||||
[0, 0, 1, 0, 0, 0, 0],
|
||||
[0, 0, 0, 1, 0, 0, 0]])
|
||||
|
||||
self.kf.R[2:, 2:] *= 10.
|
||||
self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities
|
||||
self.kf.P *= 10.
|
||||
self.kf.Q[-1, -1] *= 0.01
|
||||
self.kf.Q[4:, 4:] *= 0.01
|
||||
|
||||
self.kf.x[:4] = convert_bbox_to_z(bbox)
|
||||
self.time_since_update = 0
|
||||
self.id = KalmanBoxTracker.count
|
||||
KalmanBoxTracker.count += 1
|
||||
self.history = []
|
||||
self.hits = 0
|
||||
self.hit_streak = 0
|
||||
self.age = 0
|
||||
|
||||
def update(self, bbox):
|
||||
"""
|
||||
Updates the state vector with observed bbox.
|
||||
"""
|
||||
self.time_since_update = 0
|
||||
self.history = []
|
||||
self.hits += 1
|
||||
self.hit_streak += 1
|
||||
self.kf.update(convert_bbox_to_z(bbox))
|
||||
|
||||
def predict(self):
|
||||
"""
|
||||
Advances the state vector and returns the predicted bounding box estimate.
|
||||
"""
|
||||
if (self.kf.x[6] + self.kf.x[2]) <= 0:
|
||||
self.kf.x[6] *= 0.0
|
||||
self.kf.predict()
|
||||
self.age += 1
|
||||
if self.time_since_update > 0:
|
||||
self.hit_streak = 0
|
||||
self.time_since_update += 1
|
||||
self.history.append(convert_x_to_bbox(self.kf.x))
|
||||
return self.history[-1]
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Returns the current bounding box estimate.
|
||||
"""
|
||||
return convert_x_to_bbox(self.kf.x)
|
||||
|
||||
class Sort(object):
|
||||
def __init__(self, max_age=5, min_hits=2, iou_threshold=0.3):
|
||||
"""
|
||||
Sets key parameters for SORT
|
||||
"""
|
||||
self.max_age = max_age
|
||||
self.min_hits = min_hits
|
||||
self.iou_threshold = iou_threshold
|
||||
self.trackers = []
|
||||
self.frame_count = 0
|
||||
|
||||
def update(self, dets=np.empty((0, 5))):
|
||||
"""
|
||||
Params:
|
||||
dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],...]
|
||||
Requires: this method must be called once for each frame even with empty detections.
|
||||
Returns the a similar array, where the last column is the object ID.
|
||||
|
||||
NOTE: The number of objects returned may differ from the number of detections provided.
|
||||
"""
|
||||
self.frame_count += 1
|
||||
trks = np.zeros((len(self.trackers), 5))
|
||||
to_del = []
|
||||
ret = []
|
||||
for t, trk in enumerate(trks):
|
||||
pos = self.trackers[t].predict()[0]
|
||||
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
|
||||
if np.any(np.isnan(pos)):
|
||||
to_del.append(t)
|
||||
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
||||
for t in reversed(to_del):
|
||||
self.trackers.pop(t)
|
||||
matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold)
|
||||
|
||||
# update matched trackers with assigned detections
|
||||
for m in matched:
|
||||
self.trackers[m[1]].update(dets[m[0], :])
|
||||
|
||||
# create and initialise new trackers for unmatched detections
|
||||
for i in unmatched_dets:
|
||||
trk = KalmanBoxTracker(dets[i, :])
|
||||
self.trackers.append(trk)
|
||||
i = len(self.trackers)
|
||||
for trk in reversed(self.trackers):
|
||||
d = trk.get_state()[0]
|
||||
if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits):
|
||||
ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive
|
||||
i -= 1
|
||||
# remove dead tracklet
|
||||
if trk.time_since_update > self.max_age:
|
||||
self.trackers.pop(i)
|
||||
if len(ret) > 0:
|
||||
return np.concatenate(ret)
|
||||
return np.empty((0, 5))
|
||||
|
||||
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
|
||||
"""
|
||||
Assigns detections to tracked object (both represented as bounding boxes)
|
||||
Returns 3 lists of matches, unmatched_detections, unmatched_trackers
|
||||
"""
|
||||
if len(trackers) == 0:
|
||||
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 1), dtype=int)
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(-iou_matrix)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0, 2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if d not in matched_indices[:, 0]:
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if t not in matched_indices[:, 1]:
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if iou_matrix[m[0], m[1]] < iou_threshold:
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1, 2))
|
||||
if len(matches) == 0:
|
||||
matches = np.empty((0, 2), dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches, axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
Reference in New Issue
Block a user