- Engine build scripts (FP16/INT8) - Benchmark validation scripts - Result parsing and analysis tools - COCO dataset configuration
109 lines
4.0 KiB
Python
109 lines
4.0 KiB
Python
"""
|
|
生成性能对比图表
|
|
"""
|
|
|
|
import json
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
def generate_charts():
|
|
results_dir = Path("vehicle_person_benchmark")
|
|
results_file = results_dir / "all_results.json"
|
|
|
|
if not results_file.exists():
|
|
print("请先运行 parse_results.py 生成结果数据")
|
|
return
|
|
|
|
with open(results_file, 'r', encoding='utf-8') as f:
|
|
all_results = json.load(f)
|
|
|
|
# 准备数据
|
|
configs = list(all_results.keys())
|
|
|
|
map50_95 = [all_results[c].get('mAP50_95', 0) for c in configs]
|
|
map50 = [all_results[c].get('mAP50', 0) for c in configs]
|
|
inference = [all_results[c].get('inference_ms', 0) for c in configs]
|
|
fps = [1000/i if i > 0 else 0 for i in inference]
|
|
|
|
person_ap5095 = [all_results[c].get('person', {}).get('ap50_95', 0) for c in configs]
|
|
vehicle_ap5095 = [all_results[c].get('all_vehicles', {}).get('ap50_95', 0) for c in configs]
|
|
|
|
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']
|
|
|
|
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
|
|
fig.suptitle('YOLO11n 人和车辆检测性能对比', fontsize=14, fontweight='bold')
|
|
|
|
# 1. mAP50-95
|
|
ax1 = axes[0, 0]
|
|
bars1 = ax1.bar(configs, map50_95, color=colors)
|
|
ax1.set_title('整体 mAP50-95', fontsize=12)
|
|
ax1.set_ylabel('mAP50-95')
|
|
ax1.set_ylim(0, max(map50_95) * 1.3 if max(map50_95) > 0 else 1)
|
|
for bar, val in zip(bars1, map50_95):
|
|
if val > 0:
|
|
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
|
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
|
|
|
# 2. mAP50
|
|
ax2 = axes[0, 1]
|
|
bars2 = ax2.bar(configs, map50, color=colors)
|
|
ax2.set_title('整体 mAP50', fontsize=12)
|
|
ax2.set_ylabel('mAP50')
|
|
ax2.set_ylim(0, max(map50) * 1.3 if max(map50) > 0 else 1)
|
|
for bar, val in zip(bars2, map50):
|
|
if val > 0:
|
|
ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
|
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
|
|
|
# 3. 推理速度
|
|
ax3 = axes[0, 2]
|
|
bars3 = ax3.bar(configs, inference, color=colors)
|
|
ax3.set_title('推理速度 (ms)', fontsize=12)
|
|
ax3.set_ylabel('推理时间 (ms)')
|
|
for bar, val in zip(bars3, inference):
|
|
if val > 0:
|
|
ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
|
|
f'{val:.1f}ms', ha='center', va='bottom', fontsize=9)
|
|
|
|
# 4. Person mAP50-95
|
|
ax4 = axes[1, 0]
|
|
bars4 = ax4.bar(configs, person_ap5095, color=colors)
|
|
ax4.set_title('Person mAP50-95', fontsize=12)
|
|
ax4.set_ylabel('AP50-95')
|
|
ax4.set_ylim(0, max(person_ap5095) * 1.3 if max(person_ap5095) > 0 else 1)
|
|
for bar, val in zip(bars4, person_ap5095):
|
|
if val > 0:
|
|
ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
|
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
|
|
|
# 5. Vehicle mAP50-95
|
|
ax5 = axes[1, 1]
|
|
bars5 = ax5.bar(configs, vehicle_ap5095, color=colors)
|
|
ax5.set_title('Vehicles mAP50-95', fontsize=12)
|
|
ax5.set_ylabel('AP50-95')
|
|
ax5.set_ylim(0, max(vehicle_ap5095) * 1.3 if max(vehicle_ap5095) > 0 else 1)
|
|
for bar, val in zip(bars5, vehicle_ap5095):
|
|
if val > 0:
|
|
ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
|
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
|
|
|
# 6. FPS
|
|
ax6 = axes[1, 2]
|
|
bars6 = ax6.bar(configs, fps, color=colors)
|
|
ax6.set_title('FPS 对比', fontsize=12)
|
|
ax6.set_ylabel('FPS')
|
|
for bar, val in zip(bars6, fps):
|
|
if val > 0:
|
|
ax6.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
|
|
f'{val:.1f}', ha='center', va='bottom', fontsize=9)
|
|
|
|
plt.tight_layout()
|
|
|
|
chart_file = results_dir / "benchmark_charts.png"
|
|
plt.savefig(chart_file, dpi=150, bbox_inches='tight')
|
|
print(f"图表已保存: {chart_file}")
|
|
|
|
if __name__ == "__main__":
|
|
generate_charts()
|