Files
Test_AI/benchmark/cli.py
2026-01-20 10:54:30 +08:00

279 lines
6.7 KiB
Python

"""
命令行接口模块
"""
import argparse
import sys
from typing import Optional
from .config import BenchmarkConfig, load_config, save_default_config
from .benchmark_runner import BenchmarkRunner
from .utils import setup_logging
logger = setup_logging()
def parse_args() -> argparse.Namespace:
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="FP16 性能评估 Benchmark 框架",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 使用默认配置运行
python -m benchmark --model yolo11n.pt
# 指定配置文件
python -m benchmark --config benchmark_config.yaml
# 快速测试模式
python -m benchmark --model yolo11n.pt --quick
# 生成默认配置文件
python -m benchmark --generate-config
"""
)
# 基本参数
parser.add_argument(
"--config", "-c",
type=str,
help="配置文件路径 (YAML)"
)
parser.add_argument(
"--model", "-m",
type=str,
help="YOLOv11n 模型路径 (.pt)"
)
parser.add_argument(
"--output", "-o",
type=str,
default="./benchmark_results",
help="输出目录 (默认: ./benchmark_results)"
)
# 测试参数
parser.add_argument(
"--resolutions",
type=int,
nargs="+",
default=[320, 480],
help="测试分辨率列表 (默认: 320 480)"
)
parser.add_argument(
"--batch-sizes",
type=int,
nargs="+",
default=[1, 2, 4, 8, 16],
help="Batch Size 列表 (默认: 1 2 4 8 16)"
)
parser.add_argument(
"--camera-counts",
type=int,
nargs="+",
default=[1, 2, 5, 10, 15, 30],
help="摄像头数量列表 (默认: 1 2 5 10 15 30)"
)
parser.add_argument(
"--fps-list",
type=float,
nargs="+",
default=[5.0, 10.0, 15.0, 20.0],
help="目标帧率列表 (默认: 5 10 15 20)"
)
# 运行时参数
parser.add_argument(
"--duration", "-d",
type=int,
default=60,
help="每次测试持续时间(秒) (默认: 60)"
)
parser.add_argument(
"--warmup",
type=int,
default=5,
help="预热时间(秒) (默认: 5)"
)
parser.add_argument(
"--streams",
type=int,
default=2,
help="CUDA Stream 数量 (默认: 2)"
)
parser.add_argument(
"--device",
type=int,
default=0,
help="GPU 设备 ID (默认: 0)"
)
# 快捷模式
parser.add_argument(
"--quick", "-q",
action="store_true",
help="快速测试模式 (减少测试组合)"
)
parser.add_argument(
"--single",
action="store_true",
help="单次测试模式"
)
# 工具命令
parser.add_argument(
"--generate-config",
action="store_true",
help="生成默认配置文件"
)
parser.add_argument(
"--list-tests",
action="store_true",
help="列出所有测试组合"
)
# 输出控制
parser.add_argument(
"--no-json",
action="store_true",
help="不输出 JSON 结果"
)
parser.add_argument(
"--no-csv",
action="store_true",
help="不输出 CSV 结果"
)
parser.add_argument(
"--no-report",
action="store_true",
help="不生成 Markdown 报告"
)
parser.add_argument(
"--verbose", "-v",
action="store_true",
help="详细输出"
)
return parser.parse_args()
def build_config(args: argparse.Namespace) -> BenchmarkConfig:
"""根据命令行参数构建配置"""
# 从配置文件加载
if args.config:
config = load_config(args.config)
else:
config = BenchmarkConfig()
# 命令行参数覆盖
if args.model:
config.model_path = args.model
if args.output:
config.output_dir = args.output
if args.resolutions:
config.resolutions = args.resolutions
if args.batch_sizes:
config.batch_sizes = args.batch_sizes
if args.camera_counts:
config.camera_counts = args.camera_counts
if args.fps_list:
config.target_fps_list = args.fps_list
if args.duration:
config.test_duration_sec = args.duration
if args.warmup:
config.warmup_sec = args.warmup
if args.streams:
config.num_cuda_streams = args.streams
if args.device is not None:
config.device_id = args.device
# 快速模式
if args.quick:
config.resolutions = [320]
config.batch_sizes = [1, 8]
config.camera_counts = [1, 10]
config.target_fps_list = [10.0]
config.test_duration_sec = 30
config.warmup_sec = 3
# 单次测试模式
if args.single:
config.resolutions = [320]
config.batch_sizes = [8]
config.camera_counts = [10]
config.target_fps_list = [10.0]
# 输出控制
config.save_json = not args.no_json
config.save_csv = not args.no_csv
config.generate_report = not args.no_report
return config
def main():
"""主入口函数"""
args = parse_args()
# 生成默认配置
if args.generate_config:
config_path = "benchmark_config.yaml"
save_default_config(config_path)
print(f"默认配置已生成: {config_path}")
return 0
# 构建配置
try:
config = build_config(args)
except Exception as e:
logger.error(f"配置错误: {e}")
return 1
# 验证配置
errors = config.validate()
if errors:
logger.error("配置验证失败:")
for err in errors:
logger.error(f" - {err}")
return 1
# 列出测试组合
if args.list_tests:
combinations = config.get_test_combinations()
print(f"\n{len(combinations)} 个测试组合:\n")
for i, combo in enumerate(combinations, 1):
print(f" {i:3d}. {combo['resolution']}x{combo['resolution']}, "
f"batch={combo['batch_size']}, "
f"cameras={combo['camera_count']}, "
f"fps={combo['target_fps']}")
return 0
# 运行 Benchmark
runner = BenchmarkRunner(config)
try:
runner.run()
return 0
except KeyboardInterrupt:
logger.warning("用户中断")
return 130
except Exception as e:
logger.error(f"执行失败: {e}")
if args.verbose:
import traceback
traceback.print_exc()
return 1
if __name__ == "__main__":
sys.exit(main())