Files
Test_AI/generate_charts.py
16337 942244bd88 Add YOLO11 TensorRT quantization benchmark scripts
- Engine build scripts (FP16/INT8)
- Benchmark validation scripts
- Result parsing and analysis tools
- COCO dataset configuration
2026-01-29 13:59:42 +08:00

109 lines
4.0 KiB
Python

"""
生成性能对比图表
"""
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
def generate_charts():
results_dir = Path("vehicle_person_benchmark")
results_file = results_dir / "all_results.json"
if not results_file.exists():
print("请先运行 parse_results.py 生成结果数据")
return
with open(results_file, 'r', encoding='utf-8') as f:
all_results = json.load(f)
# 准备数据
configs = list(all_results.keys())
map50_95 = [all_results[c].get('mAP50_95', 0) for c in configs]
map50 = [all_results[c].get('mAP50', 0) for c in configs]
inference = [all_results[c].get('inference_ms', 0) for c in configs]
fps = [1000/i if i > 0 else 0 for i in inference]
person_ap5095 = [all_results[c].get('person', {}).get('ap50_95', 0) for c in configs]
vehicle_ap5095 = [all_results[c].get('all_vehicles', {}).get('ap50_95', 0) for c in configs]
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
fig.suptitle('YOLO11n 人和车辆检测性能对比', fontsize=14, fontweight='bold')
# 1. mAP50-95
ax1 = axes[0, 0]
bars1 = ax1.bar(configs, map50_95, color=colors)
ax1.set_title('整体 mAP50-95', fontsize=12)
ax1.set_ylabel('mAP50-95')
ax1.set_ylim(0, max(map50_95) * 1.3 if max(map50_95) > 0 else 1)
for bar, val in zip(bars1, map50_95):
if val > 0:
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
# 2. mAP50
ax2 = axes[0, 1]
bars2 = ax2.bar(configs, map50, color=colors)
ax2.set_title('整体 mAP50', fontsize=12)
ax2.set_ylabel('mAP50')
ax2.set_ylim(0, max(map50) * 1.3 if max(map50) > 0 else 1)
for bar, val in zip(bars2, map50):
if val > 0:
ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
# 3. 推理速度
ax3 = axes[0, 2]
bars3 = ax3.bar(configs, inference, color=colors)
ax3.set_title('推理速度 (ms)', fontsize=12)
ax3.set_ylabel('推理时间 (ms)')
for bar, val in zip(bars3, inference):
if val > 0:
ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
f'{val:.1f}ms', ha='center', va='bottom', fontsize=9)
# 4. Person mAP50-95
ax4 = axes[1, 0]
bars4 = ax4.bar(configs, person_ap5095, color=colors)
ax4.set_title('Person mAP50-95', fontsize=12)
ax4.set_ylabel('AP50-95')
ax4.set_ylim(0, max(person_ap5095) * 1.3 if max(person_ap5095) > 0 else 1)
for bar, val in zip(bars4, person_ap5095):
if val > 0:
ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
# 5. Vehicle mAP50-95
ax5 = axes[1, 1]
bars5 = ax5.bar(configs, vehicle_ap5095, color=colors)
ax5.set_title('Vehicles mAP50-95', fontsize=12)
ax5.set_ylabel('AP50-95')
ax5.set_ylim(0, max(vehicle_ap5095) * 1.3 if max(vehicle_ap5095) > 0 else 1)
for bar, val in zip(bars5, vehicle_ap5095):
if val > 0:
ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
# 6. FPS
ax6 = axes[1, 2]
bars6 = ax6.bar(configs, fps, color=colors)
ax6.set_title('FPS 对比', fontsize=12)
ax6.set_ylabel('FPS')
for bar, val in zip(bars6, fps):
if val > 0:
ax6.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
f'{val:.1f}', ha='center', va='bottom', fontsize=9)
plt.tight_layout()
chart_file = results_dir / "benchmark_charts.png"
plt.savefig(chart_file, dpi=150, bbox_inches='tight')
print(f"图表已保存: {chart_file}")
if __name__ == "__main__":
generate_charts()