193 lines
6.5 KiB
Python
193 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试 PyTorch 在大批次(16, 32)下的性能
|
||
补充完整的 PyTorch 基准数据
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import json
|
||
import numpy as np
|
||
import torch
|
||
from datetime import datetime
|
||
from ultralytics import YOLO
|
||
|
||
def test_pytorch_batch_performance(model_path, batch_sizes, test_duration=20):
|
||
"""测试 PyTorch 批次性能"""
|
||
print("🚀 开始测试 PyTorch 批次性能")
|
||
print("=" * 60)
|
||
|
||
# 加载 PyTorch 模型
|
||
print(f"📦 加载 PyTorch 模型: {model_path}")
|
||
model = YOLO(model_path)
|
||
print("✅ 模型加载成功")
|
||
|
||
results = {}
|
||
|
||
for batch_size in batch_sizes:
|
||
print(f"\n🔄 测试批次大小: {batch_size} (测试时长: {test_duration}秒)")
|
||
|
||
try:
|
||
# 预热
|
||
print("🔥 预热中...")
|
||
for _ in range(5):
|
||
test_images = [np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
|
||
for _ in range(batch_size)]
|
||
model(test_images, verbose=False)
|
||
|
||
# 正式测试
|
||
fps_list = []
|
||
latency_list = []
|
||
batch_count = 0
|
||
|
||
start_time = time.time()
|
||
last_fps_time = start_time
|
||
fps_batch_count = 0
|
||
|
||
while time.time() - start_time < test_duration:
|
||
# 生成测试数据
|
||
test_images = [np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
|
||
for _ in range(batch_size)]
|
||
|
||
# 推理
|
||
infer_start = time.time()
|
||
model(test_images, verbose=False)
|
||
infer_end = time.time()
|
||
|
||
latency_ms = (infer_end - infer_start) * 1000
|
||
latency_list.append(latency_ms)
|
||
|
||
batch_count += 1
|
||
fps_batch_count += 1
|
||
|
||
# 每秒计算一次 FPS
|
||
current_time = time.time()
|
||
if current_time - last_fps_time >= 1.0:
|
||
fps = (fps_batch_count * batch_size) / (current_time - last_fps_time)
|
||
fps_list.append(fps)
|
||
fps_batch_count = 0
|
||
last_fps_time = current_time
|
||
|
||
# 显示进度
|
||
elapsed = current_time - start_time
|
||
print(f" 进度: {elapsed:.1f}s/{test_duration}s, "
|
||
f"当前FPS: {fps:.1f}, 延迟: {latency_ms:.1f}ms")
|
||
|
||
# 计算结果
|
||
total_time = time.time() - start_time
|
||
total_frames = batch_count * batch_size
|
||
|
||
avg_fps = np.mean(fps_list) if fps_list else 0
|
||
avg_latency_ms = np.mean(latency_list)
|
||
|
||
results[batch_size] = {
|
||
'avg_fps': avg_fps,
|
||
'avg_latency_ms': avg_latency_ms,
|
||
'total_frames': total_frames,
|
||
'test_duration': total_time,
|
||
'success': True
|
||
}
|
||
|
||
print(f"✅ 批次 {batch_size} 测试完成:")
|
||
print(f" 平均FPS: {avg_fps:.1f}")
|
||
print(f" 平均延迟: {avg_latency_ms:.1f}ms")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 批次 {batch_size} 测试失败: {e}")
|
||
results[batch_size] = {
|
||
'avg_fps': 0,
|
||
'avg_latency_ms': 0,
|
||
'success': False,
|
||
'error': str(e)
|
||
}
|
||
|
||
return results
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("PyTorch 大批次性能测试")
|
||
print("=" * 60)
|
||
|
||
# PyTorch 模型路径
|
||
model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
|
||
|
||
# 检查模型文件
|
||
if not os.path.exists(model_path):
|
||
print(f"❌ PyTorch 模型不存在: {model_path}")
|
||
return
|
||
|
||
# 检查 CUDA
|
||
if not torch.cuda.is_available():
|
||
print("❌ CUDA 不可用")
|
||
return
|
||
|
||
print(f"✅ CUDA 可用,设备: {torch.cuda.get_device_name(0)}")
|
||
print(f"✅ PyTorch 模型: {model_path}")
|
||
|
||
# 测试批次大小(只测试 16 和 32)
|
||
batch_sizes = [16, 32]
|
||
test_duration = 20 # 每批次测试 20 秒
|
||
|
||
print(f"\n📊 测试配置:")
|
||
print(f" 批次大小: {batch_sizes}")
|
||
print(f" 每批次测试时长: {test_duration}秒")
|
||
|
||
try:
|
||
# 测试 PyTorch 性能
|
||
pytorch_results = test_pytorch_batch_performance(model_path, batch_sizes, test_duration)
|
||
|
||
# 保存结果
|
||
output_dir = "pytorch_results"
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 保存 JSON 数据
|
||
results_data = {
|
||
'framework': 'PyTorch',
|
||
'model': model_path,
|
||
'batch_sizes': batch_sizes,
|
||
'results': pytorch_results,
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
json_file = os.path.join(output_dir, f"pytorch_batch_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
|
||
with open(json_file, 'w', encoding='utf-8') as f:
|
||
json.dump(results_data, f, indent=2, ensure_ascii=False)
|
||
|
||
print(f"\n✅ 测试数据已保存: {json_file}")
|
||
|
||
# 生成简单报告
|
||
print("\n" + "=" * 60)
|
||
print("PyTorch 大批次性能测试结果")
|
||
print("=" * 60)
|
||
|
||
for bs in batch_sizes:
|
||
result = pytorch_results[bs]
|
||
if result['success']:
|
||
print(f"\n批次大小: {bs}")
|
||
print(f" 平均 FPS: {result['avg_fps']:.1f}")
|
||
print(f" 平均延迟: {result['avg_latency_ms']:.1f}ms")
|
||
print(f" 总帧数: {result['total_frames']}")
|
||
else:
|
||
print(f"\n批次大小: {bs}")
|
||
print(f" 状态: 测试失败 - {result.get('error', '未知错误')}")
|
||
|
||
print(f"\n🎉 测试完成!")
|
||
print(f"📁 结果已保存到: {output_dir}/")
|
||
|
||
# 显示下一步操作
|
||
print("\n" + "=" * 60)
|
||
print("📌 下一步操作:")
|
||
print(" 1. 使用这些数据更新 batch_comparison_test.py 中的 PYTORCH_DATA")
|
||
print(" 2. 运行完整的 PyTorch vs TensorRT 对比测试")
|
||
print("=" * 60)
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n⏹️ 测试被用户中断")
|
||
except Exception as e:
|
||
print(f"\n❌ 测试过程中发生错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|