Files
Test_AI/test_pytorch_large_batch.py
2026-01-20 11:14:10 +08:00

193 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试 PyTorch 在大批次16, 32下的性能
补充完整的 PyTorch 基准数据
"""
import os
import time
import json
import numpy as np
import torch
from datetime import datetime
from ultralytics import YOLO
def test_pytorch_batch_performance(model_path, batch_sizes, test_duration=20):
"""测试 PyTorch 批次性能"""
print("🚀 开始测试 PyTorch 批次性能")
print("=" * 60)
# 加载 PyTorch 模型
print(f"📦 加载 PyTorch 模型: {model_path}")
model = YOLO(model_path)
print("✅ 模型加载成功")
results = {}
for batch_size in batch_sizes:
print(f"\n🔄 测试批次大小: {batch_size} (测试时长: {test_duration}秒)")
try:
# 预热
print("🔥 预热中...")
for _ in range(5):
test_images = [np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
for _ in range(batch_size)]
model(test_images, verbose=False)
# 正式测试
fps_list = []
latency_list = []
batch_count = 0
start_time = time.time()
last_fps_time = start_time
fps_batch_count = 0
while time.time() - start_time < test_duration:
# 生成测试数据
test_images = [np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
for _ in range(batch_size)]
# 推理
infer_start = time.time()
model(test_images, verbose=False)
infer_end = time.time()
latency_ms = (infer_end - infer_start) * 1000
latency_list.append(latency_ms)
batch_count += 1
fps_batch_count += 1
# 每秒计算一次 FPS
current_time = time.time()
if current_time - last_fps_time >= 1.0:
fps = (fps_batch_count * batch_size) / (current_time - last_fps_time)
fps_list.append(fps)
fps_batch_count = 0
last_fps_time = current_time
# 显示进度
elapsed = current_time - start_time
print(f" 进度: {elapsed:.1f}s/{test_duration}s, "
f"当前FPS: {fps:.1f}, 延迟: {latency_ms:.1f}ms")
# 计算结果
total_time = time.time() - start_time
total_frames = batch_count * batch_size
avg_fps = np.mean(fps_list) if fps_list else 0
avg_latency_ms = np.mean(latency_list)
results[batch_size] = {
'avg_fps': avg_fps,
'avg_latency_ms': avg_latency_ms,
'total_frames': total_frames,
'test_duration': total_time,
'success': True
}
print(f"✅ 批次 {batch_size} 测试完成:")
print(f" 平均FPS: {avg_fps:.1f}")
print(f" 平均延迟: {avg_latency_ms:.1f}ms")
except Exception as e:
print(f"❌ 批次 {batch_size} 测试失败: {e}")
results[batch_size] = {
'avg_fps': 0,
'avg_latency_ms': 0,
'success': False,
'error': str(e)
}
return results
def main():
"""主函数"""
print("PyTorch 大批次性能测试")
print("=" * 60)
# PyTorch 模型路径
model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
# 检查模型文件
if not os.path.exists(model_path):
print(f"❌ PyTorch 模型不存在: {model_path}")
return
# 检查 CUDA
if not torch.cuda.is_available():
print("❌ CUDA 不可用")
return
print(f"✅ CUDA 可用,设备: {torch.cuda.get_device_name(0)}")
print(f"✅ PyTorch 模型: {model_path}")
# 测试批次大小(只测试 16 和 32
batch_sizes = [16, 32]
test_duration = 20 # 每批次测试 20 秒
print(f"\n📊 测试配置:")
print(f" 批次大小: {batch_sizes}")
print(f" 每批次测试时长: {test_duration}")
try:
# 测试 PyTorch 性能
pytorch_results = test_pytorch_batch_performance(model_path, batch_sizes, test_duration)
# 保存结果
output_dir = "pytorch_results"
os.makedirs(output_dir, exist_ok=True)
# 保存 JSON 数据
results_data = {
'framework': 'PyTorch',
'model': model_path,
'batch_sizes': batch_sizes,
'results': pytorch_results,
'timestamp': datetime.now().isoformat()
}
json_file = os.path.join(output_dir, f"pytorch_batch_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(results_data, f, indent=2, ensure_ascii=False)
print(f"\n✅ 测试数据已保存: {json_file}")
# 生成简单报告
print("\n" + "=" * 60)
print("PyTorch 大批次性能测试结果")
print("=" * 60)
for bs in batch_sizes:
result = pytorch_results[bs]
if result['success']:
print(f"\n批次大小: {bs}")
print(f" 平均 FPS: {result['avg_fps']:.1f}")
print(f" 平均延迟: {result['avg_latency_ms']:.1f}ms")
print(f" 总帧数: {result['total_frames']}")
else:
print(f"\n批次大小: {bs}")
print(f" 状态: 测试失败 - {result.get('error', '未知错误')}")
print(f"\n🎉 测试完成!")
print(f"📁 结果已保存到: {output_dir}/")
# 显示下一步操作
print("\n" + "=" * 60)
print("📌 下一步操作:")
print(" 1. 使用这些数据更新 batch_comparison_test.py 中的 PYTORCH_DATA")
print(" 2. 运行完整的 PyTorch vs TensorRT 对比测试")
print("=" * 60)
except KeyboardInterrupt:
print("\n\n⏹️ 测试被用户中断")
except Exception as e:
print(f"\n❌ 测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()