Files
Test_AI/benchmark/config.py

142 lines
4.6 KiB
Python
Raw Permalink Normal View History

2026-01-20 10:54:30 +08:00
"""
配置管理模块
"""
import os
from dataclasses import dataclass, field, asdict
from typing import List, Dict, Any, Tuple
from pathlib import Path
import yaml
@dataclass
class BenchmarkConfig:
"""Benchmark 配置数据类"""
# TensorRT 配置
model_path: str = ""
engine_dir: str = "./engines"
precision: str = "fp16"
# 测试参数
resolutions: List[int] = field(default_factory=lambda: [320, 480])
batch_sizes: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16])
camera_counts: List[int] = field(default_factory=lambda: [1, 2, 5, 10, 15, 30])
target_fps_list: List[float] = field(default_factory=lambda: [5.0, 10.0, 15.0, 20.0])
# 运行时配置
num_cuda_streams: int = 2
frame_queue_size: int = 2
test_duration_sec: int = 60
warmup_sec: int = 5
cooldown_sec: int = 2
# 视频源配置
video_sources: List[Dict[str, str]] = field(default_factory=list)
use_synthetic: bool = True
synthetic_resolution: Tuple[int, int] = (640, 480)
# 输出配置
output_dir: str = "./benchmark_results"
save_json: bool = True
save_csv: bool = True
generate_report: bool = True
# 指标采集配置
metrics_sample_interval_ms: int = 100
# GPU 饱和判定阈值
gpu_saturation_threshold: float = 85.0
memory_saturation_threshold: float = 90.0
latency_increase_threshold: float = 50.0
# 设备配置
device_id: int = 0
def validate(self) -> List[str]:
"""验证配置参数有效性"""
errors = []
if not self.model_path:
errors.append("model_path 不能为空")
elif not os.path.exists(self.model_path):
errors.append(f"模型文件不存在: {self.model_path}")
if self.precision not in ["fp16", "fp32"]:
errors.append(f"无效的精度模式: {self.precision}")
valid_resolutions = [320, 480]
for res in self.resolutions:
if res not in valid_resolutions:
errors.append(f"无效的分辨率: {res}")
valid_batch_sizes = [1, 2, 4, 8, 16]
for bs in self.batch_sizes:
if bs not in valid_batch_sizes:
errors.append(f"无效的 Batch Size: {bs}")
return errors
def to_dict(self) -> Dict[str, Any]:
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BenchmarkConfig":
if "synthetic_resolution" in data and isinstance(data["synthetic_resolution"], list):
data["synthetic_resolution"] = tuple(data["synthetic_resolution"])
return cls(**data)
def get_test_combinations(self) -> List[Dict[str, Any]]:
"""生成所有测试参数组合"""
combinations = []
for resolution in self.resolutions:
for batch_size in self.batch_sizes:
for camera_count in self.camera_counts:
for target_fps in self.target_fps_list:
combinations.append({
"resolution": resolution,
"batch_size": batch_size,
"camera_count": camera_count,
"target_fps": target_fps,
})
return combinations
def get_total_test_count(self) -> int:
return len(self.resolutions) * len(self.batch_sizes) * len(self.camera_counts) * len(self.target_fps_list)
def load_config(config_path: str) -> BenchmarkConfig:
"""从 YAML 文件加载配置"""
if not os.path.exists(config_path):
raise FileNotFoundError(f"配置文件不存在: {config_path}")
with open(config_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
config = BenchmarkConfig.from_dict(data)
errors = config.validate()
if errors:
raise ValueError(f"配置验证失败:\n" + "\n".join(f" - {e}" for e in errors))
return config
def save_config(config: BenchmarkConfig, config_path: str):
"""保存配置到 YAML 文件"""
Path(config_path).parent.mkdir(parents=True, exist_ok=True)
data = config.to_dict()
if "synthetic_resolution" in data:
data["synthetic_resolution"] = list(data["synthetic_resolution"])
with open(config_path, "w", encoding="utf-8") as f:
yaml.dump(data, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
def save_default_config(config_path: str = "benchmark_config.yaml"):
"""生成并保存默认配置文件"""
config = BenchmarkConfig(model_path="./yolo11n.pt")
save_config(config, config_path)
return config