""" 生成性能对比图表 """ import json import matplotlib.pyplot as plt import numpy as np from pathlib import Path def generate_charts(): results_dir = Path("vehicle_person_benchmark") results_file = results_dir / "all_results.json" if not results_file.exists(): print("请先运行 parse_results.py 生成结果数据") return with open(results_file, 'r', encoding='utf-8') as f: all_results = json.load(f) # 准备数据 configs = list(all_results.keys()) map50_95 = [all_results[c].get('mAP50_95', 0) for c in configs] map50 = [all_results[c].get('mAP50', 0) for c in configs] inference = [all_results[c].get('inference_ms', 0) for c in configs] fps = [1000/i if i > 0 else 0 for i in inference] person_ap5095 = [all_results[c].get('person', {}).get('ap50_95', 0) for c in configs] vehicle_ap5095 = [all_results[c].get('all_vehicles', {}).get('ap50_95', 0) for c in configs] colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6'] fig, axes = plt.subplots(2, 3, figsize=(16, 10)) fig.suptitle('YOLO11n 人和车辆检测性能对比', fontsize=14, fontweight='bold') # 1. mAP50-95 ax1 = axes[0, 0] bars1 = ax1.bar(configs, map50_95, color=colors) ax1.set_title('整体 mAP50-95', fontsize=12) ax1.set_ylabel('mAP50-95') ax1.set_ylim(0, max(map50_95) * 1.3 if max(map50_95) > 0 else 1) for bar, val in zip(bars1, map50_95): if val > 0: ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{val:.3f}', ha='center', va='bottom', fontsize=9) # 2. mAP50 ax2 = axes[0, 1] bars2 = ax2.bar(configs, map50, color=colors) ax2.set_title('整体 mAP50', fontsize=12) ax2.set_ylabel('mAP50') ax2.set_ylim(0, max(map50) * 1.3 if max(map50) > 0 else 1) for bar, val in zip(bars2, map50): if val > 0: ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{val:.3f}', ha='center', va='bottom', fontsize=9) # 3. 推理速度 ax3 = axes[0, 2] bars3 = ax3.bar(configs, inference, color=colors) ax3.set_title('推理速度 (ms)', fontsize=12) ax3.set_ylabel('推理时间 (ms)') for bar, val in zip(bars3, inference): if val > 0: ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, f'{val:.1f}ms', ha='center', va='bottom', fontsize=9) # 4. Person mAP50-95 ax4 = axes[1, 0] bars4 = ax4.bar(configs, person_ap5095, color=colors) ax4.set_title('Person mAP50-95', fontsize=12) ax4.set_ylabel('AP50-95') ax4.set_ylim(0, max(person_ap5095) * 1.3 if max(person_ap5095) > 0 else 1) for bar, val in zip(bars4, person_ap5095): if val > 0: ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{val:.3f}', ha='center', va='bottom', fontsize=9) # 5. Vehicle mAP50-95 ax5 = axes[1, 1] bars5 = ax5.bar(configs, vehicle_ap5095, color=colors) ax5.set_title('Vehicles mAP50-95', fontsize=12) ax5.set_ylabel('AP50-95') ax5.set_ylim(0, max(vehicle_ap5095) * 1.3 if max(vehicle_ap5095) > 0 else 1) for bar, val in zip(bars5, vehicle_ap5095): if val > 0: ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{val:.3f}', ha='center', va='bottom', fontsize=9) # 6. FPS ax6 = axes[1, 2] bars6 = ax6.bar(configs, fps, color=colors) ax6.set_title('FPS 对比', fontsize=12) ax6.set_ylabel('FPS') for bar, val in zip(bars6, fps): if val > 0: ax6.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, f'{val:.1f}', ha='center', va='bottom', fontsize=9) plt.tight_layout() chart_file = results_dir / "benchmark_charts.png" plt.savefig(chart_file, dpi=150, bbox_inches='tight') print(f"图表已保存: {chart_file}") if __name__ == "__main__": generate_charts()