GPU测试
This commit is contained in:
299
benchmark/decode_thread.py
Normal file
299
benchmark/decode_thread.py
Normal file
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
视频解码模块
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import queue
|
||||
from typing import Optional, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from .utils import setup_logging, RateCounter
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class SourceType(Enum):
|
||||
RTSP = "rtsp"
|
||||
FILE = "file"
|
||||
SYNTHETIC = "synthetic"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodeStats:
|
||||
"""解码统计信息"""
|
||||
total_frames: int = 0
|
||||
dropped_frames: int = 0
|
||||
decode_errors: int = 0
|
||||
reconnect_count: int = 0
|
||||
avg_decode_time_ms: float = 0.0
|
||||
current_fps: float = 0.0
|
||||
|
||||
|
||||
class DecodeThread(threading.Thread):
|
||||
"""视频解码线程"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source_id: str,
|
||||
source: str,
|
||||
frame_queue: queue.Queue,
|
||||
target_fps: float,
|
||||
source_type: str = "rtsp",
|
||||
resolution: Tuple[int, int] = (640, 480)
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
|
||||
self.source_id = source_id
|
||||
self.source = source
|
||||
self.frame_queue = frame_queue
|
||||
self.target_fps = target_fps
|
||||
self.source_type = SourceType(source_type)
|
||||
self.resolution = resolution
|
||||
|
||||
self._running = False
|
||||
self._lock = threading.Lock()
|
||||
self._cap: Optional[cv2.VideoCapture] = None
|
||||
|
||||
self.stats = DecodeStats()
|
||||
self._rate_counter = RateCounter(window_size=int(target_fps * 2))
|
||||
self._decode_times = []
|
||||
|
||||
self._frame_interval = 1.0 / target_fps
|
||||
self._last_frame_time = 0.0
|
||||
|
||||
def run(self):
|
||||
self._running = True
|
||||
logger.info(f"[{self.source_id}] 解码线程启动,目标帧率: {self.target_fps} FPS")
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
if self.source_type == SourceType.SYNTHETIC:
|
||||
self._decode_synthetic()
|
||||
elif self.source_type == SourceType.FILE:
|
||||
self._decode_file()
|
||||
else:
|
||||
self._decode_rtsp()
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.source_id}] 解码异常: {e}")
|
||||
self.stats.decode_errors += 1
|
||||
time.sleep(1.0)
|
||||
|
||||
self._cleanup()
|
||||
logger.info(f"[{self.source_id}] 解码线程停止")
|
||||
|
||||
def _decode_synthetic(self):
|
||||
"""合成图像源解码"""
|
||||
while self._running:
|
||||
current_time = time.time()
|
||||
elapsed = current_time - self._last_frame_time
|
||||
if elapsed < self._frame_interval:
|
||||
time.sleep(self._frame_interval - elapsed)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
frame = self._generate_synthetic_frame()
|
||||
|
||||
decode_time = (time.perf_counter() - start_time) * 1000
|
||||
self._decode_times.append(decode_time)
|
||||
if len(self._decode_times) > 100:
|
||||
self._decode_times.pop(0)
|
||||
|
||||
self._put_frame(frame)
|
||||
|
||||
self._last_frame_time = time.time()
|
||||
self.stats.total_frames += 1
|
||||
self._rate_counter.tick()
|
||||
|
||||
def _generate_synthetic_frame(self) -> np.ndarray:
|
||||
"""生成合成帧"""
|
||||
h, w = self.resolution
|
||||
frame = np.random.randint(50, 200, (h, w, 3), dtype=np.uint8)
|
||||
|
||||
num_objects = np.random.randint(0, 5)
|
||||
for _ in range(num_objects):
|
||||
x1 = np.random.randint(0, w - 50)
|
||||
y1 = np.random.randint(0, h - 100)
|
||||
x2 = x1 + np.random.randint(30, 80)
|
||||
y2 = y1 + np.random.randint(60, 150)
|
||||
color = tuple(np.random.randint(0, 255, 3).tolist())
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), color, -1)
|
||||
|
||||
return frame
|
||||
|
||||
def _decode_file(self):
|
||||
"""本地视频文件解码"""
|
||||
if not os.path.exists(self.source):
|
||||
logger.error(f"[{self.source_id}] 视频文件不存在: {self.source}")
|
||||
return
|
||||
|
||||
self._cap = cv2.VideoCapture(self.source)
|
||||
if not self._cap.isOpened():
|
||||
logger.error(f"[{self.source_id}] 无法打开视频文件: {self.source}")
|
||||
return
|
||||
|
||||
source_fps = self._cap.get(cv2.CAP_PROP_FPS)
|
||||
frame_skip = max(1, int(source_fps / self.target_fps))
|
||||
frame_count = 0
|
||||
|
||||
while self._running:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
ret, frame = self._cap.read()
|
||||
if not ret:
|
||||
self._cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||||
continue
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % frame_skip != 0:
|
||||
continue
|
||||
|
||||
decode_time = (time.perf_counter() - start_time) * 1000
|
||||
self._decode_times.append(decode_time)
|
||||
if len(self._decode_times) > 100:
|
||||
self._decode_times.pop(0)
|
||||
|
||||
self._put_frame(frame)
|
||||
|
||||
self.stats.total_frames += 1
|
||||
self._rate_counter.tick()
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
sleep_time = self._frame_interval - elapsed
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def _decode_rtsp(self):
|
||||
"""RTSP 流解码"""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
self._cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
|
||||
if not self._cap.isOpened():
|
||||
raise RuntimeError(f"无法连接 RTSP: {self.source}")
|
||||
|
||||
self._cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
||||
logger.info(f"[{self.source_id}] RTSP 连接成功")
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
ret, frame = self._cap.read()
|
||||
if not ret:
|
||||
logger.warning(f"[{self.source_id}] RTSP 读取失败")
|
||||
break
|
||||
|
||||
decode_time = (time.perf_counter() - start_time) * 1000
|
||||
self._decode_times.append(decode_time)
|
||||
if len(self._decode_times) > 100:
|
||||
self._decode_times.pop(0)
|
||||
|
||||
self._put_frame(frame)
|
||||
|
||||
self.stats.total_frames += 1
|
||||
self._rate_counter.tick()
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
sleep_time = self._frame_interval - elapsed
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.source_id}] RTSP 错误: {e}")
|
||||
retry_count += 1
|
||||
self.stats.reconnect_count += 1
|
||||
time.sleep(5.0)
|
||||
finally:
|
||||
if self._cap:
|
||||
self._cap.release()
|
||||
self._cap = None
|
||||
|
||||
if retry_count >= max_retries:
|
||||
logger.error(f"[{self.source_id}] RTSP 重连失败,切换到合成源")
|
||||
self.source_type = SourceType.SYNTHETIC
|
||||
|
||||
def _put_frame(self, frame: np.ndarray):
|
||||
"""将帧放入队列"""
|
||||
timestamp = time.time()
|
||||
|
||||
try:
|
||||
if self.frame_queue.full():
|
||||
try:
|
||||
self.frame_queue.get_nowait()
|
||||
self.stats.dropped_frames += 1
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
self.frame_queue.put_nowait((frame, timestamp, self.source_id))
|
||||
except queue.Full:
|
||||
self.stats.dropped_frames += 1
|
||||
|
||||
def stop(self):
|
||||
self._running = False
|
||||
|
||||
def _cleanup(self):
|
||||
with self._lock:
|
||||
if self._cap:
|
||||
self._cap.release()
|
||||
self._cap = None
|
||||
|
||||
def get_stats(self) -> DecodeStats:
|
||||
self.stats.current_fps = self._rate_counter.get_rate()
|
||||
if self._decode_times:
|
||||
self.stats.avg_decode_time_ms = sum(self._decode_times) / len(self._decode_times)
|
||||
return self.stats
|
||||
|
||||
|
||||
class FrameQueueManager:
|
||||
"""帧队列管理器"""
|
||||
|
||||
def __init__(self, queue_size: int = 2):
|
||||
self.queue_size = queue_size
|
||||
self.queues: Dict[str, queue.Queue] = {}
|
||||
self.decode_threads: Dict[str, DecodeThread] = {}
|
||||
|
||||
def add_source(
|
||||
self,
|
||||
source_id: str,
|
||||
source: str,
|
||||
target_fps: float,
|
||||
source_type: str = "synthetic",
|
||||
resolution: Tuple[int, int] = (640, 480)
|
||||
) -> queue.Queue:
|
||||
frame_queue = queue.Queue(maxsize=self.queue_size)
|
||||
self.queues[source_id] = frame_queue
|
||||
|
||||
decode_thread = DecodeThread(
|
||||
source_id=source_id,
|
||||
source=source,
|
||||
frame_queue=frame_queue,
|
||||
target_fps=target_fps,
|
||||
source_type=source_type,
|
||||
resolution=resolution
|
||||
)
|
||||
self.decode_threads[source_id] = decode_thread
|
||||
|
||||
return frame_queue
|
||||
|
||||
def start_all(self):
|
||||
for thread in self.decode_threads.values():
|
||||
thread.start()
|
||||
|
||||
def stop_all(self):
|
||||
for thread in self.decode_threads.values():
|
||||
thread.stop()
|
||||
for thread in self.decode_threads.values():
|
||||
thread.join(timeout=3.0)
|
||||
|
||||
def get_all_stats(self) -> Dict[str, DecodeStats]:
|
||||
return {source_id: thread.get_stats() for source_id, thread in self.decode_threads.items()}
|
||||
|
||||
def get_total_dropped_frames(self) -> int:
|
||||
return sum(thread.stats.dropped_frames for thread in self.decode_threads.values())
|
||||
Reference in New Issue
Block a user