174 lines
5.5 KiB
Python
174 lines
5.5 KiB
Python
|
|
"""
|
||
|
|
Batch 组装和 GPU 预处理模块
|
||
|
|
"""
|
||
|
|
|
||
|
|
import time
|
||
|
|
import queue
|
||
|
|
from typing import Dict, List, Optional, Tuple, Any
|
||
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
import cv2
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from .utils import setup_logging
|
||
|
|
|
||
|
|
logger = setup_logging()
|
||
|
|
|
||
|
|
|
||
|
|
class GPUPreprocessor:
|
||
|
|
"""GPU 图像预处理器"""
|
||
|
|
|
||
|
|
def __init__(self, target_size: Tuple[int, int], device_id: int = 0, use_gpu: bool = True):
|
||
|
|
self.target_size = target_size
|
||
|
|
self.device_id = device_id
|
||
|
|
self.use_gpu = use_gpu
|
||
|
|
self._init_gpu()
|
||
|
|
|
||
|
|
def _init_gpu(self):
|
||
|
|
self._gpu_available = False
|
||
|
|
|
||
|
|
if not self.use_gpu:
|
||
|
|
logger.info("GPU 预处理已禁用,使用 CPU")
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
if cv2.cuda.getCudaEnabledDeviceCount() > 0:
|
||
|
|
self._gpu_available = True
|
||
|
|
logger.info(f"OpenCV CUDA 可用,设备数: {cv2.cuda.getCudaEnabledDeviceCount()}")
|
||
|
|
else:
|
||
|
|
logger.warning("OpenCV CUDA 不可用,使用 CPU 预处理")
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"GPU 初始化失败: {e},使用 CPU 预处理")
|
||
|
|
|
||
|
|
def preprocess_single(self, frame: np.ndarray) -> np.ndarray:
|
||
|
|
h, w = self.target_size
|
||
|
|
|
||
|
|
if self._gpu_available:
|
||
|
|
return self._preprocess_gpu(frame, h, w)
|
||
|
|
else:
|
||
|
|
return self._preprocess_cpu(frame, h, w)
|
||
|
|
|
||
|
|
def _preprocess_cpu(self, frame: np.ndarray, h: int, w: int) -> np.ndarray:
|
||
|
|
resized = cv2.resize(frame, (w, h), interpolation=cv2.INTER_LINEAR)
|
||
|
|
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
||
|
|
normalized = rgb.astype(np.float32) / 255.0
|
||
|
|
transposed = np.transpose(normalized, (2, 0, 1))
|
||
|
|
return transposed
|
||
|
|
|
||
|
|
def _preprocess_gpu(self, frame: np.ndarray, h: int, w: int) -> np.ndarray:
|
||
|
|
try:
|
||
|
|
gpu_frame = cv2.cuda_GpuMat()
|
||
|
|
gpu_frame.upload(frame)
|
||
|
|
gpu_resized = cv2.cuda.resize(gpu_frame, (w, h), interpolation=cv2.INTER_LINEAR)
|
||
|
|
gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
|
||
|
|
rgb = gpu_rgb.download()
|
||
|
|
normalized = rgb.astype(np.float32) / 255.0
|
||
|
|
transposed = np.transpose(normalized, (2, 0, 1))
|
||
|
|
return transposed
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"GPU 预处理失败: {e},回退到 CPU")
|
||
|
|
return self._preprocess_cpu(frame, h, w)
|
||
|
|
|
||
|
|
def preprocess_batch(self, frames: List[np.ndarray], stream: Any = None) -> np.ndarray:
|
||
|
|
if not frames:
|
||
|
|
return np.array([])
|
||
|
|
processed = [self.preprocess_single(frame) for frame in frames]
|
||
|
|
return np.stack(processed, axis=0)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class FrameInfo:
|
||
|
|
"""帧信息"""
|
||
|
|
frame: np.ndarray
|
||
|
|
timestamp: float
|
||
|
|
source_id: str
|
||
|
|
|
||
|
|
|
||
|
|
class BatchAssembler:
|
||
|
|
"""Batch 组装器"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
frame_queues: Dict[str, queue.Queue],
|
||
|
|
batch_size: int,
|
||
|
|
imgsz: Tuple[int, int],
|
||
|
|
use_gpu_preprocess: bool = True,
|
||
|
|
device_id: int = 0
|
||
|
|
):
|
||
|
|
self.frame_queues = frame_queues
|
||
|
|
self.batch_size = batch_size
|
||
|
|
self.imgsz = imgsz
|
||
|
|
|
||
|
|
self.preprocessor = GPUPreprocessor(
|
||
|
|
target_size=imgsz,
|
||
|
|
device_id=device_id,
|
||
|
|
use_gpu=use_gpu_preprocess
|
||
|
|
)
|
||
|
|
|
||
|
|
self._total_frames = 0
|
||
|
|
self._dropped_frames = 0
|
||
|
|
self._incomplete_batches = 0
|
||
|
|
|
||
|
|
self._queue_keys = list(frame_queues.keys())
|
||
|
|
self._current_index = 0
|
||
|
|
|
||
|
|
def assemble_batch(self, timeout: float = 0.1) -> Optional[Tuple[np.ndarray, List[FrameInfo]]]:
|
||
|
|
frames = []
|
||
|
|
frame_infos = []
|
||
|
|
|
||
|
|
start_time = time.time()
|
||
|
|
|
||
|
|
while len(frames) < self.batch_size:
|
||
|
|
if time.time() - start_time > timeout:
|
||
|
|
break
|
||
|
|
|
||
|
|
got_frame = False
|
||
|
|
for _ in range(len(self._queue_keys)):
|
||
|
|
source_id = self._queue_keys[self._current_index]
|
||
|
|
self._current_index = (self._current_index + 1) % len(self._queue_keys)
|
||
|
|
|
||
|
|
q = self.frame_queues[source_id]
|
||
|
|
try:
|
||
|
|
frame, timestamp, src_id = q.get_nowait()
|
||
|
|
frames.append(frame)
|
||
|
|
frame_infos.append(FrameInfo(frame=frame, timestamp=timestamp, source_id=src_id))
|
||
|
|
got_frame = True
|
||
|
|
|
||
|
|
if len(frames) >= self.batch_size:
|
||
|
|
break
|
||
|
|
except queue.Empty:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if not got_frame:
|
||
|
|
time.sleep(0.001)
|
||
|
|
|
||
|
|
if not frames:
|
||
|
|
return None
|
||
|
|
|
||
|
|
self._total_frames += len(frames)
|
||
|
|
|
||
|
|
if len(frames) < self.batch_size:
|
||
|
|
self._incomplete_batches += 1
|
||
|
|
|
||
|
|
batch = self.preprocessor.preprocess_batch(frames)
|
||
|
|
|
||
|
|
return batch, frame_infos
|
||
|
|
|
||
|
|
def get_drop_rate(self) -> float:
|
||
|
|
if self._total_frames == 0:
|
||
|
|
return 0.0
|
||
|
|
return self._dropped_frames / self._total_frames * 100
|
||
|
|
|
||
|
|
def get_stats(self) -> Dict[str, Any]:
|
||
|
|
return {
|
||
|
|
"total_frames": self._total_frames,
|
||
|
|
"dropped_frames": self._dropped_frames,
|
||
|
|
"incomplete_batches": self._incomplete_batches,
|
||
|
|
"drop_rate": self.get_drop_rate(),
|
||
|
|
}
|
||
|
|
|
||
|
|
def reset_stats(self):
|
||
|
|
self._total_frames = 0
|
||
|
|
self._dropped_frames = 0
|
||
|
|
self._incomplete_batches = 0
|