Files
Test_AI/batch_test_configurations.py
2026-01-20 11:14:10 +08:00

217 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
批量测试不同配置的性能
自动测试不同批次大小、输入尺寸的组合
"""
import subprocess
import json
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def run_test(batch_size, target_size, max_cameras, duration=30):
"""运行单次测试"""
print(f"\n{'='*60}")
print(f"测试配置: batch_size={batch_size}, target_size={target_size}, cameras={max_cameras}")
print(f"{'='*60}\n")
cmd = [
'python', 'optimized_multi_camera_tensorrt.py',
'--batch-size', str(batch_size),
'--target-size', str(target_size),
'--max-cameras', str(max_cameras),
'--duration', str(duration)
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=duration+30)
# 查找最新的结果文件
result_dir = 'multi_camera_results'
if os.path.exists(result_dir):
files = [f for f in os.listdir(result_dir) if f.startswith('results_') and f.endswith('.json')]
if files:
latest_file = max(files, key=lambda f: os.path.getmtime(os.path.join(result_dir, f)))
with open(os.path.join(result_dir, latest_file), 'r', encoding='utf-8') as f:
data = json.load(f)
return data
return None
except subprocess.TimeoutExpired:
print(f"⚠️ 测试超时")
return None
except Exception as e:
print(f"❌ 测试失败: {e}")
return None
def main():
"""主函数"""
print("批量配置性能测试")
print("=" * 60)
# 测试配置
test_configs = [
# (batch_size, target_size, max_cameras)
(2, 640, 5),
(4, 640, 5),
(8, 640, 5),
(4, 480, 5),
(4, 640, 10),
(8, 640, 10),
]
test_duration = 30 # 每次测试30秒
results = []
for i, (batch_size, target_size, max_cameras) in enumerate(test_configs, 1):
print(f"\n进度: {i}/{len(test_configs)}")
data = run_test(batch_size, target_size, max_cameras, test_duration)
if data:
results.append({
'batch_size': batch_size,
'target_size': target_size,
'max_cameras': max_cameras,
'avg_fps': data['avg_fps'],
'avg_inference_ms': data['avg_inference_ms'],
'p95_inference_ms': data['p95_inference_ms'],
'p99_inference_ms': data['p99_inference_ms'],
'total_frames': data['total_frames']
})
# 等待系统稳定
if i < len(test_configs):
print("\n⏳ 等待系统稳定...")
time.sleep(5)
# 生成报告
if results:
generate_report(results)
else:
print("\n❌ 没有成功的测试结果")
def generate_report(results):
"""生成对比报告"""
print(f"\n{'='*60}")
print("批量测试结果汇总")
print(f"{'='*60}\n")
# 创建DataFrame
df = pd.DataFrame(results)
# 打印表格
print(df.to_string(index=False))
# 保存CSV
output_dir = 'batch_test_results'
os.makedirs(output_dir, exist_ok=True)
csv_file = os.path.join(output_dir, f"batch_test_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv")
df.to_csv(csv_file, index=False, encoding='utf-8-sig')
print(f"\n✅ CSV已保存: {csv_file}")
# 生成可视化
generate_visualizations(df, output_dir)
def generate_visualizations(df, output_dir):
"""生成可视化图表"""
print("\n🎨 生成可视化图表...")
# 图表1: 批次大小 vs FPS固定target_size=640, cameras=5
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
# 筛选数据
df_640_5 = df[(df['target_size'] == 640) & (df['max_cameras'] == 5)]
if not df_640_5.empty:
# FPS vs Batch Size
ax1 = axes[0, 0]
ax1.plot(df_640_5['batch_size'], df_640_5['avg_fps'], 'o-', linewidth=2, markersize=10)
ax1.set_xlabel('批次大小', fontsize=12, fontweight='bold')
ax1.set_ylabel('平均FPS', fontsize=12, fontweight='bold')
ax1.set_title('批次大小 vs FPS (640x640, 5摄像头)', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)
# 添加数值标签
for x, y in zip(df_640_5['batch_size'], df_640_5['avg_fps']):
ax1.text(x, y + 1, f'{y:.1f}', ha='center', va='bottom', fontweight='bold')
# 延迟 vs Batch Size
ax2 = axes[0, 1]
ax2.plot(df_640_5['batch_size'], df_640_5['avg_inference_ms'], 'o-',
linewidth=2, markersize=10, label='平均延迟')
ax2.plot(df_640_5['batch_size'], df_640_5['p95_inference_ms'], 's-',
linewidth=2, markersize=10, label='P95延迟')
ax2.set_xlabel('批次大小', fontsize=12, fontweight='bold')
ax2.set_ylabel('延迟 (ms)', fontsize=12, fontweight='bold')
ax2.set_title('批次大小 vs 延迟 (640x640, 5摄像头)', fontsize=14, fontweight='bold')
ax2.legend()
ax2.grid(True, alpha=0.3)
# 图表2: 摄像头数量 vs FPS固定batch_size=4, target_size=640
df_4_640 = df[(df['batch_size'] == 4) & (df['target_size'] == 640)]
if not df_4_640.empty:
ax3 = axes[1, 0]
ax3.plot(df_4_640['max_cameras'], df_4_640['avg_fps'], 'o-', linewidth=2, markersize=10)
ax3.set_xlabel('摄像头数量', fontsize=12, fontweight='bold')
ax3.set_ylabel('平均FPS', fontsize=12, fontweight='bold')
ax3.set_title('摄像头数量 vs FPS (batch=4, 640x640)', fontsize=14, fontweight='bold')
ax3.grid(True, alpha=0.3)
# 添加数值标签
for x, y in zip(df_4_640['max_cameras'], df_4_640['avg_fps']):
ax3.text(x, y + 1, f'{y:.1f}', ha='center', va='bottom', fontweight='bold')
# 图表3: 输入尺寸对比固定batch_size=4, cameras=5
df_4_5 = df[(df['batch_size'] == 4) & (df['max_cameras'] == 5)]
if not df_4_5.empty:
ax4 = axes[1, 1]
x = range(len(df_4_5))
width = 0.35
ax4.bar([i - width/2 for i in x], df_4_5['avg_fps'], width, label='FPS', alpha=0.8)
ax4.bar([i + width/2 for i in x], df_4_5['avg_inference_ms'], width, label='延迟(ms)', alpha=0.8)
ax4.set_xlabel('输入尺寸', fontsize=12, fontweight='bold')
ax4.set_ylabel('数值', fontsize=12, fontweight='bold')
ax4.set_title('输入尺寸对比 (batch=4, 5摄像头)', fontsize=14, fontweight='bold')
ax4.set_xticks(x)
ax4.set_xticklabels([f"{size}x{size}" for size in df_4_5['target_size']])
ax4.legend()
ax4.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
chart_file = os.path.join(output_dir, f"batch_test_charts_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
plt.savefig(chart_file, dpi=300, bbox_inches='tight')
print(f"✅ 图表已保存: {chart_file}")
plt.show()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⏹️ 测试被用户中断")
except Exception as e:
print(f"\n❌ 测试异常: {e}")
import traceback
traceback.print_exc()