234 lines
7.6 KiB
Python
234 lines
7.6 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
TensorRT 环境测试脚本
|
|||
|
|
测试 TensorRT 是否可以在当前环境中正常运行
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
import traceback
|
|||
|
|
|
|||
|
|
def test_basic_imports():
|
|||
|
|
"""测试基础库导入"""
|
|||
|
|
print("=" * 50)
|
|||
|
|
print("1. 测试基础库导入...")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import torch
|
|||
|
|
print(f"✅ PyTorch 版本: {torch.__version__}")
|
|||
|
|
print(f"✅ CUDA 可用: {torch.cuda.is_available()}")
|
|||
|
|
if torch.cuda.is_available():
|
|||
|
|
print(f"✅ CUDA 版本: {torch.version.cuda}")
|
|||
|
|
print(f"✅ GPU 数量: {torch.cuda.device_count()}")
|
|||
|
|
for i in range(torch.cuda.device_count()):
|
|||
|
|
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
|||
|
|
except ImportError as e:
|
|||
|
|
print(f"❌ PyTorch 导入失败: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import tensorrt as trt
|
|||
|
|
print(f"✅ TensorRT 版本: {trt.__version__}")
|
|||
|
|
except ImportError as e:
|
|||
|
|
print(f"❌ TensorRT 导入失败: {e}")
|
|||
|
|
print("提示: 请确保已安装 TensorRT")
|
|||
|
|
print("安装命令: pip install tensorrt")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
print(f"✅ Ultralytics YOLO 可用")
|
|||
|
|
except ImportError as e:
|
|||
|
|
print(f"❌ Ultralytics 导入失败: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def test_tensorrt_basic():
|
|||
|
|
"""测试 TensorRT 基础功能"""
|
|||
|
|
print("\n" + "=" * 50)
|
|||
|
|
print("2. 测试 TensorRT 基础功能...")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import tensorrt as trt
|
|||
|
|
|
|||
|
|
# 创建 TensorRT Logger
|
|||
|
|
logger = trt.Logger(trt.Logger.WARNING)
|
|||
|
|
print("✅ TensorRT Logger 创建成功")
|
|||
|
|
|
|||
|
|
# 创建 Builder
|
|||
|
|
builder = trt.Builder(logger)
|
|||
|
|
print("✅ TensorRT Builder 创建成功")
|
|||
|
|
|
|||
|
|
# 创建 Network
|
|||
|
|
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
|||
|
|
print("✅ TensorRT Network 创建成功")
|
|||
|
|
|
|||
|
|
# 创建 Config
|
|||
|
|
config = builder.create_builder_config()
|
|||
|
|
print("✅ TensorRT Config 创建成功")
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ TensorRT 基础功能测试失败: {e}")
|
|||
|
|
traceback.print_exc()
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def test_yolo_tensorrt_export():
|
|||
|
|
"""测试 YOLO 模型导出为 TensorRT"""
|
|||
|
|
print("\n" + "=" * 50)
|
|||
|
|
print("3. 测试 YOLO 模型 TensorRT 导出...")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
from ultralytics import YOLO
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
# 检查模型文件是否存在
|
|||
|
|
model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
|
|||
|
|
if not os.path.exists(model_path):
|
|||
|
|
print(f"❌ 模型文件不存在: {model_path}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
print(f"✅ 找到模型文件: {model_path}")
|
|||
|
|
|
|||
|
|
# 加载模型
|
|||
|
|
model = YOLO(model_path)
|
|||
|
|
print("✅ YOLO 模型加载成功")
|
|||
|
|
|
|||
|
|
# 尝试导出为 TensorRT(仅测试,不实际导出)
|
|||
|
|
print("📝 准备测试 TensorRT 导出功能...")
|
|||
|
|
print(" 注意: 实际导出需要较长时间,这里仅测试导出接口")
|
|||
|
|
|
|||
|
|
# 检查导出方法是否可用
|
|||
|
|
if hasattr(model, 'export'):
|
|||
|
|
print("✅ YOLO 模型支持导出功能")
|
|||
|
|
|
|||
|
|
# 测试导出参数(不实际执行)
|
|||
|
|
export_params = {
|
|||
|
|
'format': 'engine', # TensorRT engine format
|
|||
|
|
'imgsz': 640,
|
|||
|
|
'device': 0 if torch.cuda.is_available() else 'cpu',
|
|||
|
|
'half': True, # FP16
|
|||
|
|
'dynamic': False,
|
|||
|
|
'simplify': True,
|
|||
|
|
'workspace': 4, # GB
|
|||
|
|
}
|
|||
|
|
print(f"✅ 导出参数配置完成: {export_params}")
|
|||
|
|
|
|||
|
|
return True
|
|||
|
|
else:
|
|||
|
|
print("❌ YOLO 模型不支持导出功能")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ YOLO TensorRT 导出测试失败: {e}")
|
|||
|
|
traceback.print_exc()
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def test_gpu_memory():
|
|||
|
|
"""测试 GPU 内存"""
|
|||
|
|
print("\n" + "=" * 50)
|
|||
|
|
print("4. 测试 GPU 内存...")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
if not torch.cuda.is_available():
|
|||
|
|
print("❌ CUDA 不可用,跳过 GPU 内存测试")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
device = torch.device('cuda:0')
|
|||
|
|
|
|||
|
|
# 获取 GPU 内存信息
|
|||
|
|
total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
|
|||
|
|
allocated_memory = torch.cuda.memory_allocated(0) / 1024**3 # GB
|
|||
|
|
cached_memory = torch.cuda.memory_reserved(0) / 1024**3 # GB
|
|||
|
|
|
|||
|
|
print(f"✅ GPU 总内存: {total_memory:.2f} GB")
|
|||
|
|
print(f"✅ 已分配内存: {allocated_memory:.2f} GB")
|
|||
|
|
print(f"✅ 缓存内存: {cached_memory:.2f} GB")
|
|||
|
|
print(f"✅ 可用内存: {total_memory - cached_memory:.2f} GB")
|
|||
|
|
|
|||
|
|
# 建议的最小内存要求
|
|||
|
|
min_required_memory = 4.0 # GB
|
|||
|
|
if total_memory >= min_required_memory:
|
|||
|
|
print(f"✅ GPU 内存充足 (>= {min_required_memory} GB)")
|
|||
|
|
return True
|
|||
|
|
else:
|
|||
|
|
print(f"⚠️ GPU 内存可能不足 (< {min_required_memory} GB)")
|
|||
|
|
print(" 建议: 使用较小的批次大小或降低输入分辨率")
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ GPU 内存测试失败: {e}")
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
def test_environment_summary():
|
|||
|
|
"""环境测试总结"""
|
|||
|
|
print("\n" + "=" * 50)
|
|||
|
|
print("5. 环境测试总结")
|
|||
|
|
|
|||
|
|
# 运行所有测试
|
|||
|
|
results = []
|
|||
|
|
results.append(("基础库导入", test_basic_imports()))
|
|||
|
|
results.append(("TensorRT 基础功能", test_tensorrt_basic()))
|
|||
|
|
results.append(("YOLO TensorRT 导出", test_yolo_tensorrt_export()))
|
|||
|
|
results.append(("GPU 内存", test_gpu_memory()))
|
|||
|
|
|
|||
|
|
print("\n测试结果:")
|
|||
|
|
print("-" * 30)
|
|||
|
|
all_passed = True
|
|||
|
|
for test_name, passed in results:
|
|||
|
|
status = "✅ 通过" if passed else "❌ 失败"
|
|||
|
|
print(f"{test_name:<20}: {status}")
|
|||
|
|
if not passed:
|
|||
|
|
all_passed = False
|
|||
|
|
|
|||
|
|
print("-" * 30)
|
|||
|
|
if all_passed:
|
|||
|
|
print("🎉 所有测试通过!TensorRT 环境配置正确")
|
|||
|
|
print("✅ 可以开始进行性能对比测试")
|
|||
|
|
else:
|
|||
|
|
print("⚠️ 部分测试失败,请检查环境配置")
|
|||
|
|
print("💡 建议:")
|
|||
|
|
print(" 1. 确保已激活 conda yolov11 环境")
|
|||
|
|
print(" 2. 安装 TensorRT: pip install tensorrt")
|
|||
|
|
print(" 3. 检查 CUDA 和 GPU 驱动")
|
|||
|
|
|
|||
|
|
return all_passed
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""主函数"""
|
|||
|
|
print("TensorRT 环境测试")
|
|||
|
|
print("=" * 50)
|
|||
|
|
print(f"Python 版本: {sys.version}")
|
|||
|
|
print(f"当前工作目录: {os.getcwd()}")
|
|||
|
|
|
|||
|
|
# 检查是否在 conda 环境中
|
|||
|
|
conda_env = os.environ.get('CONDA_DEFAULT_ENV', 'None')
|
|||
|
|
print(f"Conda 环境: {conda_env}")
|
|||
|
|
|
|||
|
|
if conda_env != 'yolov11':
|
|||
|
|
print("⚠️ 警告: 当前不在 yolov11 conda 环境中")
|
|||
|
|
print(" 建议运行: conda activate yolov11")
|
|||
|
|
|
|||
|
|
# 运行环境测试
|
|||
|
|
success = test_environment_summary()
|
|||
|
|
|
|||
|
|
if success:
|
|||
|
|
print("\n🚀 下一步:")
|
|||
|
|
print(" 1. 运行完整的性能对比测试")
|
|||
|
|
print(" 2. 生成 TensorRT 引擎文件")
|
|||
|
|
print(" 3. 对比 PyTorch vs TensorRT 性能")
|
|||
|
|
|
|||
|
|
return success
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
try:
|
|||
|
|
main()
|
|||
|
|
except KeyboardInterrupt:
|
|||
|
|
print("\n\n⏹️ 测试被用户中断")
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"\n❌ 测试过程中发生未知错误: {e}")
|
|||
|
|
traceback.print_exc()
|