#!/usr/bin/env python3 """ 测试 PyTorch 在大批次(16, 32)下的性能 补充完整的 PyTorch 基准数据 """ import os import time import json import numpy as np import torch from datetime import datetime from ultralytics import YOLO def test_pytorch_batch_performance(model_path, batch_sizes, test_duration=20): """测试 PyTorch 批次性能""" print("🚀 开始测试 PyTorch 批次性能") print("=" * 60) # 加载 PyTorch 模型 print(f"📦 加载 PyTorch 模型: {model_path}") model = YOLO(model_path) print("✅ 模型加载成功") results = {} for batch_size in batch_sizes: print(f"\n🔄 测试批次大小: {batch_size} (测试时长: {test_duration}秒)") try: # 预热 print("🔥 预热中...") for _ in range(5): test_images = [np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) for _ in range(batch_size)] model(test_images, verbose=False) # 正式测试 fps_list = [] latency_list = [] batch_count = 0 start_time = time.time() last_fps_time = start_time fps_batch_count = 0 while time.time() - start_time < test_duration: # 生成测试数据 test_images = [np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) for _ in range(batch_size)] # 推理 infer_start = time.time() model(test_images, verbose=False) infer_end = time.time() latency_ms = (infer_end - infer_start) * 1000 latency_list.append(latency_ms) batch_count += 1 fps_batch_count += 1 # 每秒计算一次 FPS current_time = time.time() if current_time - last_fps_time >= 1.0: fps = (fps_batch_count * batch_size) / (current_time - last_fps_time) fps_list.append(fps) fps_batch_count = 0 last_fps_time = current_time # 显示进度 elapsed = current_time - start_time print(f" 进度: {elapsed:.1f}s/{test_duration}s, " f"当前FPS: {fps:.1f}, 延迟: {latency_ms:.1f}ms") # 计算结果 total_time = time.time() - start_time total_frames = batch_count * batch_size avg_fps = np.mean(fps_list) if fps_list else 0 avg_latency_ms = np.mean(latency_list) results[batch_size] = { 'avg_fps': avg_fps, 'avg_latency_ms': avg_latency_ms, 'total_frames': total_frames, 'test_duration': total_time, 'success': True } print(f"✅ 批次 {batch_size} 测试完成:") print(f" 平均FPS: {avg_fps:.1f}") print(f" 平均延迟: {avg_latency_ms:.1f}ms") except Exception as e: print(f"❌ 批次 {batch_size} 测试失败: {e}") results[batch_size] = { 'avg_fps': 0, 'avg_latency_ms': 0, 'success': False, 'error': str(e) } return results def main(): """主函数""" print("PyTorch 大批次性能测试") print("=" * 60) # PyTorch 模型路径 model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt" # 检查模型文件 if not os.path.exists(model_path): print(f"❌ PyTorch 模型不存在: {model_path}") return # 检查 CUDA if not torch.cuda.is_available(): print("❌ CUDA 不可用") return print(f"✅ CUDA 可用,设备: {torch.cuda.get_device_name(0)}") print(f"✅ PyTorch 模型: {model_path}") # 测试批次大小(只测试 16 和 32) batch_sizes = [16, 32] test_duration = 20 # 每批次测试 20 秒 print(f"\n📊 测试配置:") print(f" 批次大小: {batch_sizes}") print(f" 每批次测试时长: {test_duration}秒") try: # 测试 PyTorch 性能 pytorch_results = test_pytorch_batch_performance(model_path, batch_sizes, test_duration) # 保存结果 output_dir = "pytorch_results" os.makedirs(output_dir, exist_ok=True) # 保存 JSON 数据 results_data = { 'framework': 'PyTorch', 'model': model_path, 'batch_sizes': batch_sizes, 'results': pytorch_results, 'timestamp': datetime.now().isoformat() } json_file = os.path.join(output_dir, f"pytorch_batch_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json") with open(json_file, 'w', encoding='utf-8') as f: json.dump(results_data, f, indent=2, ensure_ascii=False) print(f"\n✅ 测试数据已保存: {json_file}") # 生成简单报告 print("\n" + "=" * 60) print("PyTorch 大批次性能测试结果") print("=" * 60) for bs in batch_sizes: result = pytorch_results[bs] if result['success']: print(f"\n批次大小: {bs}") print(f" 平均 FPS: {result['avg_fps']:.1f}") print(f" 平均延迟: {result['avg_latency_ms']:.1f}ms") print(f" 总帧数: {result['total_frames']}") else: print(f"\n批次大小: {bs}") print(f" 状态: 测试失败 - {result.get('error', '未知错误')}") print(f"\n🎉 测试完成!") print(f"📁 结果已保存到: {output_dir}/") # 显示下一步操作 print("\n" + "=" * 60) print("📌 下一步操作:") print(" 1. 使用这些数据更新 batch_comparison_test.py 中的 PYTORCH_DATA") print(" 2. 运行完整的 PyTorch vs TensorRT 对比测试") print("=" * 60) except KeyboardInterrupt: print("\n\n⏹️ 测试被用户中断") except Exception as e: print(f"\n❌ 测试过程中发生错误: {e}") import traceback traceback.print_exc() if __name__ == "__main__": main()