217 lines
7.4 KiB
Python
217 lines
7.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
批量测试不同配置的性能
|
||
自动测试不同批次大小、输入尺寸的组合
|
||
"""
|
||
|
||
import subprocess
|
||
import json
|
||
import os
|
||
import time
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
from datetime import datetime
|
||
|
||
# 设置中文字体
|
||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
|
||
def run_test(batch_size, target_size, max_cameras, duration=30):
|
||
"""运行单次测试"""
|
||
print(f"\n{'='*60}")
|
||
print(f"测试配置: batch_size={batch_size}, target_size={target_size}, cameras={max_cameras}")
|
||
print(f"{'='*60}\n")
|
||
|
||
cmd = [
|
||
'python', 'optimized_multi_camera_tensorrt.py',
|
||
'--batch-size', str(batch_size),
|
||
'--target-size', str(target_size),
|
||
'--max-cameras', str(max_cameras),
|
||
'--duration', str(duration)
|
||
]
|
||
|
||
try:
|
||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=duration+30)
|
||
|
||
# 查找最新的结果文件
|
||
result_dir = 'multi_camera_results'
|
||
if os.path.exists(result_dir):
|
||
files = [f for f in os.listdir(result_dir) if f.startswith('results_') and f.endswith('.json')]
|
||
if files:
|
||
latest_file = max(files, key=lambda f: os.path.getmtime(os.path.join(result_dir, f)))
|
||
with open(os.path.join(result_dir, latest_file), 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
return data
|
||
|
||
return None
|
||
|
||
except subprocess.TimeoutExpired:
|
||
print(f"⚠️ 测试超时")
|
||
return None
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
return None
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("批量配置性能测试")
|
||
print("=" * 60)
|
||
|
||
# 测试配置
|
||
test_configs = [
|
||
# (batch_size, target_size, max_cameras)
|
||
(2, 640, 5),
|
||
(4, 640, 5),
|
||
(8, 640, 5),
|
||
(4, 480, 5),
|
||
(4, 640, 10),
|
||
(8, 640, 10),
|
||
]
|
||
|
||
test_duration = 30 # 每次测试30秒
|
||
|
||
results = []
|
||
|
||
for i, (batch_size, target_size, max_cameras) in enumerate(test_configs, 1):
|
||
print(f"\n进度: {i}/{len(test_configs)}")
|
||
|
||
data = run_test(batch_size, target_size, max_cameras, test_duration)
|
||
|
||
if data:
|
||
results.append({
|
||
'batch_size': batch_size,
|
||
'target_size': target_size,
|
||
'max_cameras': max_cameras,
|
||
'avg_fps': data['avg_fps'],
|
||
'avg_inference_ms': data['avg_inference_ms'],
|
||
'p95_inference_ms': data['p95_inference_ms'],
|
||
'p99_inference_ms': data['p99_inference_ms'],
|
||
'total_frames': data['total_frames']
|
||
})
|
||
|
||
# 等待系统稳定
|
||
if i < len(test_configs):
|
||
print("\n⏳ 等待系统稳定...")
|
||
time.sleep(5)
|
||
|
||
# 生成报告
|
||
if results:
|
||
generate_report(results)
|
||
else:
|
||
print("\n❌ 没有成功的测试结果")
|
||
|
||
|
||
def generate_report(results):
|
||
"""生成对比报告"""
|
||
print(f"\n{'='*60}")
|
||
print("批量测试结果汇总")
|
||
print(f"{'='*60}\n")
|
||
|
||
# 创建DataFrame
|
||
df = pd.DataFrame(results)
|
||
|
||
# 打印表格
|
||
print(df.to_string(index=False))
|
||
|
||
# 保存CSV
|
||
output_dir = 'batch_test_results'
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
csv_file = os.path.join(output_dir, f"batch_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")
|
||
df.to_csv(csv_file, index=False, encoding='utf-8-sig')
|
||
print(f"\n✅ CSV已保存: {csv_file}")
|
||
|
||
# 生成可视化
|
||
generate_visualizations(df, output_dir)
|
||
|
||
|
||
def generate_visualizations(df, output_dir):
|
||
"""生成可视化图表"""
|
||
print("\n🎨 生成可视化图表...")
|
||
|
||
# 图表1: 批次大小 vs FPS(固定target_size=640, cameras=5)
|
||
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
|
||
|
||
# 筛选数据
|
||
df_640_5 = df[(df['target_size'] == 640) & (df['max_cameras'] == 5)]
|
||
|
||
if not df_640_5.empty:
|
||
# FPS vs Batch Size
|
||
ax1 = axes[0, 0]
|
||
ax1.plot(df_640_5['batch_size'], df_640_5['avg_fps'], 'o-', linewidth=2, markersize=10)
|
||
ax1.set_xlabel('批次大小', fontsize=12, fontweight='bold')
|
||
ax1.set_ylabel('平均FPS', fontsize=12, fontweight='bold')
|
||
ax1.set_title('批次大小 vs FPS (640x640, 5摄像头)', fontsize=14, fontweight='bold')
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# 添加数值标签
|
||
for x, y in zip(df_640_5['batch_size'], df_640_5['avg_fps']):
|
||
ax1.text(x, y + 1, f'{y:.1f}', ha='center', va='bottom', fontweight='bold')
|
||
|
||
# 延迟 vs Batch Size
|
||
ax2 = axes[0, 1]
|
||
ax2.plot(df_640_5['batch_size'], df_640_5['avg_inference_ms'], 'o-',
|
||
linewidth=2, markersize=10, label='平均延迟')
|
||
ax2.plot(df_640_5['batch_size'], df_640_5['p95_inference_ms'], 's-',
|
||
linewidth=2, markersize=10, label='P95延迟')
|
||
ax2.set_xlabel('批次大小', fontsize=12, fontweight='bold')
|
||
ax2.set_ylabel('延迟 (ms)', fontsize=12, fontweight='bold')
|
||
ax2.set_title('批次大小 vs 延迟 (640x640, 5摄像头)', fontsize=14, fontweight='bold')
|
||
ax2.legend()
|
||
ax2.grid(True, alpha=0.3)
|
||
|
||
# 图表2: 摄像头数量 vs FPS(固定batch_size=4, target_size=640)
|
||
df_4_640 = df[(df['batch_size'] == 4) & (df['target_size'] == 640)]
|
||
|
||
if not df_4_640.empty:
|
||
ax3 = axes[1, 0]
|
||
ax3.plot(df_4_640['max_cameras'], df_4_640['avg_fps'], 'o-', linewidth=2, markersize=10)
|
||
ax3.set_xlabel('摄像头数量', fontsize=12, fontweight='bold')
|
||
ax3.set_ylabel('平均FPS', fontsize=12, fontweight='bold')
|
||
ax3.set_title('摄像头数量 vs FPS (batch=4, 640x640)', fontsize=14, fontweight='bold')
|
||
ax3.grid(True, alpha=0.3)
|
||
|
||
# 添加数值标签
|
||
for x, y in zip(df_4_640['max_cameras'], df_4_640['avg_fps']):
|
||
ax3.text(x, y + 1, f'{y:.1f}', ha='center', va='bottom', fontweight='bold')
|
||
|
||
# 图表3: 输入尺寸对比(固定batch_size=4, cameras=5)
|
||
df_4_5 = df[(df['batch_size'] == 4) & (df['max_cameras'] == 5)]
|
||
|
||
if not df_4_5.empty:
|
||
ax4 = axes[1, 1]
|
||
x = range(len(df_4_5))
|
||
width = 0.35
|
||
|
||
ax4.bar([i - width/2 for i in x], df_4_5['avg_fps'], width, label='FPS', alpha=0.8)
|
||
ax4.bar([i + width/2 for i in x], df_4_5['avg_inference_ms'], width, label='延迟(ms)', alpha=0.8)
|
||
|
||
ax4.set_xlabel('输入尺寸', fontsize=12, fontweight='bold')
|
||
ax4.set_ylabel('数值', fontsize=12, fontweight='bold')
|
||
ax4.set_title('输入尺寸对比 (batch=4, 5摄像头)', fontsize=14, fontweight='bold')
|
||
ax4.set_xticks(x)
|
||
ax4.set_xticklabels([f"{size}x{size}" for size in df_4_5['target_size']])
|
||
ax4.legend()
|
||
ax4.grid(True, alpha=0.3, axis='y')
|
||
|
||
plt.tight_layout()
|
||
|
||
chart_file = os.path.join(output_dir, f"batch_test_charts_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
|
||
plt.savefig(chart_file, dpi=300, bbox_inches='tight')
|
||
print(f"✅ 图表已保存: {chart_file}")
|
||
|
||
plt.show()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
main()
|
||
except KeyboardInterrupt:
|
||
print("\n\n⏹️ 测试被用户中断")
|
||
except Exception as e:
|
||
print(f"\n❌ 测试异常: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|