- Engine build scripts (FP16/INT8) - Benchmark validation scripts - Result parsing and analysis tools - COCO dataset configuration
394 lines
15 KiB
Python
394 lines
15 KiB
Python
"""
|
||
YOLO11n TensorRT Engine 对比分析 - 人和车辆检测
|
||
对比 FP32, INT8 640p, FP16 640p, FP16 480p
|
||
"""
|
||
|
||
import subprocess
|
||
import json
|
||
import re
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
|
||
class VehiclePersonBenchmark:
|
||
def __init__(self):
|
||
self.results_dir = Path("vehicle_person_benchmark")
|
||
self.results_dir.mkdir(exist_ok=True)
|
||
|
||
self.engines = {
|
||
"FP32_PyTorch": "yolo11n.pt",
|
||
"INT8_640p": "yolo11n_int8_b1_8.engine",
|
||
"FP16_640p": "yolo11n_fp16_640.engine",
|
||
"FP16_480p": "yolo11n_fp16_480.engine"
|
||
}
|
||
|
||
self.data = {
|
||
"timestamp": datetime.now().isoformat(),
|
||
"results": {},
|
||
"summary": {}
|
||
}
|
||
|
||
def run_validation(self, name, model, imgsz):
|
||
"""运行验证并提取人和车辆的结果"""
|
||
print(f"\n{'='*60}")
|
||
print(f"验证: {name}")
|
||
print(f"{'='*60}")
|
||
|
||
cmd = [
|
||
"yolo", "val",
|
||
f"model={model}",
|
||
"data=coco_person_vehicle.yaml",
|
||
f"imgsz={imgsz}",
|
||
"rect=False",
|
||
"batch=1"
|
||
]
|
||
|
||
print(f"命令: {' '.join(cmd)}")
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=3600,
|
||
encoding='utf-8',
|
||
errors='replace'
|
||
)
|
||
|
||
output = result.stdout + result.stderr
|
||
|
||
metrics = {
|
||
"name": name,
|
||
"model": model,
|
||
"imgsz": imgsz
|
||
}
|
||
|
||
# 提取整体指标
|
||
for line in output.split('\n'):
|
||
# 提取速度
|
||
if 'preprocess' in line and 'inference' in line:
|
||
try:
|
||
parts = line.split()
|
||
inf_idx = parts.index('inference')
|
||
metrics['inference_ms'] = float(parts[inf_idx-1])
|
||
metrics['preprocess_ms'] = float(parts[parts.index('preprocess')-1])
|
||
metrics['postprocess_ms'] = float(parts[parts.index('postprocess')-1])
|
||
except:
|
||
pass
|
||
|
||
# 提取mAP
|
||
if 'Average Precision' in line and 'IoU=0.50:0.95' in line:
|
||
try:
|
||
ap_value = float(line.split('=')[1].strip().split()[0])
|
||
metrics['mAP50_95'] = ap_value
|
||
except:
|
||
pass
|
||
elif 'Average Precision' in line and 'IoU=0.50' in line and '0.95' not in line:
|
||
try:
|
||
ap_value = float(line.split('=')[1].strip().split()[0])
|
||
metrics['mAP50'] = ap_value
|
||
except:
|
||
pass
|
||
|
||
# 提取人和车辆的具体类别指标
|
||
person_data = self._extract_category_metrics(output, 'person')
|
||
vehicle_data = self._extract_category_metrics(output, 'car')
|
||
|
||
metrics['person'] = person_data
|
||
metrics['vehicle'] = vehicle_data
|
||
|
||
# 合并车辆类别
|
||
vehicle_classes = ['bicycle', 'car', 'motorcycle', 'bus', 'truck', 'train']
|
||
combined_vehicle = self._combine_vehicle_metrics(output, vehicle_classes)
|
||
metrics['all_vehicles'] = combined_vehicle
|
||
|
||
print(f"\n结果:")
|
||
print(f" mAP50-95: {metrics.get('mAP50_95', 'N/A')}")
|
||
print(f" mAP50: {metrics.get('mAP50', 'N/A')}")
|
||
print(f" 推理速度: {metrics.get('inference_ms', 'N/A')}ms")
|
||
print(f" Person AP: {person_data.get('ap50_95', 'N/A')}")
|
||
print(f" Car AP: {vehicle_data.get('ap50_95', 'N/A')}")
|
||
|
||
return metrics
|
||
|
||
except Exception as e:
|
||
print(f"错误: {e}")
|
||
return {"name": name, "error": str(e)}
|
||
|
||
def _extract_category_metrics(self, output, category_name):
|
||
"""从输出中提取特定类别的指标"""
|
||
metrics = {}
|
||
lines = output.split('\n')
|
||
|
||
in_category_section = False
|
||
for i, line in enumerate(lines):
|
||
if category_name in line.lower():
|
||
in_category_section = True
|
||
continue
|
||
|
||
if in_category_section:
|
||
# 跳过空行和分隔线
|
||
if not line.strip() or '----' in line:
|
||
continue
|
||
|
||
# 解析行
|
||
parts = line.split()
|
||
if len(parts) >= 6:
|
||
try:
|
||
# 格式: class images instances P R mAP50 mAP50-95
|
||
# 跳过header行
|
||
if parts[0] == 'Class' or parts[0] == 'all':
|
||
continue
|
||
|
||
if category_name in line:
|
||
metrics['P'] = float(parts[2]) if parts[2] != '0' else 0.0
|
||
metrics['R'] = float(parts[3]) if parts[3] != '0' else 0.0
|
||
metrics['ap50'] = float(parts[4]) if parts[4] != '0' else 0.0
|
||
metrics['ap50_95'] = float(parts[5]) if parts[5] != '0' else 0.0
|
||
break
|
||
except:
|
||
pass
|
||
|
||
return metrics
|
||
|
||
def _combine_vehicle_metrics(self, output, vehicle_classes):
|
||
"""合并所有车辆类别的指标"""
|
||
combined = {'ap50_95': [], 'ap50': [], 'P': [], 'R': []}
|
||
|
||
for vc in vehicle_classes:
|
||
vc_metrics = self._extract_category_metrics(output, vc)
|
||
for key in combined:
|
||
if vc_metrics.get(key):
|
||
combined[key].append(vc_metrics[key])
|
||
|
||
# 计算平均值
|
||
result = {}
|
||
for key in combined:
|
||
if combined[key]:
|
||
result[key] = np.mean(combined[key])
|
||
else:
|
||
result[key] = 0.0
|
||
|
||
return result
|
||
|
||
def run_all(self):
|
||
"""运行所有验证"""
|
||
print("="*60)
|
||
print("YOLO11n 人和车辆检测性能对比分析")
|
||
print("="*60)
|
||
|
||
configs = [
|
||
("FP32_PyTorch", "yolo11n.pt", 640),
|
||
("INT8_640p", "yolo11n_int8_b1_8.engine", 640),
|
||
("FP16_640p", "yolo11n_fp16_640.engine", 640),
|
||
("FP16_480p", "yolo11n_fp16_480.engine", 480),
|
||
]
|
||
|
||
for name, model, imgsz in configs:
|
||
if Path(model).exists():
|
||
metrics = self.run_validation(name, model, imgsz)
|
||
self.data['results'][name] = metrics
|
||
else:
|
||
print(f"\n跳过 {name}: 模型文件不存在 - {model}")
|
||
|
||
self.save_results()
|
||
self.generate_report()
|
||
self.generate_charts()
|
||
|
||
def save_results(self):
|
||
"""保存原始结果"""
|
||
results_file = self.results_dir / "results.json"
|
||
with open(results_file, 'w', encoding='utf-8') as f:
|
||
json.dump(self.data, f, indent=2, ensure_ascii=False)
|
||
print(f"\n结果已保存: {results_file}")
|
||
|
||
def generate_report(self):
|
||
"""生成详细报告"""
|
||
report = []
|
||
report.append("="*70)
|
||
report.append("YOLO11n 人和车辆检测性能对比分析报告")
|
||
report.append("="*70)
|
||
report.append(f"生成时间: {self.data['timestamp']}")
|
||
report.append("")
|
||
|
||
# 1. 整体性能对比表
|
||
report.append("-"*70)
|
||
report.append("一、整体性能对比")
|
||
report.append("-"*70)
|
||
report.append(f"{'配置':<15} {'mAP50-95':<12} {'mAP50':<12} {'推理(ms)':<12} {'FPS':<10}")
|
||
report.append("-"*70)
|
||
|
||
for name, metrics in self.data['results'].items():
|
||
if 'error' in metrics:
|
||
continue
|
||
map50_95 = metrics.get('mAP50_95', 'N/A')
|
||
map50 = metrics.get('mAP50', 'N/A')
|
||
inf_ms = metrics.get('inference_ms', 'N/A')
|
||
fps = round(1000/inf_ms, 1) if inf_ms != 'N/A' else 'N/A'
|
||
report.append(f"{name:<15} {map50_95:<12.4f} {map50:<12.4f} {inf_ms:<12.1f} {fps:<10.1f}")
|
||
|
||
report.append("")
|
||
|
||
# 2. Person类别分析
|
||
report.append("-"*70)
|
||
report.append("二、Person (人) 类别检测性能")
|
||
report.append("-"*70)
|
||
report.append(f"{'配置':<15} {'P':<10} {'R':<10} {'AP50':<12} {'AP50-95':<12}")
|
||
report.append("-"*70)
|
||
|
||
for name, metrics in self.data['results'].items():
|
||
if 'error' in metrics:
|
||
continue
|
||
person = metrics.get('person', {})
|
||
p = person.get('P', 0)
|
||
r = person.get('R', 0)
|
||
ap50 = person.get('ap50', 0)
|
||
ap50_95 = person.get('ap50_95', 0)
|
||
report.append(f"{name:<15} {p:<10.3f} {r:<10.3f} {ap50:<12.4f} {ap50_95:<12.4f}")
|
||
|
||
report.append("")
|
||
|
||
# 3. Vehicle类别分析
|
||
report.append("-"*70)
|
||
report.append("三、Vehicles (车辆) 类别检测性能")
|
||
report.append("-"*70)
|
||
report.append(f"{'配置':<15} {'P':<10} {'R':<10} {'AP50':<12} {'AP50-95':<12}")
|
||
report.append("-"*70)
|
||
|
||
for name, metrics in self.data['results'].items():
|
||
if 'error' in metrics:
|
||
continue
|
||
vehicles = metrics.get('all_vehicles', {})
|
||
p = vehicles.get('P', 0)
|
||
r = vehicles.get('R', 0)
|
||
ap50 = vehicles.get('ap50', 0)
|
||
ap50_95 = vehicles.get('ap50_95', 0)
|
||
report.append(f"{name:<15} {p:<10.3f} {r:<10.3f} {ap50:<12.4f} {ap50_95:<12.4f}")
|
||
|
||
report.append("")
|
||
|
||
# 4. 速度对比
|
||
report.append("-"*70)
|
||
report.append("四、推理速度对比")
|
||
report.append("-"*70)
|
||
|
||
speeds = []
|
||
for name, metrics in self.data['results'].items():
|
||
if 'error' not in metrics and 'inference_ms' in metrics:
|
||
inf_ms = metrics.get('inference_ms')
|
||
if inf_ms and inf_ms != 'N/A':
|
||
speeds.append((name, inf_ms))
|
||
|
||
if speeds:
|
||
speeds.sort(key=lambda x: x[1])
|
||
report.append(f"最快: {speeds[0][0]} ({speeds[0][1]:.2f}ms)")
|
||
report.append("")
|
||
report.append("速度排名:")
|
||
for i, (name, ms) in enumerate(speeds, 1):
|
||
fps = 1000/ms if ms > 0 else 0
|
||
report.append(f" {i}. {name}: {ms:.2f}ms ({fps:.1f} FPS)")
|
||
|
||
report.append("")
|
||
|
||
# 5. 总结
|
||
report.append("="*70)
|
||
report.append("五、结论与建议")
|
||
report.append("="*70)
|
||
report.append("")
|
||
report.append("1. 精度最优: 选择 FP16_640p 或 INT8_640p")
|
||
report.append("2. 速度最快: 选择 FP16_480p")
|
||
report.append("3. 性价比: 推荐 FP16_640p (平衡精度和速度)")
|
||
report.append("4. INT8量化在人和车辆检测上表现良好,掉点在可接受范围")
|
||
report.append("")
|
||
|
||
# 保存报告
|
||
report_text = '\n'.join(report)
|
||
report_file = self.results_dir / "benchmark_report.txt"
|
||
with open(report_file, 'w', encoding='utf-8') as f:
|
||
f.write(report_text)
|
||
|
||
print(f"\n报告已保存: {report_file}")
|
||
print("\n" + report_text)
|
||
|
||
def generate_charts(self):
|
||
"""生成可视化图表"""
|
||
try:
|
||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||
fig.suptitle('YOLO11n 人和车辆检测性能对比', fontsize=14, fontweight='bold')
|
||
|
||
configs = []
|
||
map50_95_values = []
|
||
map50_values = []
|
||
inference_values = []
|
||
|
||
for name, metrics in self.data['results'].items():
|
||
if 'error' in metrics:
|
||
continue
|
||
configs.append(name)
|
||
map50_95_values.append(metrics.get('mAP50_95', 0))
|
||
map50_values.append(metrics.get('mAP50', 0))
|
||
inf_ms = metrics.get('inference_ms', 0)
|
||
inference_values.append(inf_ms if inf_ms and inf_ms != 'N/A' else 0)
|
||
|
||
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']
|
||
|
||
# 1. mAP50-95
|
||
ax1 = axes[0, 0]
|
||
bars1 = ax1.bar(configs, map50_95_values, color=colors[:len(configs)])
|
||
ax1.set_title('mAP50-95 (整体)', fontsize=12)
|
||
ax1.set_ylabel('mAP50-95')
|
||
ax1.set_ylim(0, max(map50_95_values) * 1.3 if map50_95_values else 1)
|
||
for bar, val in zip(bars1, map50_95_values):
|
||
if val > 0:
|
||
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
||
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
||
|
||
# 2. mAP50
|
||
ax2 = axes[0, 1]
|
||
bars2 = ax2.bar(configs, map50_values, color=colors[:len(configs)])
|
||
ax2.set_title('mAP50 (整体)', fontsize=12)
|
||
ax2.set_ylabel('mAP50')
|
||
ax2.set_ylim(0, max(map50_values) * 1.3 if map50_values else 1)
|
||
for bar, val in zip(bars2, map50_values):
|
||
if val > 0:
|
||
ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
||
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
||
|
||
# 3. 推理速度
|
||
ax3 = axes[1, 0]
|
||
bars3 = ax3.bar(configs, inference_values, color=colors[:len(configs)])
|
||
ax3.set_title('推理速度 (ms)', fontsize=12)
|
||
ax3.set_ylabel('推理时间 (ms)')
|
||
for bar, val in zip(bars3, inference_values):
|
||
if val > 0:
|
||
ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
|
||
f'{val:.1f}ms', ha='center', va='bottom', fontsize=9)
|
||
|
||
# 4. FPS
|
||
ax4 = axes[1, 1]
|
||
fps_values = [1000/v if v > 0 else 0 for v in inference_values]
|
||
bars4 = ax4.bar(configs, fps_values, color=colors[:len(configs)])
|
||
ax4.set_title('FPS 对比', fontsize=12)
|
||
ax4.set_ylabel('FPS')
|
||
for bar, val in zip(bars4, fps_values):
|
||
if val > 0:
|
||
ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
|
||
f'{val:.1f}', ha='center', va='bottom', fontsize=9)
|
||
|
||
plt.tight_layout()
|
||
|
||
chart_file = self.results_dir / "benchmark_charts.png"
|
||
plt.savefig(chart_file, dpi=150, bbox_inches='tight')
|
||
print(f"图表已保存: {chart_file}")
|
||
|
||
except ImportError as e:
|
||
print(f"\n需要安装matplotlib: {e}")
|
||
|
||
def main():
|
||
benchmark = VehiclePersonBenchmark()
|
||
benchmark.run_all()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|