GPU测试

This commit is contained in:
2026-01-20 10:54:30 +08:00
parent 8463f5a571
commit 8e9de9c858
59 changed files with 18934 additions and 0 deletions

View File

@@ -0,0 +1,173 @@
"""
Batch 组装和 GPU 预处理模块
"""
import time
import queue
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
import cv2
import numpy as np
from .utils import setup_logging
logger = setup_logging()
class GPUPreprocessor:
"""GPU 图像预处理器"""
def __init__(self, target_size: Tuple[int, int], device_id: int = 0, use_gpu: bool = True):
self.target_size = target_size
self.device_id = device_id
self.use_gpu = use_gpu
self._init_gpu()
def _init_gpu(self):
self._gpu_available = False
if not self.use_gpu:
logger.info("GPU 预处理已禁用,使用 CPU")
return
try:
if cv2.cuda.getCudaEnabledDeviceCount() > 0:
self._gpu_available = True
logger.info(f"OpenCV CUDA 可用,设备数: {cv2.cuda.getCudaEnabledDeviceCount()}")
else:
logger.warning("OpenCV CUDA 不可用,使用 CPU 预处理")
except Exception as e:
logger.warning(f"GPU 初始化失败: {e},使用 CPU 预处理")
def preprocess_single(self, frame: np.ndarray) -> np.ndarray:
h, w = self.target_size
if self._gpu_available:
return self._preprocess_gpu(frame, h, w)
else:
return self._preprocess_cpu(frame, h, w)
def _preprocess_cpu(self, frame: np.ndarray, h: int, w: int) -> np.ndarray:
resized = cv2.resize(frame, (w, h), interpolation=cv2.INTER_LINEAR)
rgb = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
normalized = rgb.astype(np.float32) / 255.0
transposed = np.transpose(normalized, (2, 0, 1))
return transposed
def _preprocess_gpu(self, frame: np.ndarray, h: int, w: int) -> np.ndarray:
try:
gpu_frame = cv2.cuda_GpuMat()
gpu_frame.upload(frame)
gpu_resized = cv2.cuda.resize(gpu_frame, (w, h), interpolation=cv2.INTER_LINEAR)
gpu_rgb = cv2.cuda.cvtColor(gpu_resized, cv2.COLOR_BGR2RGB)
rgb = gpu_rgb.download()
normalized = rgb.astype(np.float32) / 255.0
transposed = np.transpose(normalized, (2, 0, 1))
return transposed
except Exception as e:
logger.warning(f"GPU 预处理失败: {e},回退到 CPU")
return self._preprocess_cpu(frame, h, w)
def preprocess_batch(self, frames: List[np.ndarray], stream: Any = None) -> np.ndarray:
if not frames:
return np.array([])
processed = [self.preprocess_single(frame) for frame in frames]
return np.stack(processed, axis=0)
@dataclass
class FrameInfo:
"""帧信息"""
frame: np.ndarray
timestamp: float
source_id: str
class BatchAssembler:
"""Batch 组装器"""
def __init__(
self,
frame_queues: Dict[str, queue.Queue],
batch_size: int,
imgsz: Tuple[int, int],
use_gpu_preprocess: bool = True,
device_id: int = 0
):
self.frame_queues = frame_queues
self.batch_size = batch_size
self.imgsz = imgsz
self.preprocessor = GPUPreprocessor(
target_size=imgsz,
device_id=device_id,
use_gpu=use_gpu_preprocess
)
self._total_frames = 0
self._dropped_frames = 0
self._incomplete_batches = 0
self._queue_keys = list(frame_queues.keys())
self._current_index = 0
def assemble_batch(self, timeout: float = 0.1) -> Optional[Tuple[np.ndarray, List[FrameInfo]]]:
frames = []
frame_infos = []
start_time = time.time()
while len(frames) < self.batch_size:
if time.time() - start_time > timeout:
break
got_frame = False
for _ in range(len(self._queue_keys)):
source_id = self._queue_keys[self._current_index]
self._current_index = (self._current_index + 1) % len(self._queue_keys)
q = self.frame_queues[source_id]
try:
frame, timestamp, src_id = q.get_nowait()
frames.append(frame)
frame_infos.append(FrameInfo(frame=frame, timestamp=timestamp, source_id=src_id))
got_frame = True
if len(frames) >= self.batch_size:
break
except queue.Empty:
continue
if not got_frame:
time.sleep(0.001)
if not frames:
return None
self._total_frames += len(frames)
if len(frames) < self.batch_size:
self._incomplete_batches += 1
batch = self.preprocessor.preprocess_batch(frames)
return batch, frame_infos
def get_drop_rate(self) -> float:
if self._total_frames == 0:
return 0.0
return self._dropped_frames / self._total_frames * 100
def get_stats(self) -> Dict[str, Any]:
return {
"total_frames": self._total_frames,
"dropped_frames": self._dropped_frames,
"incomplete_batches": self._incomplete_batches,
"drop_rate": self.get_drop_rate(),
}
def reset_stats(self):
self._total_frames = 0
self._dropped_frames = 0
self._incomplete_batches = 0