GPU测试
This commit is contained in:
173
benchmark/batch_assembler.py
Normal file
173
benchmark/batch_assembler.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Batch 组装和 GPU 预处理模块
|
||||
"""
|
||||
|
||||
import time
|
||||
import queue
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class GPUPreprocessor:
|
||||
"""GPU 图像预处理器"""
|
||||
|
||||
def __init__(self, target_size: Tuple[int, int], device_id: int = 0, use_gpu: bool = True):
|
||||
self.target_size = target_size
|
||||
self.device_id = device_id
|
||||
self.use_gpu = use_gpu
|
||||
self._init_gpu()
|
||||
|
||||
def _init_gpu(self):
|
||||
self._gpu_available = False
|
||||
|
||||
if not self.use_gpu:
|
||||
logger.info("GPU 预处理已禁用,使用 CPU")
|
||||
return
|
||||
|
||||
try:
|
||||
if cv2.cuda.getCudaEnabledDeviceCount() > 0:
|
||||
self._gpu_available = True
|
||||
logger.info(f"OpenCV CUDA 可用,设备数: {cv2.cuda.getCudaEnabledDeviceCount()}")
|
||||
else:
|
||||
logger.warning("OpenCV CUDA 不可用,使用 CPU 预处理")
|
||||
except Exception as e:
|
||||
logger.warning(f"GPU 初始化失败: {e},使用 CPU 预处理")
|
||||
|
||||
def preprocess_single(self, frame: np.ndarray) -> np.ndarray:
|
||||
h, w = self.target_size
|
||||
|
||||
if self._gpu_available:
|
||||
return self._preprocess_gpu(frame, h, w)
|
||||
else:
|
||||
return self._preprocess_cpu(frame, h, w)
|
||||
|
||||
def _preprocess_cpu(self, frame: np.ndarray, h: int, w: int) -> np.ndarray:
|
||||
resized = cv2.resize(frame, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
|
||||
normalized = rgb.astype(np.float32) / 255.0
|
||||
transposed = np.transpose(normalized, (2, 0, 1))
|
||||
return transposed
|
||||
|
||||
def _preprocess_gpu(self, frame: np.ndarray, h: int, w: int) -> np.ndarray:
|
||||
try:
|
||||
gpu_frame = cv2.cuda_GpuMat()
|
||||
gpu_frame.upload(frame)
|
||||
gpu_resized = cv2.cuda.resize(gpu_frame, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
|
||||
rgb = gpu_rgb.download()
|
||||
normalized = rgb.astype(np.float32) / 255.0
|
||||
transposed = np.transpose(normalized, (2, 0, 1))
|
||||
return transposed
|
||||
except Exception as e:
|
||||
logger.warning(f"GPU 预处理失败: {e},回退到 CPU")
|
||||
return self._preprocess_cpu(frame, h, w)
|
||||
|
||||
def preprocess_batch(self, frames: List[np.ndarray], stream: Any = None) -> np.ndarray:
|
||||
if not frames:
|
||||
return np.array([])
|
||||
processed = [self.preprocess_single(frame) for frame in frames]
|
||||
return np.stack(processed, axis=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameInfo:
|
||||
"""帧信息"""
|
||||
frame: np.ndarray
|
||||
timestamp: float
|
||||
source_id: str
|
||||
|
||||
|
||||
class BatchAssembler:
|
||||
"""Batch 组装器"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_queues: Dict[str, queue.Queue],
|
||||
batch_size: int,
|
||||
imgsz: Tuple[int, int],
|
||||
use_gpu_preprocess: bool = True,
|
||||
device_id: int = 0
|
||||
):
|
||||
self.frame_queues = frame_queues
|
||||
self.batch_size = batch_size
|
||||
self.imgsz = imgsz
|
||||
|
||||
self.preprocessor = GPUPreprocessor(
|
||||
target_size=imgsz,
|
||||
device_id=device_id,
|
||||
use_gpu=use_gpu_preprocess
|
||||
)
|
||||
|
||||
self._total_frames = 0
|
||||
self._dropped_frames = 0
|
||||
self._incomplete_batches = 0
|
||||
|
||||
self._queue_keys = list(frame_queues.keys())
|
||||
self._current_index = 0
|
||||
|
||||
def assemble_batch(self, timeout: float = 0.1) -> Optional[Tuple[np.ndarray, List[FrameInfo]]]:
|
||||
frames = []
|
||||
frame_infos = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while len(frames) < self.batch_size:
|
||||
if time.time() - start_time > timeout:
|
||||
break
|
||||
|
||||
got_frame = False
|
||||
for _ in range(len(self._queue_keys)):
|
||||
source_id = self._queue_keys[self._current_index]
|
||||
self._current_index = (self._current_index + 1) % len(self._queue_keys)
|
||||
|
||||
q = self.frame_queues[source_id]
|
||||
try:
|
||||
frame, timestamp, src_id = q.get_nowait()
|
||||
frames.append(frame)
|
||||
frame_infos.append(FrameInfo(frame=frame, timestamp=timestamp, source_id=src_id))
|
||||
got_frame = True
|
||||
|
||||
if len(frames) >= self.batch_size:
|
||||
break
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
if not got_frame:
|
||||
time.sleep(0.001)
|
||||
|
||||
if not frames:
|
||||
return None
|
||||
|
||||
self._total_frames += len(frames)
|
||||
|
||||
if len(frames) < self.batch_size:
|
||||
self._incomplete_batches += 1
|
||||
|
||||
batch = self.preprocessor.preprocess_batch(frames)
|
||||
|
||||
return batch, frame_infos
|
||||
|
||||
def get_drop_rate(self) -> float:
|
||||
if self._total_frames == 0:
|
||||
return 0.0
|
||||
return self._dropped_frames / self._total_frames * 100
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"total_frames": self._total_frames,
|
||||
"dropped_frames": self._dropped_frames,
|
||||
"incomplete_batches": self._incomplete_batches,
|
||||
"drop_rate": self.get_drop_rate(),
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
self._total_frames = 0
|
||||
self._dropped_frames = 0
|
||||
self._incomplete_batches = 0
|
||||
Reference in New Issue
Block a user