""" 配置管理模块 """ import os from dataclasses import dataclass, field, asdict from typing import List, Dict, Any, Tuple from pathlib import Path import yaml @dataclass class BenchmarkConfig: """Benchmark 配置数据类""" # TensorRT 配置 model_path: str = "" engine_dir: str = "./engines" precision: str = "fp16" # 测试参数 resolutions: List[int] = field(default_factory=lambda: [320, 480]) batch_sizes: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16]) camera_counts: List[int] = field(default_factory=lambda: [1, 2, 5, 10, 15, 30]) target_fps_list: List[float] = field(default_factory=lambda: [5.0, 10.0, 15.0, 20.0]) # 运行时配置 num_cuda_streams: int = 2 frame_queue_size: int = 2 test_duration_sec: int = 60 warmup_sec: int = 5 cooldown_sec: int = 2 # 视频源配置 video_sources: List[Dict[str, str]] = field(default_factory=list) use_synthetic: bool = True synthetic_resolution: Tuple[int, int] = (640, 480) # 输出配置 output_dir: str = "./benchmark_results" save_json: bool = True save_csv: bool = True generate_report: bool = True # 指标采集配置 metrics_sample_interval_ms: int = 100 # GPU 饱和判定阈值 gpu_saturation_threshold: float = 85.0 memory_saturation_threshold: float = 90.0 latency_increase_threshold: float = 50.0 # 设备配置 device_id: int = 0 def validate(self) -> List[str]: """验证配置参数有效性""" errors = [] if not self.model_path: errors.append("model_path 不能为空") elif not os.path.exists(self.model_path): errors.append(f"模型文件不存在: {self.model_path}") if self.precision not in ["fp16", "fp32"]: errors.append(f"无效的精度模式: {self.precision}") valid_resolutions = [320, 480] for res in self.resolutions: if res not in valid_resolutions: errors.append(f"无效的分辨率: {res}") valid_batch_sizes = [1, 2, 4, 8, 16] for bs in self.batch_sizes: if bs not in valid_batch_sizes: errors.append(f"无效的 Batch Size: {bs}") return errors def to_dict(self) -> Dict[str, Any]: return asdict(self) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "BenchmarkConfig": if "synthetic_resolution" in data and isinstance(data["synthetic_resolution"], list): data["synthetic_resolution"] = tuple(data["synthetic_resolution"]) return cls(**data) def get_test_combinations(self) -> List[Dict[str, Any]]: """生成所有测试参数组合""" combinations = [] for resolution in self.resolutions: for batch_size in self.batch_sizes: for camera_count in self.camera_counts: for target_fps in self.target_fps_list: combinations.append({ "resolution": resolution, "batch_size": batch_size, "camera_count": camera_count, "target_fps": target_fps, }) return combinations def get_total_test_count(self) -> int: return len(self.resolutions) * len(self.batch_sizes) * len(self.camera_counts) * len(self.target_fps_list) def load_config(config_path: str) -> BenchmarkConfig: """从 YAML 文件加载配置""" if not os.path.exists(config_path): raise FileNotFoundError(f"配置文件不存在: {config_path}") with open(config_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) config = BenchmarkConfig.from_dict(data) errors = config.validate() if errors: raise ValueError(f"配置验证失败:\n" + "\n".join(f" - {e}" for e in errors)) return config def save_config(config: BenchmarkConfig, config_path: str): """保存配置到 YAML 文件""" Path(config_path).parent.mkdir(parents=True, exist_ok=True) data = config.to_dict() if "synthetic_resolution" in data: data["synthetic_resolution"] = list(data["synthetic_resolution"]) with open(config_path, "w", encoding="utf-8") as f: yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False) def save_default_config(config_path: str = "benchmark_config.yaml"): """生成并保存默认配置文件""" config = BenchmarkConfig(model_path="./yolo11n.pt") save_config(config, config_path) return config