diff --git a/api/sync.py b/api/sync.py new file mode 100644 index 0000000..a8a22c8 --- /dev/null +++ b/api/sync.py @@ -0,0 +1,116 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from typing import List, Optional + +from db.models import get_db +from services.sync_service import get_sync_service + +router = APIRouter(prefix="/api/sync", tags=["同步"]) + + +@router.get("/status") +def get_sync_status(db: Session = Depends(get_db)): + """获取同步服务状态""" + from sqlalchemy import text + + service = get_sync_service() + status = service.get_status() + + pending_cameras = db.execute( + text("SELECT COUNT(*) FROM cameras WHERE pending_sync = 1") + ).scalar() or 0 + + pending_rois = db.execute( + text("SELECT COUNT(*) FROM rois WHERE pending_sync = 1") + ).scalar() or 0 + + pending_alarms = db.execute( + text("SELECT COUNT(*) FROM alarms WHERE upload_status = 'pending' OR upload_status = 'retry'") + ).scalar() or 0 + + return { + "running": status["running"], + "cloud_enabled": status["cloud_enabled"], + "network_status": status["network_status"], + "device_id": status["device_id"], + "pending_sync": pending_cameras + pending_rois, + "pending_alarms": pending_alarms, + "details": { + "pending_cameras": pending_cameras, + "pending_rois": pending_rois, + "pending_alarms": pending_alarms + } + } + + +@router.get("/pending") +def get_pending_syncs(db: Session = Depends(get_db)): + """获取待同步项列表""" + from sqlalchemy import text + from db.models import Camera, ROI, Alarm + + pending_cameras = db.query(Camera).filter(Camera.pending_sync == True).all() + pending_rois = db.query(ROI).filter(ROI.pending_sync == True).all() + + from db.session import SessionLocal + temp_db = SessionLocal() + try: + pending_alarms = temp_db.query(Alarm).filter( + Alarm.upload_status.in_(['pending', 'retry']) + ).limit(100).all() + finally: + temp_db.close() + + return { + "cameras": [{"id": c.id, "name": c.name} for c in pending_cameras], + "rois": [{"id": r.id, "name": r.name, "camera_id": r.camera_id} for r in pending_rois], + "alarms": [{"id": a.id, "camera_id": a.camera_id, "type": a.event_type} for a in pending_alarms] + } + + +@router.post("/trigger") +def trigger_sync(): + """手动触发同步""" + service = get_sync_service() + from db.session import SessionLocal + from db.crud import get_all_cameras, get_all_rois + from db.models import Camera, ROI + + db = SessionLocal() + try: + cameras = get_all_cameras(db) + for camera in cameras: + service.queue_camera_sync(camera.id, 'update', { + 'name': camera.name, + 'rtsp_url': camera.rtsp_url, + 'enabled': camera.enabled + }) + db.query(Camera).filter(Camera.id == camera.id).update({'pending_sync': False}) + db.commit() + + rois = get_all_rois(db) + for roi in rois: + service.queue_roi_sync(roi.id, 'update', { + 'name': roi.name, + 'roi_type': roi.roi_type, + 'points': roi.points, + 'enabled': roi.enabled + }) + db.query(ROI).filter(ROI.id == roi.id).update({'pending_sync': False}) + db.commit() + + return {"message": "同步任务已加入队列", "count": len(cameras) + len(rois)} + finally: + db.close() + + +@router.post("/clear-failed") +def clear_failed_syncs(db: Session = Depends(get_db)): + """清除失败的同步标记""" + from sqlalchemy import text + + db.execute(text("UPDATE cameras SET pending_sync = 0, sync_failed_at = NULL, sync_retry_count = 0")) + db.execute(text("UPDATE rois SET pending_sync = 0, sync_failed_at = NULL, sync_retry_count = 0")) + db.commit() + + return {"message": "已清除所有失败的同步标记"} diff --git a/services/sync_service.py b/services/sync_service.py new file mode 100644 index 0000000..1c74208 --- /dev/null +++ b/services/sync_service.py @@ -0,0 +1,461 @@ +""" +云端同步服务 + +实现"云端为主、本地为辅"的双层数据存储架构: +- 配置双向同步 +- 报警单向上报 +- 设备状态上报 +- 断网容错机制 +""" + +import os +import sys +import time +import threading +import logging +from datetime import datetime +from typing import Optional, List, Dict, Any +from queue import Queue, Empty +from dataclasses import dataclass +from enum import Enum + +import requests +from sqlalchemy.orm import Session + +# 添加项目根目录到路径 +project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, project_root) + +from config import get_config + +logger = logging.getLogger(__name__) + + +class SyncStatus(Enum): + """同步状态""" + PENDING = "pending" + SYNCING = "syncing" + SUCCESS = "success" + FAILED = "failed" + RETRY = "retry" + + +class EntityType(Enum): + """实体类型""" + CAMERA = "camera" + ROI = "roi" + ALARM = "alarm" + STATUS = "status" + + +@dataclass +class SyncTask: + """同步任务""" + entity_type: EntityType + entity_id: int + operation: str # create, update, delete + data: Optional[Dict[str, Any]] = None + status: SyncStatus = SyncStatus.PENDING + retry_count: int = 0 + error_message: Optional[str] = None + created_at: datetime = None + + def __post_init__(self): + if self.created_at is None: + self.created_at = datetime.utcnow() + + +class CloudAPIClient: + """云端 API 客户端""" + + def __init__(self, base_url: str, api_key: str, device_id: str): + self.base_url = base_url.rstrip('/') + self.api_key = api_key + self.device_id = device_id + self.session = requests.Session() + self.session.headers.update({ + 'Authorization': f'Bearer {api_key}', + 'Content-Type': 'application/json', + 'X-Device-ID': device_id + }) + + def request(self, method: str, path: str, **kwargs) -> requests.Response: + """发送 API 请求""" + url = f"{self.base_url}{path}" + response = self.session.request(method, url, **kwargs) + response.raise_for_status() + return response + + def get(self, path: str, **kwargs): + return self.request('GET', path, **kwargs) + + def post(self, path: str, **kwargs): + return self.request('POST', path, **kwargs) + + def put(self, path: str, **kwargs): + return self.request('PUT', path, **kwargs) + + def delete(self, path: str, **kwargs): + return self.request('DELETE', path, **kwargs) + + +class SyncService: + """云端同步服务""" + + def __init__(self): + config = get_config() + self.config = config + + # 云端配置 + self.cloud_enabled = config.cloud.enabled + self.cloud_url = config.cloud.api_url + self.api_key = config.cloud.api_key + self.device_id = config.cloud.device_id + + # 同步配置 + self.sync_interval = config.cloud.sync_interval + self.alarm_retry_interval = config.cloud.alarm_retry_interval + self.status_report_interval = config.cloud.status_report_interval + self.max_retries = config.cloud.max_retries + + # 客户端 + self.client: Optional[CloudAPIClient] = None + if self.cloud_enabled: + self.client = CloudAPIClient( + self.cloud_url, + self.api_key, + self.device_id + ) + + # 任务队列 + self.sync_queue: Queue = Queue() + self.alarm_queue: Queue = Queue() + + # 状态 + self.running = False + self.threads: List[threading.Thread] = [] + self.network_status = "disconnected" + + # 重试配置 + self.retry_delays = [60, 300, 900, 3600] # 1分钟, 5分钟, 15分钟, 1小时 + + def start(self): + """启动同步服务""" + if self.running: + logger.warning("同步服务已在运行") + return + + self.running = True + + if self.cloud_enabled: + logger.info(f"启动云端同步服务,设备ID: {self.device_id}") + else: + logger.info("云端同步已禁用,使用本地模式") + + # 启动工作线程 + self.threads.append(threading.Thread(target=self._sync_worker, daemon=True)) + self.threads.append(threading.Thread(target=self._alarm_worker, daemon=True)) + self.threads.append(threading.Thread(target=self._status_worker, daemon=True)) + + for thread in self.threads: + thread.start() + + logger.info("同步服务已启动") + + def stop(self): + """停止同步服务""" + self.running = False + + for thread in self.threads: + if thread.is_alive(): + thread.join(timeout=5) + + logger.info("同步服务已停止") + + def _sync_worker(self): + """配置同步工作线程""" + while self.running: + try: + task = self.sync_queue.get(timeout=1) + self._execute_sync(task) + except Empty: + self._check_network_status() + except Exception as e: + logger.error(f"同步工作线程异常: {e}") + + def _alarm_worker(self): + """报警上报工作线程""" + while self.running: + try: + alarm_id = self.alarm_queue.get(timeout=1) + self._upload_alarm(alarm_id) + except Empty: + continue + except Exception as e: + logger.error(f"报警上报工作线程异常: {e}") + + def _status_worker(self): + """状态上报工作线程""" + while self.running: + try: + if self.network_status == "connected": + self._report_status() + except Exception as e: + logger.error(f"状态上报失败: {e}") + time.sleep(self.status_report_interval) + + def _check_network_status(self): + """检查网络状态""" + if not self.cloud_enabled: + self.network_status = "disabled" + return + + try: + self.client.get('/health') + self.network_status = "connected" + except: + self.network_status = "disconnected" + + def _execute_sync(self, task: SyncTask): + """执行同步任务""" + logger.info(f"执行同步任务: {task.entity_type.value}/{task.entity_id} ({task.operation})") + + task.status = SyncStatus.SYNCING + + try: + if task.entity_type == EntityType.CAMERA: + self._sync_camera(task) + elif task.entity_type == EntityType.ROI: + self._sync_roi(task) + + task.status = SyncStatus.SUCCESS + logger.info(f"同步成功: {task.entity_type.value}/{task.entity_id}") + + except requests.exceptions.RequestException as e: + self._handle_sync_error(task, str(e)) + except Exception as e: + task.status = SyncStatus.FAILED + task.error_message = str(e) + logger.error(f"同步失败: {task.entity_type.value}/{task.entity_id}: {e}") + + def _handle_sync_error(self, task: SyncTask, error: str): + """处理同步错误""" + task.retry_count += 1 + + if task.retry_count < self.max_retries: + task.status = SyncStatus.RETRY + delay = self.retry_delays[task.retry_count - 1] + task.error_message = f"第{task.retry_count}次失败: {error}" + logger.warning(f"同步重试 ({task.retry_count}/{self.max_retries}): {task.entity_type.value}/{task.entity_id}") + # 重新入队 + time.sleep(delay) + self.sync_queue.put(task) + else: + task.status = SyncStatus.FAILED + task.error_message = f"已超过最大重试次数: {error}" + logger.error(f"同步失败,已达最大重试次数: {task.entity_type.value}/{task.entity_id}") + + def _sync_camera(self, task: SyncTask): + """同步摄像头配置""" + if task.operation == 'update': + self.client.put(f"/api/v1/cameras/{task.entity_id}", json=task.data) + elif task.operation == 'delete': + self.client.delete(f"/api/v1/cameras/{task.entity_id}") + + def _sync_roi(self, task: SyncTask): + """同步 ROI 配置""" + if task.operation == 'update': + self.client.put(f"/api/v1/rois/{task.entity_id}", json=task.data) + elif task.operation == 'delete': + self.client.delete(f"/api/v1/rois/{task.entity_id}") + + def _upload_alarm(self, alarm_id: int): + """上传报警记录""" + from db.crud import get_alarm_by_id, update_alarm_status + from db.models import get_session_factory + + SessionLocal = get_session_factory() + db = SessionLocal() + + try: + alarm = get_alarm_by_id(db, alarm_id) + if not alarm: + logger.warning(f"报警记录不存在: {alarm_id}") + return + + # 准备数据 + alarm_data = { + 'device_id': self.device_id, + 'camera_id': alarm.camera_id, + 'alarm_type': alarm.event_type, + 'confidence': alarm.confidence, + 'timestamp': alarm.created_at.isoformat() if alarm.created_at else None, + 'region': alarm.region_data + } + + # 上传图片 + if alarm.snapshot_path and os.path.exists(alarm.snapshot_path): + with open(alarm.snapshot_path, 'rb') as f: + files = {'file': f} + response = self.client.post('/api/v1/alarms/images', files=files) + alarm_data['image_url'] = response.json().get('data', {}).get('url') + + # 上报报警 + response = self.client.post('/api/v1/alarms/report', json=alarm_data) + cloud_id = response.json().get('data', {}).get('alarm_id') + + # 更新本地状态 + update_alarm_status(db, alarm_id, status='uploaded', cloud_id=cloud_id) + logger.info(f"报警上报成功: {alarm_id} -> 云端ID: {cloud_id}") + + except requests.exceptions.RequestException as e: + update_alarm_status(db, alarm_id, status='retry', error_message=str(e)) + self.alarm_queue.put(alarm_id) # 重试 + except Exception as e: + update_alarm_status(db, alarm_id, status='failed', error_message=str(e)) + logger.error(f"报警处理失败: {alarm_id}: {e}") + finally: + db.close() + + def _report_status(self): + """上报设备状态""" + import psutil + from db.crud import get_active_camera_count + + try: + metrics = { + 'device_id': self.device_id, + 'cpu_percent': psutil.cpu_percent(), + 'memory_percent': psutil.virtual_memory().percent, + 'disk_usage': psutil.disk_usage('/').percent, + 'timestamp': datetime.utcnow().isoformat() + } + + self.client.post('/api/v1/devices/status', json=metrics) + logger.debug(f"设备状态上报成功: CPU={metrics['cpu_percent']}%") + except requests.exceptions.RequestException as e: + logger.warning(f"设备状态上报失败: {e}") + + # 公共接口 + + def queue_camera_sync(self, camera_id: int, operation: str = 'update', data: Dict[str, Any] = None): + """将摄像头同步加入队列""" + task = SyncTask( + entity_type=EntityType.CAMERA, + entity_id=camera_id, + operation=operation, + data=data + ) + self.sync_queue.put(task) + + def queue_roi_sync(self, roi_id: int, operation: str = 'update', data: Dict[str, Any] = None): + """将 ROI 同步加入队列""" + task = SyncTask( + entity_type=EntityType.ROI, + entity_id=roi_id, + operation=operation, + data=data + ) + self.sync_queue.put(task) + + def queue_alarm_upload(self, alarm_id: int): + """将报警上传加入队列""" + self.alarm_queue.put(alarm_id) + + def sync_config_from_cloud(self, db: Session) -> Dict[str, int]: + """从云端拉取配置""" + result = {'cameras': 0, 'rois': 0} + + if not self.cloud_enabled: + logger.info("云端同步已禁用,跳过配置拉取") + return result + + try: + logger.info("从云端拉取配置...") + + # 拉取设备配置 + response = self.client.get(f"/api/v1/devices/{self.device_id}/config") + config = response.json().get('data', {}) + + # 处理摄像头 + cameras = config.get('cameras', []) + for cloud_cam in cameras: + self._merge_camera(db, cloud_cam) + result['cameras'] += 1 + + logger.info(f"从云端拉取配置完成: {result['cameras']} 个摄像头") + + except requests.exceptions.RequestException as e: + logger.error(f"从云端拉取配置失败: {e}") + except Exception as e: + logger.error(f"处理云端配置时出错: {e}") + + return result + + def _merge_camera(self, db: Session, cloud_data: Dict[str, Any]): + """合并摄像头配置""" + from db.crud import get_camera_by_cloud_id, create_camera, update_camera + from db.models import Camera + + cloud_id = cloud_data.get('id') + existing = get_camera_by_cloud_id(db, cloud_id) + + if existing: + # 更新现有记录 + if not existing.pending_sync: + update_camera(db, existing.id, { + 'name': cloud_data.get('name'), + 'rtsp_url': cloud_data.get('rtsp_url'), + 'enabled': cloud_data.get('enabled', True), + 'fps_limit': cloud_data.get('fps_limit', 30), + 'process_every_n_frames': cloud_data.get('process_every_n_frames', 3), + }) + else: + # 创建新记录 + camera = create_camera(db, { + 'name': cloud_data.get('name'), + 'rtsp_url': cloud_data.get('rtsp_url'), + 'enabled': cloud_data.get('enabled', True), + 'fps_limit': cloud_data.get('fps_limit', 30), + 'process_every_n_frames': cloud_data.get('process_every_n_frames', 3), + }) + # 更新 cloud_id + camera.cloud_id = cloud_id + db.commit() + + def get_status(self) -> Dict[str, Any]: + """获取同步服务状态""" + return { + 'running': self.running, + 'cloud_enabled': self.cloud_enabled, + 'network_status': self.network_status, + 'device_id': self.device_id, + 'pending_sync': self.sync_queue.qsize(), + 'pending_alarms': self.alarm_queue.qsize(), + } + + +# 单例 +_sync_service: Optional[SyncService] = None + + +def get_sync_service() -> SyncService: + """获取同步服务单例""" + global _sync_service + if _sync_service is None: + _sync_service = SyncService() + return _sync_service + + +def start_sync_service(): + """启动同步服务""" + service = get_sync_service() + service.start() + + +def stop_sync_service(): + """停止同步服务""" + global _sync_service + if _sync_service: + _sync_service.stop() + _sync_service = None diff --git a/sort.py b/sort.py deleted file mode 100644 index fb41847..0000000 --- a/sort.py +++ /dev/null @@ -1,214 +0,0 @@ -import numpy as np -from scipy.optimize import linear_sum_assignment -from filterpy.kalman import KalmanFilter - -def linear_assignment(cost_matrix): - x, y = linear_sum_assignment(cost_matrix) - return np.array(list(zip(x, y))) - -def iou_batch(bb_test, bb_gt): - """ - From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] - """ - bb_test = np.expand_dims(bb_test, 1) - bb_gt = np.expand_dims(bb_gt, 0) - xx1 = np.maximum(bb_test[..., 0], bb_gt[..., 0]) - yy1 = np.maximum(bb_test[..., 1], bb_gt[..., 1]) - xx2 = np.minimum(bb_test[..., 2], bb_gt[..., 2]) - yy2 = np.minimum(bb_test[..., 3], bb_gt[..., 3]) - w = np.maximum(0., xx2 - xx1) - h = np.maximum(0., yy2 - yy1) - wh = w * h - o = wh / ((bb_test[..., 2] - bb_test[..., 0]) * (bb_test[..., 3] - bb_test[..., 1]) - + (bb_gt[..., 2] - bb_gt[..., 0]) * (bb_gt[..., 3] - bb_gt[..., 1]) - wh) - return o - -def convert_bbox_to_z(bbox): - """ - Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form - [x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is - the aspect ratio - """ - w = bbox[2] - bbox[0] - h = bbox[3] - bbox[1] - x = bbox[0] + w / 2. - y = bbox[1] + h / 2. - s = w * h # scale is just area - r = w / float(h) - return np.array([x, y, s, r]).reshape((4, 1)) - -def convert_x_to_bbox(x, score=None): - """ - Takes a bounding box in the centre form [x,y,s,r] and returns it in the form - [x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right - """ - w = np.sqrt(x[2] * x[3]) - h = x[2] / w - if score is None: - return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2.]).reshape((1, 4)) - else: - return np.array([x[0] - w / 2., x[1] - h / 2., x[0] + w / 2., x[1] + h / 2., score]).reshape((1, 5)) - -class KalmanBoxTracker(object): - """ - This class represents the internal state of individual tracked objects observed as bbox. - """ - count = 0 - - def __init__(self, bbox): - """ - Initialises a tracker using initial bounding box. - """ - self.kf = KalmanFilter(dim_x=7, dim_z=4) - self.kf.F = np.array([[1, 0, 0, 0, 1, 0, 0], - [0, 1, 0, 0, 0, 1, 0], - [0, 0, 1, 0, 0, 0, 1], - [0, 0, 0, 1, 0, 0, 0], - [0, 0, 0, 0, 1, 0, 0], - [0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 0, 0, 0, 1]]) - self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0], - [0, 1, 0, 0, 0, 0, 0], - [0, 0, 1, 0, 0, 0, 0], - [0, 0, 0, 1, 0, 0, 0]]) - - self.kf.R[2:, 2:] *= 10. - self.kf.P[4:, 4:] *= 1000. # give high uncertainty to the unobservable initial velocities - self.kf.P *= 10. - self.kf.Q[-1, -1] *= 0.01 - self.kf.Q[4:, 4:] *= 0.01 - - self.kf.x[:4] = convert_bbox_to_z(bbox) - self.time_since_update = 0 - self.id = KalmanBoxTracker.count - KalmanBoxTracker.count += 1 - self.history = [] - self.hits = 0 - self.hit_streak = 0 - self.age = 0 - - def update(self, bbox): - """ - Updates the state vector with observed bbox. - """ - self.time_since_update = 0 - self.history = [] - self.hits += 1 - self.hit_streak += 1 - self.kf.update(convert_bbox_to_z(bbox)) - - def predict(self): - """ - Advances the state vector and returns the predicted bounding box estimate. - """ - if (self.kf.x[6] + self.kf.x[2]) <= 0: - self.kf.x[6] *= 0.0 - self.kf.predict() - self.age += 1 - if self.time_since_update > 0: - self.hit_streak = 0 - self.time_since_update += 1 - self.history.append(convert_x_to_bbox(self.kf.x)) - return self.history[-1] - - def get_state(self): - """ - Returns the current bounding box estimate. - """ - return convert_x_to_bbox(self.kf.x) - -class Sort(object): - def __init__(self, max_age=5, min_hits=2, iou_threshold=0.3): - """ - Sets key parameters for SORT - """ - self.max_age = max_age - self.min_hits = min_hits - self.iou_threshold = iou_threshold - self.trackers = [] - self.frame_count = 0 - - def update(self, dets=np.empty((0, 5))): - """ - Params: - dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],...] - Requires: this method must be called once for each frame even with empty detections. - Returns the a similar array, where the last column is the object ID. - - NOTE: The number of objects returned may differ from the number of detections provided. - """ - self.frame_count += 1 - trks = np.zeros((len(self.trackers), 5)) - to_del = [] - ret = [] - for t, trk in enumerate(trks): - pos = self.trackers[t].predict()[0] - trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] - if np.any(np.isnan(pos)): - to_del.append(t) - trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) - for t in reversed(to_del): - self.trackers.pop(t) - matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold) - - # update matched trackers with assigned detections - for m in matched: - self.trackers[m[1]].update(dets[m[0], :]) - - # create and initialise new trackers for unmatched detections - for i in unmatched_dets: - trk = KalmanBoxTracker(dets[i, :]) - self.trackers.append(trk) - i = len(self.trackers) - for trk in reversed(self.trackers): - d = trk.get_state()[0] - if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): - ret.append(np.concatenate((d, [trk.id + 1])).reshape(1, -1)) # +1 as MOT benchmark requires positive - i -= 1 - # remove dead tracklet - if trk.time_since_update > self.max_age: - self.trackers.pop(i) - if len(ret) > 0: - return np.concatenate(ret) - return np.empty((0, 5)) - -def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3): - """ - Assigns detections to tracked object (both represented as bounding boxes) - Returns 3 lists of matches, unmatched_detections, unmatched_trackers - """ - if len(trackers) == 0: - return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 1), dtype=int) - iou_matrix = iou_batch(detections, trackers) - - if min(iou_matrix.shape) > 0: - a = (iou_matrix > iou_threshold).astype(np.int32) - if a.sum(1).max() == 1 and a.sum(0).max() == 1: - matched_indices = np.stack(np.where(a), axis=1) - else: - matched_indices = linear_assignment(-iou_matrix) - else: - matched_indices = np.empty(shape=(0, 2)) - - unmatched_detections = [] - for d, det in enumerate(detections): - if d not in matched_indices[:, 0]: - unmatched_detections.append(d) - unmatched_trackers = [] - for t, trk in enumerate(trackers): - if t not in matched_indices[:, 1]: - unmatched_trackers.append(t) - - matches = [] - for m in matched_indices: - if iou_matrix[m[0], m[1]] < iou_threshold: - unmatched_detections.append(m[0]) - unmatched_trackers.append(m[1]) - else: - matches.append(m.reshape(1, 2)) - if len(matches) == 0: - matches = np.empty((0, 2), dtype=int) - else: - matches = np.concatenate(matches, axis=0) - - return matches, np.array(unmatched_detections), np.array(unmatched_trackers) \ No newline at end of file