Files
Test_AI/real_world_quick_test.py
2026-01-20 11:14:10 +08:00

307 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
import numpy as np
import yaml
import torch
from ultralytics import YOLO
import time
import datetime
import json
import os
def test_real_world_performance(model_path, config_path, framework_name, test_duration=30, max_cameras=5):
"""测试真实场景性能"""
print(f"\n{'='*60}")
print(f"测试框架: {framework_name}")
print(f"{'='*60}")
# 加载配置
with open(config_path, 'r', encoding='utf-8') as f:
cfg = yaml.safe_load(f)
# 加载模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"🚀 加载模型: {model_path}")
print(f" 设备: {device}")
model = YOLO(model_path, task='detect')
# TensorRT引擎不需要.to(),直接使用即可
model_cfg = cfg['model']
imgsz = model_cfg['imgsz']
conf_thresh = model_cfg['conf_threshold']
# 选择前N个摄像头
cameras = cfg['cameras'][:max_cameras]
print(f"✅ 测试 {len(cameras)} 个摄像头")
# 打开视频流
caps = []
cam_ids = []
for cam_cfg in cameras:
cam_id = cam_cfg['id']
rtsp_url = cam_cfg['rtsp_url']
print(f"📹 连接摄像头 {cam_id}...")
cap = cv2.VideoCapture(rtsp_url, cv2.CAP_FFMPEG)
if cap.isOpened():
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
caps.append(cap)
cam_ids.append(cam_id)
print(f"{cam_id} 连接成功")
else:
print(f" ⚠️ {cam_id} 连接失败")
if not caps:
print("❌ 没有可用的摄像头")
return None
print(f"\n✅ 成功连接 {len(caps)} 个摄像头")
print(f"⏳ 开始测试,时长: {test_duration}\n")
# 性能统计
frame_count = 0
inference_times = []
start_time = time.time()
last_print_time = start_time
try:
while (time.time() - start_time) < test_duration:
for i, cap in enumerate(caps):
ret, frame = cap.read()
if not ret:
continue
# 推理
infer_start = time.time()
results = model.predict(
frame,
imgsz=imgsz,
conf=conf_thresh,
verbose=False,
device=0 if device == 'cuda' else 'cpu',
half=(device == 'cuda'),
classes=[0]
)
infer_end = time.time()
inference_times.append((infer_end - infer_start) * 1000)
frame_count += 1
# 每5秒打印一次进度
current_time = time.time()
if current_time - last_print_time >= 5.0:
elapsed = current_time - start_time
avg_fps = frame_count / elapsed
avg_latency = np.mean(inference_times) if inference_times else 0
print(f"⏱️ {elapsed:.0f}s | 帧数: {frame_count} | FPS: {avg_fps:.1f} | 延迟: {avg_latency:.1f}ms")
last_print_time = current_time
except KeyboardInterrupt:
print("\n⏹️ 测试被用户中断")
finally:
# 释放资源
for cap in caps:
cap.release()
# 计算统计
elapsed = time.time() - start_time
avg_fps = frame_count / elapsed if elapsed > 0 else 0
stats = {
'framework': framework_name,
'total_frames': frame_count,
'elapsed_time': elapsed,
'avg_fps': avg_fps,
'avg_inference_time_ms': np.mean(inference_times) if inference_times else 0,
'p50_inference_time_ms': np.percentile(inference_times, 50) if inference_times else 0,
'p95_inference_time_ms': np.percentile(inference_times, 95) if inference_times else 0,
'p99_inference_time_ms': np.percentile(inference_times, 99) if inference_times else 0,
'num_cameras': len(caps)
}
print(f"\n{'='*60}")
print(f"{framework_name} 测试完成")
print(f"{'='*60}")
print(f"总帧数: {stats['total_frames']}")
print(f"测试时长: {stats['elapsed_time']:.1f}")
print(f"平均FPS: {stats['avg_fps']:.1f}")
print(f"平均推理延迟: {stats['avg_inference_time_ms']:.1f}ms")
print(f"P95推理延迟: {stats['p95_inference_time_ms']:.1f}ms")
print(f"P99推理延迟: {stats['p99_inference_time_ms']:.1f}ms")
print(f"{'='*60}\n")
return stats
def main():
"""主函数"""
print("真实场景快速性能测试")
print("=" * 60)
config_path = "config.yaml"
pytorch_model = "C:/Users/16337/PycharmProjects/Security/yolo11n.pt"
tensorrt_model = "C:/Users/16337/PycharmProjects/Security/yolo11n.engine"
# 检查文件
if not os.path.exists(config_path):
print(f"❌ 配置文件不存在: {config_path}")
return
if not os.path.exists(pytorch_model):
print(f"❌ PyTorch 模型不存在: {pytorch_model}")
return
if not os.path.exists(tensorrt_model):
print(f"❌ TensorRT 引擎不存在: {tensorrt_model}")
return
# 检查 CUDA
if not torch.cuda.is_available():
print("❌ CUDA 不可用")
return
print(f"✅ CUDA 可用,设备: {torch.cuda.get_device_name(0)}")
# 测试配置
test_duration = 30 # 每个框架测试30秒
max_cameras = 5 # 只测试前5个摄像头
results = {}
# 测试 PyTorch
print(f"\n{'='*60}")
print("测试 1/2: PyTorch 框架")
print(f"{'='*60}")
try:
pytorch_stats = test_real_world_performance(
pytorch_model, config_path, "PyTorch",
test_duration=test_duration, max_cameras=max_cameras
)
results['pytorch'] = pytorch_stats
except Exception as e:
print(f"❌ PyTorch 测试失败: {e}")
import traceback
traceback.print_exc()
results['pytorch'] = None
# 等待系统稳定
print("\n⏳ 等待系统稳定...")
time.sleep(3)
# 测试 TensorRT
print(f"\n{'='*60}")
print("测试 2/2: TensorRT 框架")
print(f"{'='*60}")
try:
tensorrt_stats = test_real_world_performance(
tensorrt_model, config_path, "TensorRT",
test_duration=test_duration, max_cameras=max_cameras
)
results['tensorrt'] = tensorrt_stats
except Exception as e:
print(f"❌ TensorRT 测试失败: {e}")
import traceback
traceback.print_exc()
results['tensorrt'] = None
# 生成对比报告
print(f"\n{'='*60}")
print("性能对比报告")
print(f"{'='*60}\n")
if results['pytorch'] and results['tensorrt']:
pt_stats = results['pytorch']
trt_stats = results['tensorrt']
print(f"指标 | PyTorch | TensorRT | 提升")
print(f"{'-'*60}")
print(f"平均FPS | {pt_stats['avg_fps']:12.1f} | {trt_stats['avg_fps']:12.1f} | {(trt_stats['avg_fps']/pt_stats['avg_fps']-1)*100:+.1f}%")
print(f"平均推理延迟(ms) | {pt_stats['avg_inference_time_ms']:12.1f} | {trt_stats['avg_inference_time_ms']:12.1f} | {(1-trt_stats['avg_inference_time_ms']/pt_stats['avg_inference_time_ms'])*100:+.1f}%")
print(f"P95推理延迟(ms) | {pt_stats['p95_inference_time_ms']:12.1f} | {trt_stats['p95_inference_time_ms']:12.1f} | {(1-trt_stats['p95_inference_time_ms']/pt_stats['p95_inference_time_ms'])*100:+.1f}%")
print(f"P99推理延迟(ms) | {pt_stats['p99_inference_time_ms']:12.1f} | {trt_stats['p99_inference_time_ms']:12.1f} | {(1-trt_stats['p99_inference_time_ms']/pt_stats['p99_inference_time_ms'])*100:+.1f}%")
print(f"总帧数 | {pt_stats['total_frames']:12d} | {trt_stats['total_frames']:12d} | {(trt_stats['total_frames']/pt_stats['total_frames']-1)*100:+.1f}%")
print(f"摄像头数量 | {pt_stats['num_cameras']:12d} | {trt_stats['num_cameras']:12d} |")
# 保存结果
output_dir = "real_world_results"
os.makedirs(output_dir, exist_ok=True)
results_data = {
'pytorch': pt_stats,
'tensorrt': trt_stats,
'timestamp': datetime.datetime.now().isoformat(),
'test_duration': test_duration,
'max_cameras': max_cameras
}
json_file = os.path.join(output_dir, f"real_world_quick_test_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.json")
with open(json_file, 'w', encoding='utf-8') as f:
json.dump(results_data, f, indent=2, ensure_ascii=False)
print(f"\n✅ 结果已保存: {json_file}")
# 生成文本报告
report = f"""
真实场景性能测试报告
{'='*60}
测试时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
测试时长: {test_duration}
摄像头数量: {max_cameras}
详细对比数据:
{'='*60}
指标 | PyTorch | TensorRT | 提升
{'-'*60}
平均FPS | {pt_stats['avg_fps']:12.1f} | {trt_stats['avg_fps']:12.1f} | {(trt_stats['avg_fps']/pt_stats['avg_fps']-1)*100:+.1f}%
平均推理延迟(ms) | {pt_stats['avg_inference_time_ms']:12.1f} | {trt_stats['avg_inference_time_ms']:12.1f} | {(1-trt_stats['avg_inference_time_ms']/pt_stats['avg_inference_time_ms'])*100:+.1f}%
P50推理延迟(ms) | {pt_stats['p50_inference_time_ms']:12.1f} | {trt_stats['p50_inference_time_ms']:12.1f} | {(1-trt_stats['p50_inference_time_ms']/pt_stats['p50_inference_time_ms'])*100:+.1f}%
P95推理延迟(ms) | {pt_stats['p95_inference_time_ms']:12.1f} | {trt_stats['p95_inference_time_ms']:12.1f} | {(1-trt_stats['p95_inference_time_ms']/pt_stats['p95_inference_time_ms'])*100:+.1f}%
P99推理延迟(ms) | {pt_stats['p99_inference_time_ms']:12.1f} | {trt_stats['p99_inference_time_ms']:12.1f} | {(1-trt_stats['p99_inference_time_ms']/pt_stats['p99_inference_time_ms'])*100:+.1f}%
总帧数 | {pt_stats['total_frames']:12d} | {trt_stats['total_frames']:12d} | {(trt_stats['total_frames']/pt_stats['total_frames']-1)*100:+.1f}%
关键发现:
{'='*60}
✅ TensorRT 在真实场景下平均FPS提升: {(trt_stats['avg_fps']/pt_stats['avg_fps']-1)*100:+.1f}%
✅ TensorRT 推理延迟降低: {(1-trt_stats['avg_inference_time_ms']/pt_stats['avg_inference_time_ms'])*100:+.1f}%
✅ TensorRT 在相同时间内处理更多帧: {(trt_stats['total_frames']/pt_stats['total_frames']-1)*100:+.1f}%
说明:
{'='*60}
本测试接入真实RTSP视频流包含完整的业务逻辑
- 视频流解码
- YOLO目标检测person类
- ROI区域判断
- 离岗检测算法
- 周界入侵检测算法
测试结果反映了实际生产环境的性能表现。
"""
report_file = os.path.join(output_dir, f"real_world_report_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")
with open(report_file, 'w', encoding='utf-8') as f:
f.write(report)
print(f"✅ 报告已保存: {report_file}")
else:
print("❌ 测试未完成,无法生成对比报告")
print(f"\n🎉 测试完成!")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\n⏹️ 测试被用户中断")
except Exception as e:
print(f"\n❌ 测试过程中发生错误: {e}")
import traceback
traceback.print_exc()