234 lines
7.6 KiB
Python
234 lines
7.6 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
TensorRT 环境测试脚本
|
||
测试 TensorRT 是否可以在当前环境中正常运行
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import traceback
|
||
|
||
def test_basic_imports():
|
||
"""测试基础库导入"""
|
||
print("=" * 50)
|
||
print("1. 测试基础库导入...")
|
||
|
||
try:
|
||
import torch
|
||
print(f"✅ PyTorch 版本: {torch.__version__}")
|
||
print(f"✅ CUDA 可用: {torch.cuda.is_available()}")
|
||
if torch.cuda.is_available():
|
||
print(f"✅ CUDA 版本: {torch.version.cuda}")
|
||
print(f"✅ GPU 数量: {torch.cuda.device_count()}")
|
||
for i in range(torch.cuda.device_count()):
|
||
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
|
||
except ImportError as e:
|
||
print(f"❌ PyTorch 导入失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
import tensorrt as trt
|
||
print(f"✅ TensorRT 版本: {trt.__version__}")
|
||
except ImportError as e:
|
||
print(f"❌ TensorRT 导入失败: {e}")
|
||
print("提示: 请确保已安装 TensorRT")
|
||
print("安装命令: pip install tensorrt")
|
||
return False
|
||
|
||
try:
|
||
from ultralytics import YOLO
|
||
print(f"✅ Ultralytics YOLO 可用")
|
||
except ImportError as e:
|
||
print(f"❌ Ultralytics 导入失败: {e}")
|
||
return False
|
||
|
||
return True
|
||
|
||
def test_tensorrt_basic():
|
||
"""测试 TensorRT 基础功能"""
|
||
print("\n" + "=" * 50)
|
||
print("2. 测试 TensorRT 基础功能...")
|
||
|
||
try:
|
||
import tensorrt as trt
|
||
|
||
# 创建 TensorRT Logger
|
||
logger = trt.Logger(trt.Logger.WARNING)
|
||
print("✅ TensorRT Logger 创建成功")
|
||
|
||
# 创建 Builder
|
||
builder = trt.Builder(logger)
|
||
print("✅ TensorRT Builder 创建成功")
|
||
|
||
# 创建 Network
|
||
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||
print("✅ TensorRT Network 创建成功")
|
||
|
||
# 创建 Config
|
||
config = builder.create_builder_config()
|
||
print("✅ TensorRT Config 创建成功")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ TensorRT 基础功能测试失败: {e}")
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
def test_yolo_tensorrt_export():
|
||
"""测试 YOLO 模型导出为 TensorRT"""
|
||
print("\n" + "=" * 50)
|
||
print("3. 测试 YOLO 模型 TensorRT 导出...")
|
||
|
||
try:
|
||
from ultralytics import YOLO
|
||
import torch
|
||
|
||
# 检查模型文件是否存在
|
||
model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
|
||
if not os.path.exists(model_path):
|
||
print(f"❌ 模型文件不存在: {model_path}")
|
||
return False
|
||
|
||
print(f"✅ 找到模型文件: {model_path}")
|
||
|
||
# 加载模型
|
||
model = YOLO(model_path)
|
||
print("✅ YOLO 模型加载成功")
|
||
|
||
# 尝试导出为 TensorRT(仅测试,不实际导出)
|
||
print("📝 准备测试 TensorRT 导出功能...")
|
||
print(" 注意: 实际导出需要较长时间,这里仅测试导出接口")
|
||
|
||
# 检查导出方法是否可用
|
||
if hasattr(model, 'export'):
|
||
print("✅ YOLO 模型支持导出功能")
|
||
|
||
# 测试导出参数(不实际执行)
|
||
export_params = {
|
||
'format': 'engine', # TensorRT engine format
|
||
'imgsz': 640,
|
||
'device': 0 if torch.cuda.is_available() else 'cpu',
|
||
'half': True, # FP16
|
||
'dynamic': False,
|
||
'simplify': True,
|
||
'workspace': 4, # GB
|
||
}
|
||
print(f"✅ 导出参数配置完成: {export_params}")
|
||
|
||
return True
|
||
else:
|
||
print("❌ YOLO 模型不支持导出功能")
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"❌ YOLO TensorRT 导出测试失败: {e}")
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
def test_gpu_memory():
|
||
"""测试 GPU 内存"""
|
||
print("\n" + "=" * 50)
|
||
print("4. 测试 GPU 内存...")
|
||
|
||
try:
|
||
import torch
|
||
|
||
if not torch.cuda.is_available():
|
||
print("❌ CUDA 不可用,跳过 GPU 内存测试")
|
||
return False
|
||
|
||
device = torch.device('cuda:0')
|
||
|
||
# 获取 GPU 内存信息
|
||
total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
|
||
allocated_memory = torch.cuda.memory_allocated(0) / 1024**3 # GB
|
||
cached_memory = torch.cuda.memory_reserved(0) / 1024**3 # GB
|
||
|
||
print(f"✅ GPU 总内存: {total_memory:.2f} GB")
|
||
print(f"✅ 已分配内存: {allocated_memory:.2f} GB")
|
||
print(f"✅ 缓存内存: {cached_memory:.2f} GB")
|
||
print(f"✅ 可用内存: {total_memory - cached_memory:.2f} GB")
|
||
|
||
# 建议的最小内存要求
|
||
min_required_memory = 4.0 # GB
|
||
if total_memory >= min_required_memory:
|
||
print(f"✅ GPU 内存充足 (>= {min_required_memory} GB)")
|
||
return True
|
||
else:
|
||
print(f"⚠️ GPU 内存可能不足 (< {min_required_memory} GB)")
|
||
print(" 建议: 使用较小的批次大小或降低输入分辨率")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ GPU 内存测试失败: {e}")
|
||
return False
|
||
|
||
def test_environment_summary():
|
||
"""环境测试总结"""
|
||
print("\n" + "=" * 50)
|
||
print("5. 环境测试总结")
|
||
|
||
# 运行所有测试
|
||
results = []
|
||
results.append(("基础库导入", test_basic_imports()))
|
||
results.append(("TensorRT 基础功能", test_tensorrt_basic()))
|
||
results.append(("YOLO TensorRT 导出", test_yolo_tensorrt_export()))
|
||
results.append(("GPU 内存", test_gpu_memory()))
|
||
|
||
print("\n测试结果:")
|
||
print("-" * 30)
|
||
all_passed = True
|
||
for test_name, passed in results:
|
||
status = "✅ 通过" if passed else "❌ 失败"
|
||
print(f"{test_name:<20}: {status}")
|
||
if not passed:
|
||
all_passed = False
|
||
|
||
print("-" * 30)
|
||
if all_passed:
|
||
print("🎉 所有测试通过!TensorRT 环境配置正确")
|
||
print("✅ 可以开始进行性能对比测试")
|
||
else:
|
||
print("⚠️ 部分测试失败,请检查环境配置")
|
||
print("💡 建议:")
|
||
print(" 1. 确保已激活 conda yolov11 环境")
|
||
print(" 2. 安装 TensorRT: pip install tensorrt")
|
||
print(" 3. 检查 CUDA 和 GPU 驱动")
|
||
|
||
return all_passed
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("TensorRT 环境测试")
|
||
print("=" * 50)
|
||
print(f"Python 版本: {sys.version}")
|
||
print(f"当前工作目录: {os.getcwd()}")
|
||
|
||
# 检查是否在 conda 环境中
|
||
conda_env = os.environ.get('CONDA_DEFAULT_ENV', 'None')
|
||
print(f"Conda 环境: {conda_env}")
|
||
|
||
if conda_env != 'yolov11':
|
||
print("⚠️ 警告: 当前不在 yolov11 conda 环境中")
|
||
print(" 建议运行: conda activate yolov11")
|
||
|
||
# 运行环境测试
|
||
success = test_environment_summary()
|
||
|
||
if success:
|
||
print("\n🚀 下一步:")
|
||
print(" 1. 运行完整的性能对比测试")
|
||
print(" 2. 生成 TensorRT 引擎文件")
|
||
print(" 3. 对比 PyTorch vs TensorRT 性能")
|
||
|
||
return success
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
main()
|
||
except KeyboardInterrupt:
|
||
print("\n\n⏹️ 测试被用户中断")
|
||
except Exception as e:
|
||
print(f"\n❌ 测试过程中发生未知错误: {e}")
|
||
traceback.print_exc() |