Files
Test_AI/capture_dataset.py
16337 942244bd88 Add YOLO11 TensorRT quantization benchmark scripts
- Engine build scripts (FP16/INT8)
- Benchmark validation scripts
- Result parsing and analysis tools
- COCO dataset configuration
2026-01-29 13:59:42 +08:00

287 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import time
import os
import threading
import logging
from datetime import datetime
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[
logging.FileHandler("capture.log", encoding='utf-8'),
logging.StreamHandler()
]
)
# 全局配置
RTSP_FILE = 'rtsp.txt'
DATA_DIR = 'data'
TARGET_SIZE = (640, 640)
FILL_COLOR = (114, 114, 114)
JPEG_QUALITY = 95
# 采集策略配置
TIME_INTERVAL = 30 * 60 # 定时采集间隔 (秒)
CHANGE_THRESHOLD_AREA = 0.05 # 变化面积阈值 (5%)
CHANGE_DURATION_THRESHOLD = 0.5 # 变化持续时间阈值 (秒)
COOLDOWN_TIME = 60 # 冷却时间 (秒)
RECONNECT_DELAY = 10 # 重连延迟 (秒)
STUCK_CHECK_INTERVAL = 10 # 检查卡死的时间窗口 (秒)
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path)
def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
"""
将图像缩放并填充到指定尺寸,保持长宽比。
"""
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return img
class CameraCapture(threading.Thread):
def __init__(self, camera_id, rtsp_url):
super().__init__()
self.camera_id = camera_id
self.rtsp_url = rtsp_url.strip()
self.save_dir = os.path.join(DATA_DIR, f"cam_{camera_id:02d}")
ensure_dir(self.save_dir)
self.running = True
self.cap = None
# 状态变量
self.last_time_capture = 0
self.last_trigger_capture = 0
self.change_start_time = None
# 运动检测辅助变量
self.prev_gray = None
# 卡死检测
self.last_frame_content_hash = None
self.last_frame_change_time = time.time()
def connect(self):
if self.cap is not None:
self.cap.release()
logging.info(f"[Cam {self.camera_id}] Connecting to RTSP stream...")
self.cap = cv2.VideoCapture(self.rtsp_url)
if not self.cap.isOpened():
logging.error(f"[Cam {self.camera_id}] Failed to open stream.")
return False
return True
def save_image(self, frame, trigger_type):
"""
保存图像
trigger_type: 'T' for Time-based, 'TM' for Motion-based
"""
try:
# 预处理
processed_img = letterbox(frame, new_shape=TARGET_SIZE, color=FILL_COLOR)
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"cam_{self.camera_id:02d}_{timestamp}_{trigger_type}.jpg"
filepath = os.path.join(self.save_dir, filename)
# 保存
cv2.imwrite(filepath, processed_img, [int(cv2.IMWRITE_JPEG_QUALITY), JPEG_QUALITY])
logging.info(f"[Cam {self.camera_id}] Saved {trigger_type}: {filename}")
return True
except Exception as e:
logging.error(f"[Cam {self.camera_id}] Error saving image: {e}")
return False
def check_motion(self, frame_gray):
"""
检测运动
返回: 是否触发采集 (True/False)
"""
if self.prev_gray is None:
self.prev_gray = frame_gray
return False
# 计算帧差
frame_delta = cv2.absdiff(self.prev_gray, frame_gray)
thresh = cv2.threshold(frame_delta, 25, 255, cv2.THRESH_BINARY)[1]
# 膨胀填补孔洞
thresh = cv2.dilate(thresh, None, iterations=2)
# 计算变化面积比例
height, width = thresh.shape
total_pixels = height * width
changed_pixels = cv2.countNonZero(thresh)
change_ratio = changed_pixels / total_pixels
self.prev_gray = frame_gray
# 逻辑判定
now = time.time()
# 如果还在冷却期直接返回False
if now - self.last_trigger_capture < COOLDOWN_TIME:
self.change_start_time = None # 重置累计时间
return False
if change_ratio > CHANGE_THRESHOLD_AREA:
if self.change_start_time is None:
self.change_start_time = now
elif now - self.change_start_time >= CHANGE_DURATION_THRESHOLD:
# 持续时间达标,触发
self.change_start_time = None # 重置
return True
else:
self.change_start_time = None
return False
def is_stuck(self, frame):
"""
简单的卡死检测:检查画面是否完全无变化(像素级一致)
"""
# 计算简单的均值或哈希作为特征
current_hash = np.sum(frame) # 简单求和作为特征,或者可以用更复杂的
if self.last_frame_content_hash is None:
self.last_frame_content_hash = current_hash
self.last_frame_change_time = time.time()
return False
# 如果特征完全一致(考虑到浮点误差或压缩噪声,完全一致通常意味着数字信号卡死)
# 对于RTSP流如果摄像头卡死通常read会返回完全相同的buffer
if current_hash == self.last_frame_content_hash:
# 如果持续超过一定时间没有变化
if time.time() - self.last_frame_change_time > 60: # 1分钟无任何像素变化
return True
else:
self.last_frame_content_hash = current_hash
self.last_frame_change_time = time.time()
return False
def run(self):
while self.running:
if self.cap is None or not self.cap.isOpened():
if not self.connect():
time.sleep(RECONNECT_DELAY)
continue
ret, frame = self.cap.read()
if not ret:
logging.warning(f"[Cam {self.camera_id}] Stream disconnected/empty. Reconnecting in {RECONNECT_DELAY}s...")
self.cap.release()
time.sleep(RECONNECT_DELAY)
continue
now = time.time()
# 1. 健壮性:静止画面过滤 (卡死检测)
if self.is_stuck(frame):
logging.warning(f"[Cam {self.camera_id}] Frame stuck detected. Skipping storage.")
# 尝试重连以恢复
self.cap.release()
continue
# 转灰度用于运动检测
try:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
gray = cv2.GaussianBlur(gray, (21, 21), 0)
except Exception as e:
logging.error(f"[Cam {self.camera_id}] Image processing error: {e}")
continue
# 2. 定时采集 (Time-based)
if now - self.last_time_capture >= TIME_INTERVAL:
if self.save_image(frame, "T"):
self.last_time_capture = now
# 3. 变化触发采集 (Change-based)
if self.check_motion(gray):
if self.save_image(frame, "TM"):
self.last_trigger_capture = now
# 简单的帧率控制避免CPU占用过高
time.sleep(0.01)
if self.cap:
self.cap.release()
def load_rtsp_list(filepath):
cameras = []
if not os.path.exists(filepath):
logging.error(f"RTSP file not found: {filepath}")
return cameras
with open(filepath, 'r') as f:
lines = f.readlines()
for idx, line in enumerate(lines):
line = line.strip()
if line and not line.startswith('#'):
# ID 从 1 开始
cameras.append({'id': idx + 1, 'url': line})
return cameras
def main():
ensure_dir(DATA_DIR)
rtsp_list = load_rtsp_list(RTSP_FILE)
if not rtsp_list:
logging.error("No RTSP URLs found.")
return
logging.info(f"Loaded {len(rtsp_list)} cameras.")
threads = []
for cam_info in rtsp_list:
t = CameraCapture(cam_info['id'], cam_info['url'])
t.start()
threads.append(t)
try:
while True:
time.sleep(1)
# 监控线程存活状态
dead_threads = [t for t in threads if not t.is_alive()]
if dead_threads:
logging.warning(f"Found {len(dead_threads)} dead threads. (Should be handled inside run loop)")
except KeyboardInterrupt:
logging.info("Stopping all tasks...")
for t in threads:
t.running = False
for t in threads:
t.join()
logging.info("All tasks stopped.")
if __name__ == "__main__":
main()