Files
Test_AI/benchmark_vehicle_person.py
16337 942244bd88 Add YOLO11 TensorRT quantization benchmark scripts
- Engine build scripts (FP16/INT8)
- Benchmark validation scripts
- Result parsing and analysis tools
- COCO dataset configuration
2026-01-29 13:59:42 +08:00

394 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YOLO11n TensorRT Engine 对比分析 - 人和车辆检测
对比 FP32, INT8 640p, FP16 640p, FP16 480p
"""
import subprocess
import json
import re
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
class VehiclePersonBenchmark:
def __init__(self):
self.results_dir = Path("vehicle_person_benchmark")
self.results_dir.mkdir(exist_ok=True)
self.engines = {
"FP32_PyTorch": "yolo11n.pt",
"INT8_640p": "yolo11n_int8_b1_8.engine",
"FP16_640p": "yolo11n_fp16_640.engine",
"FP16_480p": "yolo11n_fp16_480.engine"
}
self.data = {
"timestamp": datetime.now().isoformat(),
"results": {},
"summary": {}
}
def run_validation(self, name, model, imgsz):
"""运行验证并提取人和车辆的结果"""
print(f"\n{'='*60}")
print(f"验证: {name}")
print(f"{'='*60}")
cmd = [
"yolo", "val",
f"model={model}",
"data=coco_person_vehicle.yaml",
f"imgsz={imgsz}",
"rect=False",
"batch=1"
]
print(f"命令: {' '.join(cmd)}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=3600,
encoding='utf-8',
errors='replace'
)
output = result.stdout + result.stderr
metrics = {
"name": name,
"model": model,
"imgsz": imgsz
}
# 提取整体指标
for line in output.split('\n'):
# 提取速度
if 'preprocess' in line and 'inference' in line:
try:
parts = line.split()
inf_idx = parts.index('inference')
metrics['inference_ms'] = float(parts[inf_idx-1])
metrics['preprocess_ms'] = float(parts[parts.index('preprocess')-1])
metrics['postprocess_ms'] = float(parts[parts.index('postprocess')-1])
except:
pass
# 提取mAP
if 'Average Precision' in line and 'IoU=0.50:0.95' in line:
try:
ap_value = float(line.split('=')[1].strip().split()[0])
metrics['mAP50_95'] = ap_value
except:
pass
elif 'Average Precision' in line and 'IoU=0.50' in line and '0.95' not in line:
try:
ap_value = float(line.split('=')[1].strip().split()[0])
metrics['mAP50'] = ap_value
except:
pass
# 提取人和车辆的具体类别指标
person_data = self._extract_category_metrics(output, 'person')
vehicle_data = self._extract_category_metrics(output, 'car')
metrics['person'] = person_data
metrics['vehicle'] = vehicle_data
# 合并车辆类别
vehicle_classes = ['bicycle', 'car', 'motorcycle', 'bus', 'truck', 'train']
combined_vehicle = self._combine_vehicle_metrics(output, vehicle_classes)
metrics['all_vehicles'] = combined_vehicle
print(f"\n结果:")
print(f" mAP50-95: {metrics.get('mAP50_95', 'N/A')}")
print(f" mAP50: {metrics.get('mAP50', 'N/A')}")
print(f" 推理速度: {metrics.get('inference_ms', 'N/A')}ms")
print(f" Person AP: {person_data.get('ap50_95', 'N/A')}")
print(f" Car AP: {vehicle_data.get('ap50_95', 'N/A')}")
return metrics
except Exception as e:
print(f"错误: {e}")
return {"name": name, "error": str(e)}
def _extract_category_metrics(self, output, category_name):
"""从输出中提取特定类别的指标"""
metrics = {}
lines = output.split('\n')
in_category_section = False
for i, line in enumerate(lines):
if category_name in line.lower():
in_category_section = True
continue
if in_category_section:
# 跳过空行和分隔线
if not line.strip() or '----' in line:
continue
# 解析行
parts = line.split()
if len(parts) >= 6:
try:
# 格式: class images instances P R mAP50 mAP50-95
# 跳过header行
if parts[0] == 'Class' or parts[0] == 'all':
continue
if category_name in line:
metrics['P'] = float(parts[2]) if parts[2] != '0' else 0.0
metrics['R'] = float(parts[3]) if parts[3] != '0' else 0.0
metrics['ap50'] = float(parts[4]) if parts[4] != '0' else 0.0
metrics['ap50_95'] = float(parts[5]) if parts[5] != '0' else 0.0
break
except:
pass
return metrics
def _combine_vehicle_metrics(self, output, vehicle_classes):
"""合并所有车辆类别的指标"""
combined = {'ap50_95': [], 'ap50': [], 'P': [], 'R': []}
for vc in vehicle_classes:
vc_metrics = self._extract_category_metrics(output, vc)
for key in combined:
if vc_metrics.get(key):
combined[key].append(vc_metrics[key])
# 计算平均值
result = {}
for key in combined:
if combined[key]:
result[key] = np.mean(combined[key])
else:
result[key] = 0.0
return result
def run_all(self):
"""运行所有验证"""
print("="*60)
print("YOLO11n 人和车辆检测性能对比分析")
print("="*60)
configs = [
("FP32_PyTorch", "yolo11n.pt", 640),
("INT8_640p", "yolo11n_int8_b1_8.engine", 640),
("FP16_640p", "yolo11n_fp16_640.engine", 640),
("FP16_480p", "yolo11n_fp16_480.engine", 480),
]
for name, model, imgsz in configs:
if Path(model).exists():
metrics = self.run_validation(name, model, imgsz)
self.data['results'][name] = metrics
else:
print(f"\n跳过 {name}: 模型文件不存在 - {model}")
self.save_results()
self.generate_report()
self.generate_charts()
def save_results(self):
"""保存原始结果"""
results_file = self.results_dir / "results.json"
with open(results_file, 'w', encoding='utf-8') as f:
json.dump(self.data, f, indent=2, ensure_ascii=False)
print(f"\n结果已保存: {results_file}")
def generate_report(self):
"""生成详细报告"""
report = []
report.append("="*70)
report.append("YOLO11n 人和车辆检测性能对比分析报告")
report.append("="*70)
report.append(f"生成时间: {self.data['timestamp']}")
report.append("")
# 1. 整体性能对比表
report.append("-"*70)
report.append("一、整体性能对比")
report.append("-"*70)
report.append(f"{'配置':<15} {'mAP50-95':<12} {'mAP50':<12} {'推理(ms)':<12} {'FPS':<10}")
report.append("-"*70)
for name, metrics in self.data['results'].items():
if 'error' in metrics:
continue
map50_95 = metrics.get('mAP50_95', 'N/A')
map50 = metrics.get('mAP50', 'N/A')
inf_ms = metrics.get('inference_ms', 'N/A')
fps = round(1000/inf_ms, 1) if inf_ms != 'N/A' else 'N/A'
report.append(f"{name:<15} {map50_95:<12.4f} {map50:<12.4f} {inf_ms:<12.1f} {fps:<10.1f}")
report.append("")
# 2. Person类别分析
report.append("-"*70)
report.append("二、Person (人) 类别检测性能")
report.append("-"*70)
report.append(f"{'配置':<15} {'P':<10} {'R':<10} {'AP50':<12} {'AP50-95':<12}")
report.append("-"*70)
for name, metrics in self.data['results'].items():
if 'error' in metrics:
continue
person = metrics.get('person', {})
p = person.get('P', 0)
r = person.get('R', 0)
ap50 = person.get('ap50', 0)
ap50_95 = person.get('ap50_95', 0)
report.append(f"{name:<15} {p:<10.3f} {r:<10.3f} {ap50:<12.4f} {ap50_95:<12.4f}")
report.append("")
# 3. Vehicle类别分析
report.append("-"*70)
report.append("三、Vehicles (车辆) 类别检测性能")
report.append("-"*70)
report.append(f"{'配置':<15} {'P':<10} {'R':<10} {'AP50':<12} {'AP50-95':<12}")
report.append("-"*70)
for name, metrics in self.data['results'].items():
if 'error' in metrics:
continue
vehicles = metrics.get('all_vehicles', {})
p = vehicles.get('P', 0)
r = vehicles.get('R', 0)
ap50 = vehicles.get('ap50', 0)
ap50_95 = vehicles.get('ap50_95', 0)
report.append(f"{name:<15} {p:<10.3f} {r:<10.3f} {ap50:<12.4f} {ap50_95:<12.4f}")
report.append("")
# 4. 速度对比
report.append("-"*70)
report.append("四、推理速度对比")
report.append("-"*70)
speeds = []
for name, metrics in self.data['results'].items():
if 'error' not in metrics and 'inference_ms' in metrics:
inf_ms = metrics.get('inference_ms')
if inf_ms and inf_ms != 'N/A':
speeds.append((name, inf_ms))
if speeds:
speeds.sort(key=lambda x: x[1])
report.append(f"最快: {speeds[0][0]} ({speeds[0][1]:.2f}ms)")
report.append("")
report.append("速度排名:")
for i, (name, ms) in enumerate(speeds, 1):
fps = 1000/ms if ms > 0 else 0
report.append(f" {i}. {name}: {ms:.2f}ms ({fps:.1f} FPS)")
report.append("")
# 5. 总结
report.append("="*70)
report.append("五、结论与建议")
report.append("="*70)
report.append("")
report.append("1. 精度最优: 选择 FP16_640p 或 INT8_640p")
report.append("2. 速度最快: 选择 FP16_480p")
report.append("3. 性价比: 推荐 FP16_640p (平衡精度和速度)")
report.append("4. INT8量化在人和车辆检测上表现良好掉点在可接受范围")
report.append("")
# 保存报告
report_text = '\n'.join(report)
report_file = self.results_dir / "benchmark_report.txt"
with open(report_file, 'w', encoding='utf-8') as f:
f.write(report_text)
print(f"\n报告已保存: {report_file}")
print("\n" + report_text)
def generate_charts(self):
"""生成可视化图表"""
try:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('YOLO11n 人和车辆检测性能对比', fontsize=14, fontweight='bold')
configs = []
map50_95_values = []
map50_values = []
inference_values = []
for name, metrics in self.data['results'].items():
if 'error' in metrics:
continue
configs.append(name)
map50_95_values.append(metrics.get('mAP50_95', 0))
map50_values.append(metrics.get('mAP50', 0))
inf_ms = metrics.get('inference_ms', 0)
inference_values.append(inf_ms if inf_ms and inf_ms != 'N/A' else 0)
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']
# 1. mAP50-95
ax1 = axes[0, 0]
bars1 = ax1.bar(configs, map50_95_values, color=colors[:len(configs)])
ax1.set_title('mAP50-95 (整体)', fontsize=12)
ax1.set_ylabel('mAP50-95')
ax1.set_ylim(0, max(map50_95_values) * 1.3 if map50_95_values else 1)
for bar, val in zip(bars1, map50_95_values):
if val > 0:
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
# 2. mAP50
ax2 = axes[0, 1]
bars2 = ax2.bar(configs, map50_values, color=colors[:len(configs)])
ax2.set_title('mAP50 (整体)', fontsize=12)
ax2.set_ylabel('mAP50')
ax2.set_ylim(0, max(map50_values) * 1.3 if map50_values else 1)
for bar, val in zip(bars2, map50_values):
if val > 0:
ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
# 3. 推理速度
ax3 = axes[1, 0]
bars3 = ax3.bar(configs, inference_values, color=colors[:len(configs)])
ax3.set_title('推理速度 (ms)', fontsize=12)
ax3.set_ylabel('推理时间 (ms)')
for bar, val in zip(bars3, inference_values):
if val > 0:
ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
f'{val:.1f}ms', ha='center', va='bottom', fontsize=9)
# 4. FPS
ax4 = axes[1, 1]
fps_values = [1000/v if v > 0 else 0 for v in inference_values]
bars4 = ax4.bar(configs, fps_values, color=colors[:len(configs)])
ax4.set_title('FPS 对比', fontsize=12)
ax4.set_ylabel('FPS')
for bar, val in zip(bars4, fps_values):
if val > 0:
ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
f'{val:.1f}', ha='center', va='bottom', fontsize=9)
plt.tight_layout()
chart_file = self.results_dir / "benchmark_charts.png"
plt.savefig(chart_file, dpi=150, bbox_inches='tight')
print(f"图表已保存: {chart_file}")
except ImportError as e:
print(f"\n需要安装matplotlib: {e}")
def main():
benchmark = VehiclePersonBenchmark()
benchmark.run_all()
if __name__ == "__main__":
main()