Files
Test_AI/test_tensorrt_load.py
2026-01-20 11:14:10 +08:00

119 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试 TensorRT 引擎加载和推理
使用随机图像测试
"""
import numpy as np
import time
from ultralytics import YOLO
def test_tensorrt_engine():
"""测试 TensorRT 引擎"""
print("TensorRT 引擎测试")
print("=" * 60)
engine_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.engine"
# 1. 加载引擎
print(f"🚀 加载 TensorRT 引擎: {engine_path}")
try:
model = YOLO(engine_path, task='detect')
print("✅ 引擎加载成功")
except Exception as e:
print(f"❌ 引擎加载失败: {e}")
return
# 2. 测试单帧推理
print(f"\n🔄 测试单帧推理...")
test_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
try:
start_time = time.time()
results = model.predict(
test_image,
imgsz=640,
conf=0.45,
verbose=False,
device=0
)
elapsed = (time.time() - start_time) * 1000
print(f"✅ 单帧推理成功,耗时: {elapsed:.1f}ms")
except Exception as e:
print(f"❌ 单帧推理失败: {e}")
import traceback
traceback.print_exc()
return
# 3. 测试批量推理
print(f"\n🔄 测试批量推理batch=4...")
test_images = [np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8) for _ in range(4)]
try:
start_time = time.time()
results = model.predict(
test_images,
imgsz=640,
conf=0.45,
verbose=False,
device=0
)
elapsed = (time.time() - start_time) * 1000
per_frame = elapsed / 4
print(f"✅ 批量推理成功,总耗时: {elapsed:.1f}ms每帧: {per_frame:.1f}ms")
except Exception as e:
print(f"❌ 批量推理失败: {e}")
import traceback
traceback.print_exc()
return
# 4. 性能测试
print(f"\n🔄 性能测试100帧...")
inference_times = []
try:
for i in range(100):
test_image = np.random.randint(0, 255, (640, 640, 3), dtype=np.uint8)
start_time = time.time()
results = model.predict(
test_image,
imgsz=640,
conf=0.45,
verbose=False,
device=0
)
elapsed = (time.time() - start_time) * 1000
inference_times.append(elapsed)
if (i + 1) % 20 == 0:
print(f" 进度: {i+1}/100")
print(f"\n{'='*60}")
print("性能统计")
print(f"{'='*60}")
print(f"平均推理延迟: {np.mean(inference_times):.1f}ms")
print(f"P50推理延迟: {np.percentile(inference_times, 50):.1f}ms")
print(f"P95推理延迟: {np.percentile(inference_times, 95):.1f}ms")
print(f"P99推理延迟: {np.percentile(inference_times, 99):.1f}ms")
print(f"最小延迟: {np.min(inference_times):.1f}ms")
print(f"最大延迟: {np.max(inference_times):.1f}ms")
print(f"{'='*60}")
print("\n✅ 所有测试通过TensorRT 引擎工作正常")
except Exception as e:
print(f"❌ 性能测试失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
try:
test_tensorrt_engine()
except Exception as e:
print(f"\n❌ 测试异常: {e}")
import traceback
traceback.print_exc()