142 lines
4.6 KiB
Python
142 lines
4.6 KiB
Python
|
|
"""
|
||
|
|
配置管理模块
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
from dataclasses import dataclass, field, asdict
|
||
|
|
from typing import List, Dict, Any, Tuple
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import yaml
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class BenchmarkConfig:
|
||
|
|
"""Benchmark 配置数据类"""
|
||
|
|
|
||
|
|
# TensorRT 配置
|
||
|
|
model_path: str = ""
|
||
|
|
engine_dir: str = "./engines"
|
||
|
|
precision: str = "fp16"
|
||
|
|
|
||
|
|
# 测试参数
|
||
|
|
resolutions: List[int] = field(default_factory=lambda: [320, 480])
|
||
|
|
batch_sizes: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16])
|
||
|
|
camera_counts: List[int] = field(default_factory=lambda: [1, 2, 5, 10, 15, 30])
|
||
|
|
target_fps_list: List[float] = field(default_factory=lambda: [5.0, 10.0, 15.0, 20.0])
|
||
|
|
|
||
|
|
# 运行时配置
|
||
|
|
num_cuda_streams: int = 2
|
||
|
|
frame_queue_size: int = 2
|
||
|
|
test_duration_sec: int = 60
|
||
|
|
warmup_sec: int = 5
|
||
|
|
cooldown_sec: int = 2
|
||
|
|
|
||
|
|
# 视频源配置
|
||
|
|
video_sources: List[Dict[str, str]] = field(default_factory=list)
|
||
|
|
use_synthetic: bool = True
|
||
|
|
synthetic_resolution: Tuple[int, int] = (640, 480)
|
||
|
|
|
||
|
|
# 输出配置
|
||
|
|
output_dir: str = "./benchmark_results"
|
||
|
|
save_json: bool = True
|
||
|
|
save_csv: bool = True
|
||
|
|
generate_report: bool = True
|
||
|
|
|
||
|
|
# 指标采集配置
|
||
|
|
metrics_sample_interval_ms: int = 100
|
||
|
|
|
||
|
|
# GPU 饱和判定阈值
|
||
|
|
gpu_saturation_threshold: float = 85.0
|
||
|
|
memory_saturation_threshold: float = 90.0
|
||
|
|
latency_increase_threshold: float = 50.0
|
||
|
|
|
||
|
|
# 设备配置
|
||
|
|
device_id: int = 0
|
||
|
|
|
||
|
|
def validate(self) -> List[str]:
|
||
|
|
"""验证配置参数有效性"""
|
||
|
|
errors = []
|
||
|
|
|
||
|
|
if not self.model_path:
|
||
|
|
errors.append("model_path 不能为空")
|
||
|
|
elif not os.path.exists(self.model_path):
|
||
|
|
errors.append(f"模型文件不存在: {self.model_path}")
|
||
|
|
|
||
|
|
if self.precision not in ["fp16", "fp32"]:
|
||
|
|
errors.append(f"无效的精度模式: {self.precision}")
|
||
|
|
|
||
|
|
valid_resolutions = [320, 480]
|
||
|
|
for res in self.resolutions:
|
||
|
|
if res not in valid_resolutions:
|
||
|
|
errors.append(f"无效的分辨率: {res}")
|
||
|
|
|
||
|
|
valid_batch_sizes = [1, 2, 4, 8, 16]
|
||
|
|
for bs in self.batch_sizes:
|
||
|
|
if bs not in valid_batch_sizes:
|
||
|
|
errors.append(f"无效的 Batch Size: {bs}")
|
||
|
|
|
||
|
|
return errors
|
||
|
|
|
||
|
|
def to_dict(self) -> Dict[str, Any]:
|
||
|
|
return asdict(self)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_dict(cls, data: Dict[str, Any]) -> "BenchmarkConfig":
|
||
|
|
if "synthetic_resolution" in data and isinstance(data["synthetic_resolution"], list):
|
||
|
|
data["synthetic_resolution"] = tuple(data["synthetic_resolution"])
|
||
|
|
return cls(**data)
|
||
|
|
|
||
|
|
def get_test_combinations(self) -> List[Dict[str, Any]]:
|
||
|
|
"""生成所有测试参数组合"""
|
||
|
|
combinations = []
|
||
|
|
for resolution in self.resolutions:
|
||
|
|
for batch_size in self.batch_sizes:
|
||
|
|
for camera_count in self.camera_counts:
|
||
|
|
for target_fps in self.target_fps_list:
|
||
|
|
combinations.append({
|
||
|
|
"resolution": resolution,
|
||
|
|
"batch_size": batch_size,
|
||
|
|
"camera_count": camera_count,
|
||
|
|
"target_fps": target_fps,
|
||
|
|
})
|
||
|
|
return combinations
|
||
|
|
|
||
|
|
def get_total_test_count(self) -> int:
|
||
|
|
return len(self.resolutions) * len(self.batch_sizes) * len(self.camera_counts) * len(self.target_fps_list)
|
||
|
|
|
||
|
|
|
||
|
|
def load_config(config_path: str) -> BenchmarkConfig:
|
||
|
|
"""从 YAML 文件加载配置"""
|
||
|
|
if not os.path.exists(config_path):
|
||
|
|
raise FileNotFoundError(f"配置文件不存在: {config_path}")
|
||
|
|
|
||
|
|
with open(config_path, "r", encoding="utf-8") as f:
|
||
|
|
data = yaml.safe_load(f)
|
||
|
|
|
||
|
|
config = BenchmarkConfig.from_dict(data)
|
||
|
|
errors = config.validate()
|
||
|
|
if errors:
|
||
|
|
raise ValueError(f"配置验证失败:\n" + "\n".join(f" - {e}" for e in errors))
|
||
|
|
|
||
|
|
return config
|
||
|
|
|
||
|
|
|
||
|
|
def save_config(config: BenchmarkConfig, config_path: str):
|
||
|
|
"""保存配置到 YAML 文件"""
|
||
|
|
Path(config_path).parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
data = config.to_dict()
|
||
|
|
if "synthetic_resolution" in data:
|
||
|
|
data["synthetic_resolution"] = list(data["synthetic_resolution"])
|
||
|
|
|
||
|
|
with open(config_path, "w", encoding="utf-8") as f:
|
||
|
|
yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
||
|
|
|
||
|
|
|
||
|
|
def save_default_config(config_path: str = "benchmark_config.yaml"):
|
||
|
|
"""生成并保存默认配置文件"""
|
||
|
|
config = BenchmarkConfig(model_path="./yolo11n.pt")
|
||
|
|
save_config(config, config_path)
|
||
|
|
return config
|