Files
Test_AI/batch_test_configurations.py

217 lines
7.4 KiB
Python
Raw Normal View History

2026-01-20 11:14:10 +08:00
#!/usr/bin/env python3
"""
批量测试不同配置的性能
自动测试不同批次大小输入尺寸的组合
"""
import subprocess
import json
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def run_test(batch_size, target_size, max_cameras, duration=30):
"""运行单次测试"""
print(f"\n{'='*60}")
print(f"测试配置: batch_size={batch_size}, target_size={target_size}, cameras={max_cameras}")
print(f"{'='*60}\n")
cmd = [
'python', 'optimized_multi_camera_tensorrt.py',
'--batch-size', str(batch_size),
'--target-size', str(target_size),
'--max-cameras', str(max_cameras),
'--duration', str(duration)
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=duration+30)
# 查找最新的结果文件
result_dir = 'multi_camera_results'
if os.path.exists(result_dir):
files = [f for f in os.listdir(result_dir) if f.startswith('results_') and f.endswith('.json')]
if files:
latest_file = max(files, key=lambda f: os.path.getmtime(os.path.join(result_dir, f)))
with open(os.path.join(result_dir, latest_file), 'r', encoding='utf-8') as f:
data = json.load(f)
return data
return None
except subprocess.TimeoutExpired:
print(f"⚠️ 测试超时")
return None
except Exception as e:
print(f"❌ 测试失败: {e}")
return None
def main():
"""主函数"""
print("批量配置性能测试")
print("=" * 60)
# 测试配置
test_configs = [
# (batch_size, target_size, max_cameras)
(2, 640, 5),
(4, 640, 5),
(8, 640, 5),
(4, 480, 5),
(4, 640, 10),
(8, 640, 10),
]
test_duration = 30 # 每次测试30秒
results = []
for i, (batch_size, target_size, max_cameras) in enumerate(test_configs, 1):
print(f"\n进度: {i}/{len(test_configs)}")
data = run_test(batch_size, target_size, max_cameras, test_duration)
if data:
results.append({
'batch_size': batch_size,
'target_size': target_size,
'max_cameras': max_cameras,
'avg_fps': data['avg_fps'],
'avg_inference_ms': data['avg_inference_ms'],
'p95_inference_ms': data['p95_inference_ms'],
'p99_inference_ms': data['p99_inference_ms'],
'total_frames': data['total_frames']
})
# 等待系统稳定
if i < len(test_configs):
print("\n⏳ 等待系统稳定...")
time.sleep(5)
# 生成报告
if results:
generate_report(results)
else:
print("\n❌ 没有成功的测试结果")
def generate_report(results):
"""生成对比报告"""
print(f"\n{'='*60}")
print("批量测试结果汇总")
print(f"{'='*60}\n")
# 创建DataFrame
df = pd.DataFrame(results)
# 打印表格
print(df.to_string(index=False))
# 保存CSV
output_dir = 'batch_test_results'
os.makedirs(output_dir, exist_ok=True)
csv_file = os.path.join(output_dir, f"batch_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")
df.to_csv(csv_file, index=False, encoding='utf-8-sig')
print(f"\n✅ CSV已保存: {csv_file}")
# 生成可视化
generate_visualizations(df, output_dir)
def generate_visualizations(df, output_dir):
"""生成可视化图表"""
print("\n🎨 生成可视化图表...")
# 图表1: 批次大小 vs FPS固定target_size=640, cameras=5
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
# 筛选数据
df_640_5 = df[(df['target_size'] == 640) & (df['max_cameras'] == 5)]
if not df_640_5.empty:
# FPS vs Batch Size
ax1 = axes[0, 0]
ax1.plot(df_640_5['batch_size'], df_640_5['avg_fps'], 'o-', linewidth=2, markersize=10)
ax1.set_xlabel('批次大小', fontsize=12, fontweight='bold')
ax1.set_ylabel('平均FPS', fontsize=12, fontweight='bold')
ax1.set_title('批次大小 vs FPS (640x640, 5摄像头)', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
# 添加数值标签
for x, y in zip(df_640_5['batch_size'], df_640_5['avg_fps']):
ax1.text(x, y + 1, f'{y:.1f}', ha='center', va='bottom', fontweight='bold')
# 延迟 vs Batch Size
ax2 = axes[0, 1]
ax2.plot(df_640_5['batch_size'], df_640_5['avg_inference_ms'], 'o-',
linewidth=2, markersize=10, label='平均延迟')
ax2.plot(df_640_5['batch_size'], df_640_5['p95_inference_ms'], 's-',
linewidth=2, markersize=10, label='P95延迟')
ax2.set_xlabel('批次大小', fontsize=12, fontweight='bold')
ax2.set_ylabel('延迟 (ms)', fontsize=12, fontweight='bold')
ax2.set_title('批次大小 vs 延迟 (640x640, 5摄像头)', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 图表2: 摄像头数量 vs FPS固定batch_size=4, target_size=640
df_4_640 = df[(df['batch_size'] == 4) & (df['target_size'] == 640)]
if not df_4_640.empty:
ax3 = axes[1, 0]
ax3.plot(df_4_640['max_cameras'], df_4_640['avg_fps'], 'o-', linewidth=2, markersize=10)
ax3.set_xlabel('摄像头数量', fontsize=12, fontweight='bold')
ax3.set_ylabel('平均FPS', fontsize=12, fontweight='bold')
ax3.set_title('摄像头数量 vs FPS (batch=4, 640x640)', fontsize=14, fontweight='bold')
ax3.grid(True, alpha=0.3)
# 添加数值标签
for x, y in zip(df_4_640['max_cameras'], df_4_640['avg_fps']):
ax3.text(x, y + 1, f'{y:.1f}', ha='center', va='bottom', fontweight='bold')
# 图表3: 输入尺寸对比固定batch_size=4, cameras=5
df_4_5 = df[(df['batch_size'] == 4) & (df['max_cameras'] == 5)]
if not df_4_5.empty:
ax4 = axes[1, 1]
x = range(len(df_4_5))
width = 0.35
ax4.bar([i - width/2 for i in x], df_4_5['avg_fps'], width, label='FPS', alpha=0.8)
ax4.bar([i + width/2 for i in x], df_4_5['avg_inference_ms'], width, label='延迟(ms)', alpha=0.8)
ax4.set_xlabel('输入尺寸', fontsize=12, fontweight='bold')
ax4.set_ylabel('数值', fontsize=12, fontweight='bold')
ax4.set_title('输入尺寸对比 (batch=4, 5摄像头)', fontsize=14, fontweight='bold')
ax4.set_xticks(x)
ax4.set_xticklabels([f"{size}x{size}" for size in df_4_5['target_size']])
ax4.legend()
ax4.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
chart_file = os.path.join(output_dir, f"batch_test_charts_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
plt.savefig(chart_file, dpi=300, bbox_inches='tight')
print(f"✅ 图表已保存: {chart_file}")
plt.show()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⏹️ 测试被用户中断")
except Exception as e:
print(f"\n❌ 测试异常: {e}")
import traceback
traceback.print_exc()