300 lines
9.6 KiB
Python
300 lines
9.6 KiB
Python
"""
|
|
视频解码模块
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import threading
|
|
import queue
|
|
from typing import Optional, Dict, Tuple
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from .utils import setup_logging, RateCounter
|
|
|
|
logger = setup_logging()
|
|
|
|
|
|
class SourceType(Enum):
|
|
RTSP = "rtsp"
|
|
FILE = "file"
|
|
SYNTHETIC = "synthetic"
|
|
|
|
|
|
@dataclass
|
|
class DecodeStats:
|
|
"""解码统计信息"""
|
|
total_frames: int = 0
|
|
dropped_frames: int = 0
|
|
decode_errors: int = 0
|
|
reconnect_count: int = 0
|
|
avg_decode_time_ms: float = 0.0
|
|
current_fps: float = 0.0
|
|
|
|
|
|
class DecodeThread(threading.Thread):
|
|
"""视频解码线程"""
|
|
|
|
def __init__(
|
|
self,
|
|
source_id: str,
|
|
source: str,
|
|
frame_queue: queue.Queue,
|
|
target_fps: float,
|
|
source_type: str = "rtsp",
|
|
resolution: Tuple[int, int] = (640, 480)
|
|
):
|
|
super().__init__(daemon=True)
|
|
|
|
self.source_id = source_id
|
|
self.source = source
|
|
self.frame_queue = frame_queue
|
|
self.target_fps = target_fps
|
|
self.source_type = SourceType(source_type)
|
|
self.resolution = resolution
|
|
|
|
self._running = False
|
|
self._lock = threading.Lock()
|
|
self._cap: Optional[cv2.VideoCapture] = None
|
|
|
|
self.stats = DecodeStats()
|
|
self._rate_counter = RateCounter(window_size=int(target_fps * 2))
|
|
self._decode_times = []
|
|
|
|
self._frame_interval = 1.0 / target_fps
|
|
self._last_frame_time = 0.0
|
|
|
|
def run(self):
|
|
self._running = True
|
|
logger.info(f"[{self.source_id}] 解码线程启动,目标帧率: {self.target_fps} FPS")
|
|
|
|
while self._running:
|
|
try:
|
|
if self.source_type == SourceType.SYNTHETIC:
|
|
self._decode_synthetic()
|
|
elif self.source_type == SourceType.FILE:
|
|
self._decode_file()
|
|
else:
|
|
self._decode_rtsp()
|
|
except Exception as e:
|
|
logger.error(f"[{self.source_id}] 解码异常: {e}")
|
|
self.stats.decode_errors += 1
|
|
time.sleep(1.0)
|
|
|
|
self._cleanup()
|
|
logger.info(f"[{self.source_id}] 解码线程停止")
|
|
|
|
def _decode_synthetic(self):
|
|
"""合成图像源解码"""
|
|
while self._running:
|
|
current_time = time.time()
|
|
elapsed = current_time - self._last_frame_time
|
|
if elapsed < self._frame_interval:
|
|
time.sleep(self._frame_interval - elapsed)
|
|
|
|
start_time = time.perf_counter()
|
|
frame = self._generate_synthetic_frame()
|
|
|
|
decode_time = (time.perf_counter() - start_time) * 1000
|
|
self._decode_times.append(decode_time)
|
|
if len(self._decode_times) > 100:
|
|
self._decode_times.pop(0)
|
|
|
|
self._put_frame(frame)
|
|
|
|
self._last_frame_time = time.time()
|
|
self.stats.total_frames += 1
|
|
self._rate_counter.tick()
|
|
|
|
def _generate_synthetic_frame(self) -> np.ndarray:
|
|
"""生成合成帧"""
|
|
h, w = self.resolution
|
|
frame = np.random.randint(50, 200, (h, w, 3), dtype=np.uint8)
|
|
|
|
num_objects = np.random.randint(0, 5)
|
|
for _ in range(num_objects):
|
|
x1 = np.random.randint(0, w - 50)
|
|
y1 = np.random.randint(0, h - 100)
|
|
x2 = x1 + np.random.randint(30, 80)
|
|
y2 = y1 + np.random.randint(60, 150)
|
|
color = tuple(np.random.randint(0, 255, 3).tolist())
|
|
cv2.rectangle(frame, (x1, y1), (x2, y2), color, -1)
|
|
|
|
return frame
|
|
|
|
def _decode_file(self):
|
|
"""本地视频文件解码"""
|
|
if not os.path.exists(self.source):
|
|
logger.error(f"[{self.source_id}] 视频文件不存在: {self.source}")
|
|
return
|
|
|
|
self._cap = cv2.VideoCapture(self.source)
|
|
if not self._cap.isOpened():
|
|
logger.error(f"[{self.source_id}] 无法打开视频文件: {self.source}")
|
|
return
|
|
|
|
source_fps = self._cap.get(cv2.CAP_PROP_FPS)
|
|
frame_skip = max(1, int(source_fps / self.target_fps))
|
|
frame_count = 0
|
|
|
|
while self._running:
|
|
start_time = time.perf_counter()
|
|
|
|
ret, frame = self._cap.read()
|
|
if not ret:
|
|
self._cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
|
continue
|
|
|
|
frame_count += 1
|
|
if frame_count % frame_skip != 0:
|
|
continue
|
|
|
|
decode_time = (time.perf_counter() - start_time) * 1000
|
|
self._decode_times.append(decode_time)
|
|
if len(self._decode_times) > 100:
|
|
self._decode_times.pop(0)
|
|
|
|
self._put_frame(frame)
|
|
|
|
self.stats.total_frames += 1
|
|
self._rate_counter.tick()
|
|
|
|
elapsed = time.perf_counter() - start_time
|
|
sleep_time = self._frame_interval - elapsed
|
|
if sleep_time > 0:
|
|
time.sleep(sleep_time)
|
|
|
|
def _decode_rtsp(self):
|
|
"""RTSP 流解码"""
|
|
max_retries = 3
|
|
retry_count = 0
|
|
|
|
while self._running and retry_count < max_retries:
|
|
try:
|
|
self._cap = cv2.VideoCapture(self.source, cv2.CAP_FFMPEG)
|
|
if not self._cap.isOpened():
|
|
raise RuntimeError(f"无法连接 RTSP: {self.source}")
|
|
|
|
self._cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
|
logger.info(f"[{self.source_id}] RTSP 连接成功")
|
|
retry_count = 0
|
|
|
|
while self._running:
|
|
start_time = time.perf_counter()
|
|
|
|
ret, frame = self._cap.read()
|
|
if not ret:
|
|
logger.warning(f"[{self.source_id}] RTSP 读取失败")
|
|
break
|
|
|
|
decode_time = (time.perf_counter() - start_time) * 1000
|
|
self._decode_times.append(decode_time)
|
|
if len(self._decode_times) > 100:
|
|
self._decode_times.pop(0)
|
|
|
|
self._put_frame(frame)
|
|
|
|
self.stats.total_frames += 1
|
|
self._rate_counter.tick()
|
|
|
|
elapsed = time.perf_counter() - start_time
|
|
sleep_time = self._frame_interval - elapsed
|
|
if sleep_time > 0:
|
|
time.sleep(sleep_time)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[{self.source_id}] RTSP 错误: {e}")
|
|
retry_count += 1
|
|
self.stats.reconnect_count += 1
|
|
time.sleep(5.0)
|
|
finally:
|
|
if self._cap:
|
|
self._cap.release()
|
|
self._cap = None
|
|
|
|
if retry_count >= max_retries:
|
|
logger.error(f"[{self.source_id}] RTSP 重连失败,切换到合成源")
|
|
self.source_type = SourceType.SYNTHETIC
|
|
|
|
def _put_frame(self, frame: np.ndarray):
|
|
"""将帧放入队列"""
|
|
timestamp = time.time()
|
|
|
|
try:
|
|
if self.frame_queue.full():
|
|
try:
|
|
self.frame_queue.get_nowait()
|
|
self.stats.dropped_frames += 1
|
|
except queue.Empty:
|
|
pass
|
|
|
|
self.frame_queue.put_nowait((frame, timestamp, self.source_id))
|
|
except queue.Full:
|
|
self.stats.dropped_frames += 1
|
|
|
|
def stop(self):
|
|
self._running = False
|
|
|
|
def _cleanup(self):
|
|
with self._lock:
|
|
if self._cap:
|
|
self._cap.release()
|
|
self._cap = None
|
|
|
|
def get_stats(self) -> DecodeStats:
|
|
self.stats.current_fps = self._rate_counter.get_rate()
|
|
if self._decode_times:
|
|
self.stats.avg_decode_time_ms = sum(self._decode_times) / len(self._decode_times)
|
|
return self.stats
|
|
|
|
|
|
class FrameQueueManager:
|
|
"""帧队列管理器"""
|
|
|
|
def __init__(self, queue_size: int = 2):
|
|
self.queue_size = queue_size
|
|
self.queues: Dict[str, queue.Queue] = {}
|
|
self.decode_threads: Dict[str, DecodeThread] = {}
|
|
|
|
def add_source(
|
|
self,
|
|
source_id: str,
|
|
source: str,
|
|
target_fps: float,
|
|
source_type: str = "synthetic",
|
|
resolution: Tuple[int, int] = (640, 480)
|
|
) -> queue.Queue:
|
|
frame_queue = queue.Queue(maxsize=self.queue_size)
|
|
self.queues[source_id] = frame_queue
|
|
|
|
decode_thread = DecodeThread(
|
|
source_id=source_id,
|
|
source=source,
|
|
frame_queue=frame_queue,
|
|
target_fps=target_fps,
|
|
source_type=source_type,
|
|
resolution=resolution
|
|
)
|
|
self.decode_threads[source_id] = decode_thread
|
|
|
|
return frame_queue
|
|
|
|
def start_all(self):
|
|
for thread in self.decode_threads.values():
|
|
thread.start()
|
|
|
|
def stop_all(self):
|
|
for thread in self.decode_threads.values():
|
|
thread.stop()
|
|
for thread in self.decode_threads.values():
|
|
thread.join(timeout=3.0)
|
|
|
|
def get_all_stats(self) -> Dict[str, DecodeStats]:
|
|
return {source_id: thread.get_stats() for source_id, thread in self.decode_threads.items()}
|
|
|
|
def get_total_dropped_frames(self) -> int:
|
|
return sum(thread.stats.dropped_frames for thread in self.decode_threads.values())
|