- Engine build scripts (FP16/INT8) - Benchmark validation scripts - Result parsing and analysis tools - COCO dataset configuration
362 lines
13 KiB
Python
362 lines
13 KiB
Python
"""
|
||
YOLO11n TensorRT Engine 性能对比分析
|
||
对比 INT8 640p vs FP16 480p 的性能、精度、速度等指标
|
||
"""
|
||
|
||
import os
|
||
import subprocess
|
||
import time
|
||
import json
|
||
import numpy as np
|
||
from pathlib import Path
|
||
from datetime import datetime
|
||
|
||
class EngineBenchmark:
|
||
def __init__(self):
|
||
self.results_dir = Path("benchmark_results")
|
||
self.results_dir.mkdir(exist_ok=True)
|
||
|
||
self.engines = {
|
||
"INT8_640": "yolo11n_int8_b1_8.engine",
|
||
"FP16_640": "yolo11n_fp16_640.engine",
|
||
"INT8_480": "yolo11n_int8_480.engine",
|
||
"FP16_480": "yolo11n_fp16_480.engine"
|
||
}
|
||
|
||
self.benchmark_data = {}
|
||
|
||
def build_engine(self, name, onnx_path, precision, input_size):
|
||
"""构建指定配置的engine"""
|
||
engine_path = self.engines[name]
|
||
|
||
if os.path.exists(engine_path):
|
||
print(f"[✓] {name} engine already exists: {engine_path}")
|
||
return True
|
||
|
||
print(f"\n{'='*60}")
|
||
print(f"Building {name} engine ({precision}, {input_size}p)")
|
||
print(f"{'='*60}")
|
||
|
||
# 构建命令
|
||
cmd = [
|
||
"trtexec",
|
||
f"--onnx={onnx_path}",
|
||
f"--saveEngine={engine_path}",
|
||
"--explicitBatch",
|
||
f"--{precision.lower()}",
|
||
"--workspace=4096",
|
||
"--builderOptimizationLevel=4",
|
||
"--profilingVerbosity=detailed",
|
||
"--optShapes=input:4x3x{input_size}x{input_size}",
|
||
"--maxShapes=input:8x3x{input_size}x{input_size}",
|
||
"--useCudaGraph",
|
||
"--useSpinWait",
|
||
"--noTF32"
|
||
]
|
||
|
||
print(f"Command: {' '.join(cmd)}")
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=600
|
||
)
|
||
|
||
if result.returncode == 0:
|
||
print(f"[✓] {name} engine built successfully!")
|
||
return True
|
||
else:
|
||
print(f"[✗] {name} engine build failed!")
|
||
print(f"Error: {result.stderr}")
|
||
return False
|
||
except Exception as e:
|
||
print(f"[✗] Error building {name}: {e}")
|
||
return False
|
||
|
||
def validate_engine(self, name, engine_path):
|
||
"""验证engine并获取mAP"""
|
||
print(f"\nValidating {name}...")
|
||
|
||
# 运行验证
|
||
cmd = [
|
||
"yolo", "val",
|
||
f"model={engine_path}",
|
||
"data=coco.yaml",
|
||
"imgsz=640" if "640" in name else "imgsz=480",
|
||
"rect=False",
|
||
"batch=1"
|
||
]
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=3600
|
||
)
|
||
|
||
# 解析输出获取mAP
|
||
output = result.stdout + result.stderr
|
||
|
||
# 提取关键指标
|
||
metrics = {
|
||
"name": name,
|
||
"engine": engine_path,
|
||
"timestamp": datetime.now().isoformat()
|
||
}
|
||
|
||
# 查找AP指标
|
||
for line in output.split('\n'):
|
||
if 'Average Precision' in line and 'IoU=0.50:0.95' in line:
|
||
try:
|
||
ap_value = float(line.split('=')[1].strip().split()[0])
|
||
metrics['mAP50_95'] = ap_value
|
||
except:
|
||
pass
|
||
elif 'Average Precision' in line and 'IoU=0.50' in line and '0.95' not in line:
|
||
try:
|
||
ap_value = float(line.split('=')[1].strip().split()[0])
|
||
metrics['mAP50'] = ap_value
|
||
except:
|
||
pass
|
||
|
||
# 查找速度
|
||
for line in output.split('\n'):
|
||
if 'preprocess' in line and 'inference' in line:
|
||
try:
|
||
parts = line.split()
|
||
inf_idx = parts.index('inference')
|
||
metrics['inference_ms'] = float(parts[inf_idx-1])
|
||
except:
|
||
pass
|
||
|
||
print(f"[✓] {name} validation complete")
|
||
print(f" mAP50-95: {metrics.get('mAP50_95', 'N/A')}")
|
||
print(f" mAP50: {metrics.get('mAP50', 'N/A')}")
|
||
print(f" Inference: {metrics.get('inference_ms', 'N/A')}ms")
|
||
|
||
return metrics
|
||
|
||
except Exception as e:
|
||
print(f"[✗] Error validating {name}: {e}")
|
||
return {"name": name, "error": str(e)}
|
||
|
||
def run_benchmark(self):
|
||
"""运行完整基准测试"""
|
||
print("="*60)
|
||
print("YOLO11n TensorRT Engine 性能对比分析")
|
||
print("="*60)
|
||
print(f"\n配置:")
|
||
print(" - INT8 640p: 8位整数量化, 640x640输入")
|
||
print(" - FP16 640p: 半精度浮点, 640x640输入")
|
||
print(" - INT8 480p: 8位整数量化, 480x480输入")
|
||
print(" - FP16 480p: 半精度浮点, 480x480输入")
|
||
print(" - Batch: 1-8, Opt: 4, 优化级别: 4")
|
||
print()
|
||
|
||
# 验证FP32基线
|
||
print("\n" + "="*60)
|
||
print("Step 1: 获取FP32基线 (PyTorch)")
|
||
print("="*60)
|
||
|
||
fp32_metrics = self.validate_engine("FP32_PyTorch", "yolo11n.pt")
|
||
self.benchmark_data['FP32'] = fp32_metrics
|
||
|
||
# 构建并验证各engine
|
||
configs = [
|
||
("INT8_640", "yolo11n.onnx", "INT8", 640),
|
||
("FP16_640", "yolo11n.onnx", "FP16", 640),
|
||
("INT8_480", "yolo11n.onnx", "INT8", 480),
|
||
("FP16_480", "yolo11n.onnx", "FP16", 480),
|
||
]
|
||
|
||
for name, onnx_path, precision, size in configs:
|
||
self.build_engine(name, onnx_path, precision, size)
|
||
if os.path.exists(self.engines[name]):
|
||
metrics = self.validate_engine(name, self.engines[name])
|
||
self.benchmark_data[name] = metrics
|
||
|
||
# 保存结果
|
||
self.save_results()
|
||
|
||
# 生成报告
|
||
self.generate_report()
|
||
|
||
def save_results(self):
|
||
"""保存原始结果"""
|
||
results_file = self.results_dir / "benchmark_raw.json"
|
||
with open(results_file, 'w', encoding='utf-8') as f:
|
||
json.dump(self.benchmark_data, f, indent=2, ensure_ascii=False)
|
||
print(f"\n结果已保存到: {results_file}")
|
||
|
||
def generate_report(self):
|
||
"""生成详细分析报告"""
|
||
report = []
|
||
report.append("="*70)
|
||
report.append("YOLO11n TensorRT Engine 性能对比分析报告")
|
||
report.append("="*70)
|
||
report.append(f"生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
report.append("")
|
||
|
||
# 表头
|
||
report.append("-"*70)
|
||
report.append("一、性能指标对比表")
|
||
report.append("-"*70)
|
||
report.append(f"{'配置':<15} {'mAP50-95':<12} {'mAP50':<12} {'推理速度':<12} {'FPS':<10}")
|
||
report.append("-"*70)
|
||
|
||
fp32_map = self.benchmark_data.get('FP32', {}).get('mAP50_95', 0)
|
||
|
||
for name, data in self.benchmark_data.items():
|
||
if 'error' in data:
|
||
continue
|
||
map50_95 = data.get('mAP50_95', 'N/A')
|
||
map50 = data.get('mAP50', 'N/A')
|
||
inf_ms = data.get('inference_ms', 'N/A')
|
||
fps = round(1000/inf_ms, 1) if inf_ms != 'N/A' else 'N/A'
|
||
report.append(f"{name:<15} {map50_95:<12} {map50:<12} {inf_ms}ms{' '*5} {fps}")
|
||
|
||
report.append("")
|
||
report.append("-"*70)
|
||
report.append("二、精度掉点分析 (相对于FP32)")
|
||
report.append("-"*70)
|
||
|
||
for name, data in self.benchmark_data.items():
|
||
if name == 'FP32' or 'error' in data:
|
||
continue
|
||
map50_95 = data.get('mAP50_95', 0)
|
||
if fp32_map > 0 and map50_95 > 0:
|
||
drop = (fp32_map - map50_95) / fp32_map * 100
|
||
report.append(f"{name:<15}: mAP50-95 掉点 {drop:.2f}%")
|
||
|
||
report.append("")
|
||
report.append("-"*70)
|
||
report.append("三、速度对比")
|
||
report.append("-"*70)
|
||
|
||
# 找最快速度
|
||
speeds = []
|
||
for name, data in self.benchmark_data.items():
|
||
if 'error' not in data and 'inference_ms' in data:
|
||
inf_ms = data.get('inference_ms')
|
||
if inf_ms != 'N/A':
|
||
speeds.append((name, inf_ms))
|
||
|
||
if speeds:
|
||
speeds.sort(key=lambda x: x[1])
|
||
fastest = speeds[0]
|
||
report.append(f"最快配置: {fastest[0]} ({fastest[1]}ms)")
|
||
report.append("")
|
||
report.append("速度排名:")
|
||
for i, (name, ms) in enumerate(speeds, 1):
|
||
report.append(f" {i}. {name}: {ms}ms")
|
||
|
||
report.append("")
|
||
report.append("="*70)
|
||
report.append("四、结论与建议")
|
||
report.append("="*70)
|
||
report.append("")
|
||
report.append("1. 如果追求最高精度: 选择 INT8 640p 或 FP16 640p")
|
||
report.append("2. 如果追求最快速度: 选择 FP16 480p")
|
||
report.append("3. 如果平衡精度和速度: 选择 FP16 640p")
|
||
report.append("4. INT8量化会有约10%的mAP掉点,但推理速度显著提升")
|
||
report.append("")
|
||
|
||
# 保存报告
|
||
report_text = '\n'.join(report)
|
||
report_file = self.results_dir / "benchmark_report.txt"
|
||
with open(report_file, 'w', encoding='utf-8') as f:
|
||
f.write(report_text)
|
||
|
||
print(f"\n报告已保存到: {report_file}")
|
||
print("\n" + report_text)
|
||
|
||
def generate_charts(self):
|
||
"""生成可视化图表(需要matplotlib)"""
|
||
try:
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
matplotlib.use('Agg')
|
||
|
||
# 创建图表
|
||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||
fig.suptitle('YOLO11n TensorRT Engine 性能对比', fontsize=14, fontweight='bold')
|
||
|
||
# 数据准备
|
||
configs = []
|
||
map50_95_values = []
|
||
map50_values = []
|
||
inference_values = []
|
||
|
||
for name, data in self.benchmark_data.items():
|
||
if 'error' in data:
|
||
continue
|
||
configs.append(name)
|
||
map50_95_values.append(data.get('mAP50_95', 0))
|
||
map50_values.append(data.get('mAP50', 0))
|
||
inf_ms = data.get('inference_ms', 0)
|
||
if inf_ms != 'N/A' and inf_ms > 0:
|
||
inference_values.append(inf_ms)
|
||
else:
|
||
inference_values.append(0)
|
||
|
||
colors = ['#2ecc71', '#3498db', '#e74c3c', '#9b59b6']
|
||
|
||
# 1. mAP50-95对比
|
||
ax1 = axes[0, 0]
|
||
bars1 = ax1.bar(configs, map50_95_values, color=colors[:len(configs)])
|
||
ax1.set_title('mAP50-95 对比', fontsize=12)
|
||
ax1.set_ylabel('mAP50-95')
|
||
ax1.set_ylim(0, max(map50_95_values) * 1.2)
|
||
for bar, val in zip(bars1, map50_95_values):
|
||
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
||
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
||
|
||
# 2. mAP50对比
|
||
ax2 = axes[0, 1]
|
||
bars2 = ax2.bar(configs, map50_values, color=colors[:len(configs)])
|
||
ax2.set_title('mAP50 对比', fontsize=12)
|
||
ax2.set_ylabel('mAP50')
|
||
ax2.set_ylim(0, max(map50_values) * 1.2)
|
||
for bar, val in zip(bars2, map50_values):
|
||
ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
|
||
f'{val:.3f}', ha='center', va='bottom', fontsize=9)
|
||
|
||
# 3. 推理速度对比
|
||
ax3 = axes[1, 0]
|
||
bars3 = ax3.bar(configs, inference_values, color=colors[:len(configs)])
|
||
ax3.set_title('推理速度对比 (ms)', fontsize=12)
|
||
ax3.set_ylabel('推理时间 (ms)')
|
||
for bar, val in zip(bars3, inference_values):
|
||
ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
|
||
f'{val:.1f}ms', ha='center', va='bottom', fontsize=9)
|
||
|
||
# 4. FPS对比
|
||
ax4 = axes[1, 1]
|
||
fps_values = [1000/v if v > 0 else 0 for v in inference_values]
|
||
bars4 = ax4.bar(configs, fps_values, color=colors[:len(configs)])
|
||
ax4.set_title('FPS 对比', fontsize=12)
|
||
ax4.set_ylabel('FPS')
|
||
for bar, val in zip(bars4, fps_values):
|
||
ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
|
||
f'{val:.1f}', ha='center', va='bottom', fontsize=9)
|
||
|
||
plt.tight_layout()
|
||
|
||
chart_file = self.results_dir / "benchmark_charts.png"
|
||
plt.savefig(chart_file, dpi=150, bbox_inches='tight')
|
||
print(f"\n图表已保存到: {chart_file}")
|
||
|
||
except ImportError:
|
||
print("\n[提示] 需要安装matplotlib才能生成图表: pip install matplotlib")
|
||
|
||
def main():
|
||
benchmark = EngineBenchmark()
|
||
benchmark.run_benchmark()
|
||
benchmark.generate_charts()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|