Files
Test_AI/parse_results_v2.py

298 lines
11 KiB
Python
Raw Permalink Normal View History

"""
直接从ultralytics验证结果生成对比报告
"""
import json
import re
from pathlib import Path
from datetime import datetime
import numpy as np
import shutil
class DirectResultsParser:
def __init__(self):
self.results_dir = Path("vehicle_person_benchmark")
self.results_dir.mkdir(exist_ok=True)
# 验证结果目录映射
self.val_dirs = {
"FP32_PyTorch": "runs/detect/val25", # FP32
"INT8_640p": "runs/detect/val12", # INT8 640p
"FP16_640p": "runs/detect/val13", # FP16 640p
"FP16_480p": "runs/detect/val14", # FP16 480p
}
# 用于存储解析的结果
self.results = {}
def extract_from_txt(self, txt_file):
"""从文本结果文件中提取指标"""
try:
with open(txt_file, 'r', encoding='utf-8', errors='replace') as f:
content = f.read()
metrics = {}
# 提取速度信息
speed_match = re.search(r'(\d+\.\d+)ms preprocess.*?(\d+\.\d+)ms inference', content, re.DOTALL)
if speed_match:
metrics['inference_ms'] = float(speed_match.group(2))
metrics['preprocess_ms'] = float(speed_match.group(1))
# 提取overall指标 (all行)
all_pattern = r'\ball\b\s+\d+\s+\d+\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)'
all_match = re.search(all_pattern, content)
if all_match:
metrics['P'] = float(all_match.group(1))
metrics['R'] = float(all_match.group(2))
metrics['ap50'] = float(all_match.group(3))
metrics['ap50_95'] = float(all_match.group(4))
metrics['mAP50'] = metrics['ap50']
metrics['mAP50_95'] = metrics['ap50_95']
# 提取Person类别
person_pattern = r'\bperson\b\s+\d+\s+\d+\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)'
person_match = re.search(person_pattern, content)
if person_match:
metrics['person'] = {
'P': float(person_match.group(1)),
'R': float(person_match.group(2)),
'ap50': float(person_match.group(3)),
'ap50_95': float(person_match.group(4))
}
# 提取Car类别
car_pattern = r'\bcar\b\s+\d+\s+\d+\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)'
car_match = re.search(car_pattern, content)
if car_match:
metrics['car'] = {
'P': float(car_match.group(1)),
'R': float(car_match.group(2)),
'ap50': float(car_match.group(3)),
'ap50_95': float(car_match.group(4))
}
# 提取bicycle
bike_pattern = r'\bbicycle\b\s+\d+\s+\d+\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)'
bike_match = re.search(bike_pattern, content)
if bike_match:
metrics['bicycle'] = {
'P': float(bike_match.group(1)),
'R': float(bike_match.group(2)),
'ap50': float(bike_match.group(3)),
'ap50_95': float(bike_match.group(4))
}
# 提取motorcycle
moto_pattern = r'\bmotorcycle\b\s+\d+\s+\d+\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)'
moto_match = re.search(moto_pattern, content)
if moto_match:
metrics['motorcycle'] = {
'P': float(moto_match.group(1)),
'R': float(moto_match.group(2)),
'ap50': float(moto_match.group(3)),
'ap50_95': float(moto_match.group(4))
}
# 提取bus
bus_pattern = r'\bbus\b\s+\d+\s+\d+\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)'
bus_match = re.search(bus_pattern, content)
if bus_match:
metrics['bus'] = {
'P': float(bus_match.group(1)),
'R': float(bus_match.group(2)),
'ap50': float(bus_match.group(3)),
'ap50_95': float(bus_match.group(4))
}
# 提取truck
truck_pattern = r'\btruck\b\s+\d+\s+\d+\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)\s+([\d.]+)'
truck_match = re.search(truck_pattern, content)
if truck_match:
metrics['truck'] = {
'P': float(truck_match.group(1)),
'R': float(truck_match.group(2)),
'ap50': float(truck_match.group(3)),
'ap50_95': float(truck_match.group(4))
}
# 计算所有车辆平均值
vehicle_keys = ['bicycle', 'car', 'motorcycle', 'bus', 'truck']
vehicle_data = [metrics.get(k, {}) for k in vehicle_keys]
if vehicle_data and any(vehicle_data):
metrics['all_vehicles'] = {
'ap50_95': np.mean([v.get('ap50_95', 0) for v in vehicle_data if v]),
'ap50': np.mean([v.get('ap50', 0) for v in vehicle_data if v]),
'P': np.mean([v.get('P', 0) for v in vehicle_data if v]),
'R': np.mean([v.get('R', 0) for v in vehicle_data if v])
}
else:
metrics['all_vehicles'] = {'ap50_95': 0, 'ap50': 0, 'P': 0, 'R': 0}
return metrics
except Exception as e:
print(f"Error parsing {txt_file}: {e}")
return {}
def generate_report(self):
"""生成完整报告"""
print("="*70)
print("YOLO11n Person & Vehicle Detection Performance Report")
print("="*70)
print(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print()
# 解析每个结果文件
result_files = {
"FP32_PyTorch": "fp32_results.txt",
"INT8_640p": "int8_640_results.txt",
"FP16_640p": "fp16_640_results.txt",
"FP16_480p": "fp16_480_results.txt",
}
for name, filename in result_files.items():
if Path(filename).exists():
print(f"Parsing {name}...")
self.results[name] = self.extract_from_txt(filename)
else:
print(f"Warning: {filename} not found!")
# 1. 整体性能对比
print("\n" + "-"*70)
print("1. Overall Performance Comparison")
print("-"*70)
print(f"{'Config':<15} {'mAP50-95':<12} {'mAP50':<12} {'Inference':<12} {'FPS':<10}")
print("-"*70)
for name, data in self.results.items():
map50_95 = data.get('mAP50_95', 0)
map50 = data.get('mAP50', 0)
inf_ms = data.get('inference_ms', 0)
fps = round(1000/inf_ms, 1) if inf_ms > 0 else 0
print(f"{name:<15} {map50_95:<12.4f} {map50:<12.4f} {inf_ms:<12.1f} {fps:<10.1f}")
print()
# 2. Person类别
print("-"*70)
print("2. Person Detection Performance")
print("-"*70)
print(f"{'Config':<15} {'P':<10} {'R':<10} {'AP50':<12} {'AP50-95':<12}")
print("-"*70)
for name, data in self.results.items():
person = data.get('person', {})
p = person.get('P', 0)
r = person.get('R', 0)
ap50 = person.get('ap50', 0)
ap50_95 = person.get('ap50_95', 0)
print(f"{name:<15} {p:<10.3f} {r:<10.3f} {ap50:<12.4f} {ap50_95:<12.4f}")
print()
# 3. Vehicles类别
print("-"*70)
print("3. Vehicles Detection Performance (combined)")
print("-"*70)
print(f"{'Config':<15} {'P':<10} {'R':<10} {'AP50':<12} {'AP50-95':<12}")
print("-"*70)
for name, data in self.results.items():
vehicles = data.get('all_vehicles', {})
p = vehicles.get('P', 0)
r = vehicles.get('R', 0)
ap50 = vehicles.get('ap50', 0)
ap50_95 = vehicles.get('ap50_95', 0)
print(f"{name:<15} {p:<10.3f} {r:<10.3f} {ap50:<12.4f} {ap50_95:<12.4f}")
print()
# 4. 速度对比
print("-"*70)
print("4. Inference Speed Comparison")
print("-"*70)
speeds = [(name, data.get('inference_ms', 0)) for name, data in self.results.items()
if data.get('inference_ms', 0) > 0]
speeds.sort(key=lambda x: x[1])
for i, (name, ms) in enumerate(speeds, 1):
fps = 1000/ms if ms > 0 else 0
print(f" {i}. {name}: {ms:.2f}ms ({fps:.1f} FPS)")
print()
# 5. 掉点分析
print("-"*70)
print("5. mAP Drop Analysis (vs FP32)")
print("-"*70)
fp32_map = self.results.get('FP32_PyTorch', {}).get('mAP50_95', 0)
if fp32_map > 0:
for name, data in self.results.items():
if name == 'FP32_PyTorch' or data.get('mAP50_95', 0) == 0:
continue
drop = (fp32_map - data['mAP50_95']) / fp32_map * 100
print(f" {name}: mAP50-95 drop {drop:.2f}%")
print()
# 6. 结论
print("="*70)
print("6. Conclusions & Recommendations")
print("="*70)
print()
print("1. Best Accuracy: Choose FP16_640p or INT8_640p")
print("2. Fastest Speed: Choose FP16_480p")
print("3. Best Balance: FP16_640p (accuracy vs speed)")
print("4. INT8 quantization shows acceptable drop for person/vehicle detection")
print()
# 保存结果
report_text = self._get_report_text()
report_file = self.results_dir / "final_report.txt"
with open(report_file, 'w', encoding='utf-8') as f:
f.write(report_text)
# 保存JSON
with open(self.results_dir / "all_results.json", 'w', encoding='utf-8') as f:
json.dump(self.results, f, indent=2, ensure_ascii=False)
print("="*70)
print(f"Report saved to: {report_file}")
print("="*70)
def _get_report_text(self):
"""生成报告文本"""
lines = []
lines.append("="*70)
lines.append("YOLO11n Person & Vehicle Detection Performance Report")
lines.append("="*70)
lines.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
lines.append("")
# 整体对比
lines.append("-"*70)
lines.append("1. Overall Performance Comparison")
lines.append("-"*70)
lines.append(f"{'Config':<15} {'mAP50-95':<12} {'mAP50':<12} {'Inference':<12} {'FPS':<10}")
lines.append("-"*70)
for name, data in self.results.items():
map50_95 = data.get('mAP50_95', 0)
map50 = data.get('mAP50', 0)
inf_ms = data.get('inference_ms', 0)
fps = round(1000/inf_ms, 1) if inf_ms > 0 else 0
lines.append(f"{name:<15} {map50_95:<12.4f} {map50:<12.4f} {inf_ms:<12.1f} {fps:<10.1f}")
lines.append("")
lines.append("Full report content...")
return '\n'.join(lines)
if __name__ == "__main__":
parser = DirectResultsParser()
parser.generate_report()