Files
Test_AI/simple_tensorrt_test.py
2026-01-20 11:14:10 +08:00

118 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
简单的 TensorRT 推理测试
验证 TensorRT 引擎是否能正常工作
"""
import cv2
import numpy as np
import yaml
import time
from ultralytics import YOLO
def test_tensorrt_inference():
"""测试 TensorRT 推理"""
print("TensorRT 推理测试")
print("=" * 60)
# 配置
config_path = "config.yaml"
engine_path = "C:/Users/16337/PycharmProjects/Security/yolo11n.engine"
# 加载配置
with open(config_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
# 加载 TensorRT 引擎
print(f"🚀 加载 TensorRT 引擎: {engine_path}")
model = YOLO(engine_path, task='detect')
print("✅ 引擎加载成功")
# 获取第一个摄像头
cam_cfg = cfg['cameras'][0]
cam_id = cam_cfg['id']
rtsp_url = cam_cfg['rtsp_url']
print(f"\n📹 连接摄像头: {cam_id}")
print(f" RTSP: {rtsp_url}")
# 打开视频流
cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
if not cap.isOpened():
print("❌ 无法打开视频流")
return
print("✅ 视频流已连接")
# 测试推理
print(f"\n🔄 开始推理测试10秒...")
frame_count = 0
inference_times = []
start_time = time.time()
try:
while (time.time() - start_time) < 10:
ret, frame = cap.read()
if not ret:
continue
# Resize到640x640
frame = cv2.resize(frame, (640, 640))
# 推理
infer_start = time.time()
results = model.predict(
frame,
imgsz=640,
conf=0.45,
verbose=False,
device=0, # GPU 0
classes=[0] # person only
)
infer_end = time.time()
inference_times.append((infer_end - infer_start) * 1000)
frame_count += 1
# 显示进度
if frame_count % 10 == 0:
elapsed = time.time() - start_time
fps = frame_count / elapsed
avg_latency = np.mean(inference_times)
print(f" 帧数: {frame_count} | FPS: {fps:.1f} | 延迟: {avg_latency:.1f}ms")
except KeyboardInterrupt:
print("\n⏹️ 测试被中断")
finally:
cap.release()
# 统计结果
elapsed = time.time() - start_time
avg_fps = frame_count / elapsed
print(f"\n{'='*60}")
print("测试结果")
print(f"{'='*60}")
print(f"总帧数: {frame_count}")
print(f"测试时长: {elapsed:.1f}")
print(f"平均FPS: {avg_fps:.1f}")
print(f"平均推理延迟: {np.mean(inference_times):.1f}ms")
print(f"P95推理延迟: {np.percentile(inference_times, 95):.1f}ms")
print(f"P99推理延迟: {np.percentile(inference_times, 99):.1f}ms")
print(f"{'='*60}")
print("\n✅ 测试完成TensorRT 引擎工作正常")
if __name__ == "__main__":
try:
test_tensorrt_inference()
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()