Add YOLO11 TensorRT quantization benchmark scripts
- Engine build scripts (FP16/INT8) - Benchmark validation scripts - Result parsing and analysis tools - COCO dataset configuration
This commit is contained in:
108
generate_charts.py
Normal file
108
generate_charts.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
生成性能对比图表
|
||||
"""
|
||||
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
def generate_charts():
|
||||
results_dir = Path("vehicle_person_benchmark")
|
||||
results_file = results_dir / "all_results.json"
|
||||
|
||||
if not results_file.exists():
|
||||
print("请先运行 parse_results.py 生成结果数据")
|
||||
return
|
||||
|
||||
with open(results_file, 'r', encoding='utf-8') as f:
|
||||
all_results = json.load(f)
|
||||
|
||||
# 准备数据
|
||||
configs = list(all_results.keys())
|
||||
|
||||
map50_95 = [all_results[c].get('mAP50_95', 0) for c in configs]
|
||||
map50 = [all_results[c].get('mAP50', 0) for c in configs]
|
||||
inference = [all_results[c].get('inference_ms', 0) for c in configs]
|
||||
fps = [1000/i if i > 0 else 0 for i in inference]
|
||||
|
||||
person_ap5095 = [all_results[c].get('person', {}).get('ap50_95', 0) for c in configs]
|
||||
vehicle_ap5095 = [all_results[c].get('all_vehicles', {}).get('ap50_95', 0) for c in configs]
|
||||
|
||||
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']
|
||||
|
||||
fig, axes = plt.subplots(2, 3, figsize=(16, 10))
|
||||
fig.suptitle('YOLO11n 人和车辆检测性能对比', fontsize=14, fontweight='bold')
|
||||
|
||||
# 1. mAP50-95
|
||||
ax1 = axes[0, 0]
|
||||
bars1 = ax1.bar(configs, map50_95, color=colors)
|
||||
ax1.set_title('整体 mAP50-95', fontsize=12)
|
||||
ax1.set_ylabel('mAP50-95')
|
||||
ax1.set_ylim(0, max(map50_95) * 1.3 if max(map50_95) > 0 else 1)
|
||||
for bar, val in zip(bars1, map50_95):
|
||||
if val > 0:
|
||||
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
||||
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
||||
|
||||
# 2. mAP50
|
||||
ax2 = axes[0, 1]
|
||||
bars2 = ax2.bar(configs, map50, color=colors)
|
||||
ax2.set_title('整体 mAP50', fontsize=12)
|
||||
ax2.set_ylabel('mAP50')
|
||||
ax2.set_ylim(0, max(map50) * 1.3 if max(map50) > 0 else 1)
|
||||
for bar, val in zip(bars2, map50):
|
||||
if val > 0:
|
||||
ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
||||
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
||||
|
||||
# 3. 推理速度
|
||||
ax3 = axes[0, 2]
|
||||
bars3 = ax3.bar(configs, inference, color=colors)
|
||||
ax3.set_title('推理速度 (ms)', fontsize=12)
|
||||
ax3.set_ylabel('推理时间 (ms)')
|
||||
for bar, val in zip(bars3, inference):
|
||||
if val > 0:
|
||||
ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
|
||||
f'{val:.1f}ms', ha='center', va='bottom', fontsize=9)
|
||||
|
||||
# 4. Person mAP50-95
|
||||
ax4 = axes[1, 0]
|
||||
bars4 = ax4.bar(configs, person_ap5095, color=colors)
|
||||
ax4.set_title('Person mAP50-95', fontsize=12)
|
||||
ax4.set_ylabel('AP50-95')
|
||||
ax4.set_ylim(0, max(person_ap5095) * 1.3 if max(person_ap5095) > 0 else 1)
|
||||
for bar, val in zip(bars4, person_ap5095):
|
||||
if val > 0:
|
||||
ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
||||
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
||||
|
||||
# 5. Vehicle mAP50-95
|
||||
ax5 = axes[1, 1]
|
||||
bars5 = ax5.bar(configs, vehicle_ap5095, color=colors)
|
||||
ax5.set_title('Vehicles mAP50-95', fontsize=12)
|
||||
ax5.set_ylabel('AP50-95')
|
||||
ax5.set_ylim(0, max(vehicle_ap5095) * 1.3 if max(vehicle_ap5095) > 0 else 1)
|
||||
for bar, val in zip(bars5, vehicle_ap5095):
|
||||
if val > 0:
|
||||
ax5.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
||||
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
||||
|
||||
# 6. FPS
|
||||
ax6 = axes[1, 2]
|
||||
bars6 = ax6.bar(configs, fps, color=colors)
|
||||
ax6.set_title('FPS 对比', fontsize=12)
|
||||
ax6.set_ylabel('FPS')
|
||||
for bar, val in zip(bars6, fps):
|
||||
if val > 0:
|
||||
ax6.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
|
||||
f'{val:.1f}', ha='center', va='bottom', fontsize=9)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
chart_file = results_dir / "benchmark_charts.png"
|
||||
plt.savefig(chart_file, dpi=150, bbox_inches='tight')
|
||||
print(f"图表已保存: {chart_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_charts()
|
||||
Reference in New Issue
Block a user