Files
Test_AI/benchmark/decode_thread.py

300 lines
9.6 KiB
Python
Raw Normal View History

2026-01-20 10:54:30 +08:00
"""
视频解码模块
"""
import os
import time
import threading
import queue
from typing import Optional, Dict, Tuple
from dataclasses import dataclass
from enum import Enum
import cv2
import numpy as np
from .utils import setup_logging, RateCounter
logger = setup_logging()
class SourceType(Enum):
RTSP = "rtsp"
FILE = "file"
SYNTHETIC = "synthetic"
@dataclass
class DecodeStats:
"""解码统计信息"""
total_frames: int = 0
dropped_frames: int = 0
decode_errors: int = 0
reconnect_count: int = 0
avg_decode_time_ms: float = 0.0
current_fps: float = 0.0
class DecodeThread(threading.Thread):
"""视频解码线程"""
def __init__(
self,
source_id: str,
source: str,
frame_queue: queue.Queue,
target_fps: float,
source_type: str = "rtsp",
resolution: Tuple[int, int] = (640, 480)
):
super().__init__(daemon=True)
self.source_id = source_id
self.source = source
self.frame_queue = frame_queue
self.target_fps = target_fps
self.source_type = SourceType(source_type)
self.resolution = resolution
self._running = False
self._lock = threading.Lock()
self._cap: Optional[cv2.VideoCapture] = None
self.stats = DecodeStats()
self._rate_counter = RateCounter(window_size=int(target_fps * 2))
self._decode_times = []
self._frame_interval = 1.0 / target_fps
self._last_frame_time = 0.0
def run(self):
self._running = True
logger.info(f"[{self.source_id}] 解码线程启动,目标帧率: {self.target_fps} FPS")
while self._running:
try:
if self.source_type == SourceType.SYNTHETIC:
self._decode_synthetic()
elif self.source_type == SourceType.FILE:
self._decode_file()
else:
self._decode_rtsp()
except Exception as e:
logger.error(f"[{self.source_id}] 解码异常: {e}")
self.stats.decode_errors += 1
time.sleep(1.0)
self._cleanup()
logger.info(f"[{self.source_id}] 解码线程停止")
def _decode_synthetic(self):
"""合成图像源解码"""
while self._running:
current_time = time.time()
elapsed = current_time - self._last_frame_time
if elapsed < self._frame_interval:
time.sleep(self._frame_interval - elapsed)
start_time = time.perf_counter()
frame = self._generate_synthetic_frame()
decode_time = (time.perf_counter() - start_time) * 1000
self._decode_times.append(decode_time)
if len(self._decode_times) > 100:
self._decode_times.pop(0)
self._put_frame(frame)
self._last_frame_time = time.time()
self.stats.total_frames += 1
self._rate_counter.tick()
def _generate_synthetic_frame(self) -> np.ndarray:
"""生成合成帧"""
h, w = self.resolution
frame = np.random.randint(50, 200, (h, w, 3), dtype=np.uint8)
num_objects = np.random.randint(0, 5)
for _ in range(num_objects):
x1 = np.random.randint(0, w - 50)
y1 = np.random.randint(0, h - 100)
x2 = x1 + np.random.randint(30, 80)
y2 = y1 + np.random.randint(60, 150)
color = tuple(np.random.randint(0, 255, 3).tolist())
cv2.rectangle(frame, (x1, y1), (x2, y2), color, -1)
return frame
def _decode_file(self):
"""本地视频文件解码"""
if not os.path.exists(self.source):
logger.error(f"[{self.source_id}] 视频文件不存在: {self.source}")
return
self._cap = cv2.VideoCapture(self.source)
if not self._cap.isOpened():
logger.error(f"[{self.source_id}] 无法打开视频文件: {self.source}")
return
source_fps = self._cap.get(cv2.CAP_PROP_FPS)
frame_skip = max(1, int(source_fps / self.target_fps))
frame_count = 0
while self._running:
start_time = time.perf_counter()
ret, frame = self._cap.read()
if not ret:
self._cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
continue
frame_count += 1
if frame_count % frame_skip != 0:
continue
decode_time = (time.perf_counter() - start_time) * 1000
self._decode_times.append(decode_time)
if len(self._decode_times) > 100:
self._decode_times.pop(0)
self._put_frame(frame)
self.stats.total_frames += 1
self._rate_counter.tick()
elapsed = time.perf_counter() - start_time
sleep_time = self._frame_interval - elapsed
if sleep_time > 0:
time.sleep(sleep_time)
def _decode_rtsp(self):
"""RTSP 流解码"""
max_retries = 3
retry_count = 0
while self._running and retry_count < max_retries:
try:
self._cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
if not self._cap.isOpened():
raise RuntimeError(f"无法连接 RTSP: {self.source}")
self._cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
logger.info(f"[{self.source_id}] RTSP 连接成功")
retry_count = 0
while self._running:
start_time = time.perf_counter()
ret, frame = self._cap.read()
if not ret:
logger.warning(f"[{self.source_id}] RTSP 读取失败")
break
decode_time = (time.perf_counter() - start_time) * 1000
self._decode_times.append(decode_time)
if len(self._decode_times) > 100:
self._decode_times.pop(0)
self._put_frame(frame)
self.stats.total_frames += 1
self._rate_counter.tick()
elapsed = time.perf_counter() - start_time
sleep_time = self._frame_interval - elapsed
if sleep_time > 0:
time.sleep(sleep_time)
except Exception as e:
logger.error(f"[{self.source_id}] RTSP 错误: {e}")
retry_count += 1
self.stats.reconnect_count += 1
time.sleep(5.0)
finally:
if self._cap:
self._cap.release()
self._cap = None
if retry_count >= max_retries:
logger.error(f"[{self.source_id}] RTSP 重连失败,切换到合成源")
self.source_type = SourceType.SYNTHETIC
def _put_frame(self, frame: np.ndarray):
"""将帧放入队列"""
timestamp = time.time()
try:
if self.frame_queue.full():
try:
self.frame_queue.get_nowait()
self.stats.dropped_frames += 1
except queue.Empty:
pass
self.frame_queue.put_nowait((frame, timestamp, self.source_id))
except queue.Full:
self.stats.dropped_frames += 1
def stop(self):
self._running = False
def _cleanup(self):
with self._lock:
if self._cap:
self._cap.release()
self._cap = None
def get_stats(self) -> DecodeStats:
self.stats.current_fps = self._rate_counter.get_rate()
if self._decode_times:
self.stats.avg_decode_time_ms = sum(self._decode_times) / len(self._decode_times)
return self.stats
class FrameQueueManager:
"""帧队列管理器"""
def __init__(self, queue_size: int = 2):
self.queue_size = queue_size
self.queues: Dict[str, queue.Queue] = {}
self.decode_threads: Dict[str, DecodeThread] = {}
def add_source(
self,
source_id: str,
source: str,
target_fps: float,
source_type: str = "synthetic",
resolution: Tuple[int, int] = (640, 480)
) -> queue.Queue:
frame_queue = queue.Queue(maxsize=self.queue_size)
self.queues[source_id] = frame_queue
decode_thread = DecodeThread(
source_id=source_id,
source=source,
frame_queue=frame_queue,
target_fps=target_fps,
source_type=source_type,
resolution=resolution
)
self.decode_threads[source_id] = decode_thread
return frame_queue
def start_all(self):
for thread in self.decode_threads.values():
thread.start()
def stop_all(self):
for thread in self.decode_threads.values():
thread.stop()
for thread in self.decode_threads.values():
thread.join(timeout=3.0)
def get_all_stats(self) -> Dict[str, DecodeStats]:
return {source_id: thread.get_stats() for source_id, thread in self.decode_threads.items()}
def get_total_dropped_frames(self) -> int:
return sum(thread.stats.dropped_frames for thread in self.decode_threads.values())