GPU测试
This commit is contained in:
141
benchmark/config.py
Normal file
141
benchmark/config.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
配置管理模块
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
"""Benchmark 配置数据类"""
|
||||
|
||||
# TensorRT 配置
|
||||
model_path: str = ""
|
||||
engine_dir: str = "./engines"
|
||||
precision: str = "fp16"
|
||||
|
||||
# 测试参数
|
||||
resolutions: List[int] = field(default_factory=lambda: [320, 480])
|
||||
batch_sizes: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16])
|
||||
camera_counts: List[int] = field(default_factory=lambda: [1, 2, 5, 10, 15, 30])
|
||||
target_fps_list: List[float] = field(default_factory=lambda: [5.0, 10.0, 15.0, 20.0])
|
||||
|
||||
# 运行时配置
|
||||
num_cuda_streams: int = 2
|
||||
frame_queue_size: int = 2
|
||||
test_duration_sec: int = 60
|
||||
warmup_sec: int = 5
|
||||
cooldown_sec: int = 2
|
||||
|
||||
# 视频源配置
|
||||
video_sources: List[Dict[str, str]] = field(default_factory=list)
|
||||
use_synthetic: bool = True
|
||||
synthetic_resolution: Tuple[int, int] = (640, 480)
|
||||
|
||||
# 输出配置
|
||||
output_dir: str = "./benchmark_results"
|
||||
save_json: bool = True
|
||||
save_csv: bool = True
|
||||
generate_report: bool = True
|
||||
|
||||
# 指标采集配置
|
||||
metrics_sample_interval_ms: int = 100
|
||||
|
||||
# GPU 饱和判定阈值
|
||||
gpu_saturation_threshold: float = 85.0
|
||||
memory_saturation_threshold: float = 90.0
|
||||
latency_increase_threshold: float = 50.0
|
||||
|
||||
# 设备配置
|
||||
device_id: int = 0
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""验证配置参数有效性"""
|
||||
errors = []
|
||||
|
||||
if not self.model_path:
|
||||
errors.append("model_path 不能为空")
|
||||
elif not os.path.exists(self.model_path):
|
||||
errors.append(f"模型文件不存在: {self.model_path}")
|
||||
|
||||
if self.precision not in ["fp16", "fp32"]:
|
||||
errors.append(f"无效的精度模式: {self.precision}")
|
||||
|
||||
valid_resolutions = [320, 480]
|
||||
for res in self.resolutions:
|
||||
if res not in valid_resolutions:
|
||||
errors.append(f"无效的分辨率: {res}")
|
||||
|
||||
valid_batch_sizes = [1, 2, 4, 8, 16]
|
||||
for bs in self.batch_sizes:
|
||||
if bs not in valid_batch_sizes:
|
||||
errors.append(f"无效的 Batch Size: {bs}")
|
||||
|
||||
return errors
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "BenchmarkConfig":
|
||||
if "synthetic_resolution" in data and isinstance(data["synthetic_resolution"], list):
|
||||
data["synthetic_resolution"] = tuple(data["synthetic_resolution"])
|
||||
return cls(**data)
|
||||
|
||||
def get_test_combinations(self) -> List[Dict[str, Any]]:
|
||||
"""生成所有测试参数组合"""
|
||||
combinations = []
|
||||
for resolution in self.resolutions:
|
||||
for batch_size in self.batch_sizes:
|
||||
for camera_count in self.camera_counts:
|
||||
for target_fps in self.target_fps_list:
|
||||
combinations.append({
|
||||
"resolution": resolution,
|
||||
"batch_size": batch_size,
|
||||
"camera_count": camera_count,
|
||||
"target_fps": target_fps,
|
||||
})
|
||||
return combinations
|
||||
|
||||
def get_total_test_count(self) -> int:
|
||||
return len(self.resolutions) * len(self.batch_sizes) * len(self.camera_counts) * len(self.target_fps_list)
|
||||
|
||||
|
||||
def load_config(config_path: str) -> BenchmarkConfig:
|
||||
"""从 YAML 文件加载配置"""
|
||||
if not os.path.exists(config_path):
|
||||
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
config = BenchmarkConfig.from_dict(data)
|
||||
errors = config.validate()
|
||||
if errors:
|
||||
raise ValueError(f"配置验证失败:\n" + "\n".join(f" - {e}" for e in errors))
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def save_config(config: BenchmarkConfig, config_path: str):
|
||||
"""保存配置到 YAML 文件"""
|
||||
Path(config_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = config.to_dict()
|
||||
if "synthetic_resolution" in data:
|
||||
data["synthetic_resolution"] = list(data["synthetic_resolution"])
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||||
|
||||
|
||||
def save_default_config(config_path: str = "benchmark_config.yaml"):
|
||||
"""生成并保存默认配置文件"""
|
||||
config = BenchmarkConfig(model_path="./yolo11n.pt")
|
||||
save_config(config, config_path)
|
||||
return config
|
||||
Reference in New Issue
Block a user