Files
Test_AI/test_tensorrt_env.py
2026-01-20 11:14:10 +08:00

234 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
TensorRT 环境测试脚本
测试 TensorRT 是否可以在当前环境中正常运行
"""
import sys
import os
import traceback
def test_basic_imports():
"""测试基础库导入"""
print("=" * 50)
print("1. 测试基础库导入...")
try:
import torch
print(f"✅ PyTorch 版本: {torch.__version__}")
print(f"✅ CUDA 可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"✅ CUDA 版本: {torch.version.cuda}")
print(f"✅ GPU 数量: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
except ImportError as e:
print(f"❌ PyTorch 导入失败: {e}")
return False
try:
import tensorrt as trt
print(f"✅ TensorRT 版本: {trt.__version__}")
except ImportError as e:
print(f"❌ TensorRT 导入失败: {e}")
print("提示: 请确保已安装 TensorRT")
print("安装命令: pip install tensorrt")
return False
try:
from ultralytics import YOLO
print(f"✅ Ultralytics YOLO 可用")
except ImportError as e:
print(f"❌ Ultralytics 导入失败: {e}")
return False
return True
def test_tensorrt_basic():
"""测试 TensorRT 基础功能"""
print("\n" + "=" * 50)
print("2. 测试 TensorRT 基础功能...")
try:
import tensorrt as trt
# 创建 TensorRT Logger
logger = trt.Logger(trt.Logger.WARNING)
print("✅ TensorRT Logger 创建成功")
# 创建 Builder
builder = trt.Builder(logger)
print("✅ TensorRT Builder 创建成功")
# 创建 Network
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
print("✅ TensorRT Network 创建成功")
# 创建 Config
config = builder.create_builder_config()
print("✅ TensorRT Config 创建成功")
return True
except Exception as e:
print(f"❌ TensorRT 基础功能测试失败: {e}")
traceback.print_exc()
return False
def test_yolo_tensorrt_export():
"""测试 YOLO 模型导出为 TensorRT"""
print("\n" + "=" * 50)
print("3. 测试 YOLO 模型 TensorRT 导出...")
try:
from ultralytics import YOLO
import torch
# 检查模型文件是否存在
model_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
if not os.path.exists(model_path):
print(f"❌ 模型文件不存在: {model_path}")
return False
print(f"✅ 找到模型文件: {model_path}")
# 加载模型
model = YOLO(model_path)
print("✅ YOLO 模型加载成功")
# 尝试导出为 TensorRT仅测试不实际导出
print("📝 准备测试 TensorRT 导出功能...")
print(" 注意: 实际导出需要较长时间,这里仅测试导出接口")
# 检查导出方法是否可用
if hasattr(model, 'export'):
print("✅ YOLO 模型支持导出功能")
# 测试导出参数(不实际执行)
export_params = {
'format': 'engine', # TensorRT engine format
'imgsz': 640,
'device': 0 if torch.cuda.is_available() else 'cpu',
'half': True, # FP16
'dynamic': False,
'simplify': True,
'workspace': 4, # GB
}
print(f"✅ 导出参数配置完成: {export_params}")
return True
else:
print("❌ YOLO 模型不支持导出功能")
return False
except Exception as e:
print(f"❌ YOLO TensorRT 导出测试失败: {e}")
traceback.print_exc()
return False
def test_gpu_memory():
"""测试 GPU 内存"""
print("\n" + "=" * 50)
print("4. 测试 GPU 内存...")
try:
import torch
if not torch.cuda.is_available():
print("❌ CUDA 不可用,跳过 GPU 内存测试")
return False
device = torch.device('cuda:0')
# 获取 GPU 内存信息
total_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 # GB
allocated_memory = torch.cuda.memory_allocated(0) / 1024**3 # GB
cached_memory = torch.cuda.memory_reserved(0) / 1024**3 # GB
print(f"✅ GPU 总内存: {total_memory:.2f} GB")
print(f"✅ 已分配内存: {allocated_memory:.2f} GB")
print(f"✅ 缓存内存: {cached_memory:.2f} GB")
print(f"✅ 可用内存: {total_memory - cached_memory:.2f} GB")
# 建议的最小内存要求
min_required_memory = 4.0 # GB
if total_memory >= min_required_memory:
print(f"✅ GPU 内存充足 (>= {min_required_memory} GB)")
return True
else:
print(f"⚠️ GPU 内存可能不足 (< {min_required_memory} GB)")
print(" 建议: 使用较小的批次大小或降低输入分辨率")
return True
except Exception as e:
print(f"❌ GPU 内存测试失败: {e}")
return False
def test_environment_summary():
"""环境测试总结"""
print("\n" + "=" * 50)
print("5. 环境测试总结")
# 运行所有测试
results = []
results.append(("基础库导入", test_basic_imports()))
results.append(("TensorRT 基础功能", test_tensorrt_basic()))
results.append(("YOLO TensorRT 导出", test_yolo_tensorrt_export()))
results.append(("GPU 内存", test_gpu_memory()))
print("\n测试结果:")
print("-" * 30)
all_passed = True
for test_name, passed in results:
status = "✅ 通过" if passed else "❌ 失败"
print(f"{test_name:<20}: {status}")
if not passed:
all_passed = False
print("-" * 30)
if all_passed:
print("🎉 所有测试通过TensorRT 环境配置正确")
print("✅ 可以开始进行性能对比测试")
else:
print("⚠️ 部分测试失败,请检查环境配置")
print("💡 建议:")
print(" 1. 确保已激活 conda yolov11 环境")
print(" 2. 安装 TensorRT: pip install tensorrt")
print(" 3. 检查 CUDA 和 GPU 驱动")
return all_passed
def main():
"""主函数"""
print("TensorRT 环境测试")
print("=" * 50)
print(f"Python 版本: {sys.version}")
print(f"当前工作目录: {os.getcwd()}")
# 检查是否在 conda 环境中
conda_env = os.environ.get('CONDA_DEFAULT_ENV', 'None')
print(f"Conda 环境: {conda_env}")
if conda_env != 'yolov11':
print("⚠️ 警告: 当前不在 yolov11 conda 环境中")
print(" 建议运行: conda activate yolov11")
# 运行环境测试
success = test_environment_summary()
if success:
print("\n🚀 下一步:")
print(" 1. 运行完整的性能对比测试")
print(" 2. 生成 TensorRT 引擎文件")
print(" 3. 对比 PyTorch vs TensorRT 性能")
return success
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⏹️ 测试被用户中断")
except Exception as e:
print(f"\n❌ 测试过程中发生未知错误: {e}")
traceback.print_exc()