fix: 修复 TensorRT bindings 问题
- tensorrt_engine.py 添加 pycuda 支持 - CUDA 上下文和流管理 - _is_in_working_hours 支持字符串格式
This commit is contained in:
@@ -49,14 +49,39 @@ class LeavePostAlgorithm:
|
||||
def _is_in_working_hours(self, dt: Optional[datetime] = None) -> bool:
|
||||
if not self.working_hours:
|
||||
return True
|
||||
|
||||
import json
|
||||
|
||||
working_hours = self.working_hours
|
||||
if isinstance(working_hours, str):
|
||||
try:
|
||||
working_hours = json.loads(working_hours)
|
||||
except:
|
||||
return True
|
||||
|
||||
if not working_hours:
|
||||
return True
|
||||
|
||||
dt = dt or datetime.now()
|
||||
current_minutes = dt.hour * 60 + dt.minute
|
||||
for period in self.working_hours:
|
||||
start_minutes = period["start"][0] * 60 + period["start"][1]
|
||||
end_minutes = period["end"][0] * 60 + period["end"][1]
|
||||
for period in working_hours:
|
||||
start_str = period["start"] if isinstance(period, dict) else period
|
||||
end_str = period["end"] if isinstance(period, dict) else period
|
||||
start_minutes = self._parse_time_to_minutes(start_str)
|
||||
end_minutes = self._parse_time_to_minutes(end_str)
|
||||
if start_minutes <= current_minutes < end_minutes:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _parse_time_to_minutes(self, time_str: str) -> int:
|
||||
"""将时间字符串转换为分钟数"""
|
||||
if isinstance(time_str, int):
|
||||
return time_str
|
||||
try:
|
||||
parts = time_str.split(":")
|
||||
return int(parts[0]) * 60 + int(parts[1])
|
||||
except:
|
||||
return 0
|
||||
|
||||
def _check_detection_in_roi(self, detection: Dict, roi_id: str) -> bool:
|
||||
matched_rois = detection.get("matched_rois", [])
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -3,6 +3,7 @@ TensorRT推理引擎模块
|
||||
实现引擎加载、显存优化、异步推理、性能监控
|
||||
"""
|
||||
|
||||
import ctypes
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
@@ -12,10 +13,13 @@ import numpy as np
|
||||
|
||||
try:
|
||||
import tensorrt as trt
|
||||
import pycuda.driver as cuda
|
||||
import pycuda.autoinit
|
||||
TRT_AVAILABLE = True
|
||||
except ImportError:
|
||||
TRT_AVAILABLE = False
|
||||
trt = None
|
||||
cuda = None
|
||||
|
||||
from config.settings import get_settings, InferenceConfig
|
||||
from utils.logger import get_logger
|
||||
@@ -50,6 +54,7 @@ class TensorRTEngine:
|
||||
self._output_bindings = []
|
||||
self._stream = None
|
||||
self._released = False
|
||||
self._cuda_context = None
|
||||
|
||||
self._logger = get_logger("tensorrt")
|
||||
self._lock = threading.Lock()
|
||||
@@ -90,6 +95,10 @@ class TensorRTEngine:
|
||||
if self._context is not None:
|
||||
self._release_resources()
|
||||
|
||||
if cuda is not None:
|
||||
self._cuda_context = cuda.Device(0).make_context()
|
||||
self._stream = cuda.Stream()
|
||||
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
with open(engine_path, "rb") as f:
|
||||
@@ -241,24 +250,60 @@ class TensorRTEngine:
|
||||
input_data.shape
|
||||
)
|
||||
|
||||
input_tensor = input_data
|
||||
output_tensors = []
|
||||
|
||||
for output in self._output_bindings:
|
||||
output_shape = list(output["shape"])
|
||||
output_shape[0] = batch_size
|
||||
output_tensor = np.zeros(output_shape, dtype=self._get_numpy_dtype(output["dtype"]))
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
bindings = [input_tensor] + output_tensors
|
||||
|
||||
self._context.execute_v2(bindings=bindings)
|
||||
|
||||
inference_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
self._update_performance_stats(inference_time_ms, batch_size)
|
||||
|
||||
return output_tensors, inference_time_ms
|
||||
if cuda is not None and self._cuda_context is not None:
|
||||
self._cuda_context.push()
|
||||
|
||||
try:
|
||||
input_data = np.ascontiguousarray(input_data)
|
||||
|
||||
input_ptr = cuda.mem_alloc(input_data.nbytes)
|
||||
cuda.memcpy_htod(input_ptr, input_data)
|
||||
|
||||
bindings = [int(input_ptr)]
|
||||
output_tensors = []
|
||||
|
||||
for output in self._output_bindings:
|
||||
output_shape = list(output["shape"])
|
||||
output_shape[0] = batch_size
|
||||
output_tensor = np.zeros(output_shape, dtype=self._get_numpy_dtype(output["dtype"]))
|
||||
output_tensor = np.ascontiguousarray(output_tensor)
|
||||
output_ptr = cuda.mem_alloc(output_tensor.nbytes)
|
||||
cuda.memcpy_htod(output_ptr, output_tensor)
|
||||
bindings.append(int(output_ptr))
|
||||
output_tensors.append((output_tensor, output_ptr))
|
||||
|
||||
self._context.execute_v2(bindings=bindings)
|
||||
|
||||
for output_tensor, output_ptr in output_tensors:
|
||||
cuda.memcpy_dtoh(output_tensor, output_ptr)
|
||||
|
||||
inference_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
self._update_performance_stats(inference_time_ms, batch_size)
|
||||
|
||||
return [t[0] for t in output_tensors], inference_time_ms
|
||||
|
||||
finally:
|
||||
self._cuda_context.pop()
|
||||
else:
|
||||
input_tensor = input_data
|
||||
output_tensors = []
|
||||
|
||||
for output in self._output_bindings:
|
||||
output_shape = list(output["shape"])
|
||||
output_shape[0] = batch_size
|
||||
output_tensor = np.zeros(output_shape, dtype=self._get_numpy_dtype(output["dtype"]))
|
||||
output_tensors.append(output_tensor)
|
||||
|
||||
bindings = [int(input_tensor.ctypes.data)] + [int(t.ctypes.data) for t in output_tensors]
|
||||
|
||||
self._context.execute_v2(bindings=bindings)
|
||||
|
||||
inference_time_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
self._update_performance_stats(inference_time_ms, batch_size)
|
||||
|
||||
return output_tensors, inference_time_ms
|
||||
|
||||
def infer_async(self, input_data: np.ndarray) -> Tuple[List[np.ndarray], float]:
|
||||
"""
|
||||
@@ -333,6 +378,14 @@ class TensorRTEngine:
|
||||
|
||||
def _release_resources(self):
|
||||
"""释放资源(Python TensorRT 由 GC 管理,无需 destroy)"""
|
||||
if self._cuda_context:
|
||||
try:
|
||||
self._cuda_context.pop()
|
||||
self._cuda_context.detach()
|
||||
except Exception:
|
||||
pass
|
||||
self._cuda_context = None
|
||||
|
||||
if self._stream:
|
||||
try:
|
||||
self._stream.synchronize()
|
||||
|
||||
36005
logs/main.log
36005
logs/main.log
File diff suppressed because it is too large
Load Diff
34445
logs/main_error.log
34445
logs/main_error.log
File diff suppressed because it is too large
Load Diff
130
test_edge_run.py
Normal file
130
test_edge_run.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
边缘端运行测试脚本
|
||||
添加测试摄像头和ROI配置,验证系统正常运行
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from config.database import get_sqlite_manager
|
||||
from datetime import datetime
|
||||
import random
|
||||
|
||||
|
||||
def setup_test_data():
|
||||
"""设置测试数据"""
|
||||
db = get_sqlite_manager()
|
||||
|
||||
print("=" * 60)
|
||||
print("边缘端运行测试 - 数据准备")
|
||||
print("=" * 60)
|
||||
|
||||
camera_id = "test_camera_01"
|
||||
rtsp_url = "rtsp://admin:admin@172.16.8.35/cam/realmonitor?channel=6&subtype=1"
|
||||
|
||||
print(f"\n1. 添加摄像头配置")
|
||||
print(f" camera_id: {camera_id}")
|
||||
print(f" rtsp_url: {rtsp_url}")
|
||||
|
||||
result = db.save_camera_config(
|
||||
camera_id=camera_id,
|
||||
rtsp_url=rtsp_url,
|
||||
camera_name="测试摄像头-车间入口",
|
||||
location="车间入口通道",
|
||||
enabled=True,
|
||||
status=True
|
||||
)
|
||||
print(f" 结果: {'成功' if result else '失败'}")
|
||||
|
||||
print(f"\n2. 添加ROI配置(随机划分区域)")
|
||||
|
||||
roi_configs = [
|
||||
{
|
||||
"roi_id": f"{camera_id}_roi_01",
|
||||
"name": "离岗检测区域",
|
||||
"roi_type": "polygon",
|
||||
"coordinates": [[100, 50], [300, 50], [300, 200], [100, 200]],
|
||||
"algorithm_type": "leave_post",
|
||||
"target_class": "person",
|
||||
"confirm_on_duty_sec": 10,
|
||||
"confirm_leave_sec": 30,
|
||||
"cooldown_sec": 60,
|
||||
"working_hours": [{"start": "08:00", "end": "18:00"}],
|
||||
},
|
||||
{
|
||||
"roi_id": f"{camera_id}_roi_02",
|
||||
"name": "入侵检测区域",
|
||||
"roi_type": "polygon",
|
||||
"coordinates": [[350, 50], [550, 50], [550, 200], [350, 200]],
|
||||
"algorithm_type": "intrusion",
|
||||
"target_class": "person",
|
||||
"alert_threshold": 3,
|
||||
"alert_cooldown": 60,
|
||||
"confirm_on_duty_sec": 10,
|
||||
"confirm_leave_sec": 10,
|
||||
"cooldown_sec": 60,
|
||||
"working_hours": None,
|
||||
},
|
||||
]
|
||||
|
||||
for roi in roi_configs:
|
||||
print(f"\n ROI: {roi['name']}")
|
||||
print(f" - roi_id: {roi['roi_id']}")
|
||||
print(f" - algorithm_type: {roi['algorithm_type']}")
|
||||
print(f" - coordinates: {roi['coordinates']}")
|
||||
|
||||
result = db.save_roi_config(
|
||||
roi_id=roi["roi_id"],
|
||||
camera_id=camera_id,
|
||||
roi_type=roi["roi_type"],
|
||||
coordinates=roi["coordinates"],
|
||||
algorithm_type=roi["algorithm_type"],
|
||||
target_class=roi["target_class"],
|
||||
confirm_on_duty_sec=roi["confirm_on_duty_sec"],
|
||||
confirm_leave_sec=roi["confirm_leave_sec"],
|
||||
cooldown_sec=roi["cooldown_sec"],
|
||||
working_hours=str(roi["working_hours"]),
|
||||
)
|
||||
print(f" 结果: {'成功' if result else '失败'}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("测试数据准备完成")
|
||||
print("=" * 60)
|
||||
|
||||
return camera_id, roi_configs
|
||||
|
||||
|
||||
def verify_data():
|
||||
"""验证数据"""
|
||||
db = get_sqlite_manager()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("验证数据库中的配置")
|
||||
print("=" * 60)
|
||||
|
||||
cameras = db.get_all_camera_configs()
|
||||
print(f"\n摄像头数量: {len(cameras)}")
|
||||
for cam in cameras:
|
||||
print(f" - {cam['camera_id']}: {cam['camera_name']} ({cam['rtsp_url'][:50]}...)")
|
||||
|
||||
rois = db.get_all_roi_configs()
|
||||
print(f"\nROI数量: {len(rois)}")
|
||||
for roi in rois:
|
||||
print(f" - {roi['roi_id']}: {roi['name']} ({roi['algorithm_type']})")
|
||||
|
||||
return len(cameras) > 0 and len(rois) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "#" * 60)
|
||||
print("# 边缘端运行测试 - 数据准备")
|
||||
print("#" * 60)
|
||||
|
||||
setup_test_data()
|
||||
verify_data()
|
||||
|
||||
print("\n" + "#" * 60)
|
||||
print("# 测试数据准备完成,请运行 main.py 进行推理测试")
|
||||
print("#" * 60)
|
||||
69
test_inference.py
Normal file
69
test_inference.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
边缘端运行测试脚本 - 推理测试
|
||||
运行 main.py 并测试 30 秒
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
def run_test():
|
||||
print("=" * 60)
|
||||
print("边缘端运行测试 - 推理测试")
|
||||
print("=" * 60)
|
||||
print(f"测试时长: 30 秒")
|
||||
print(f"测试时间: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
print("=" * 60)
|
||||
|
||||
env = os.environ.copy()
|
||||
env['PATH'] = r"C:\Users\16337\miniconda3\envs\yolo;" + env.get('PATH', '')
|
||||
|
||||
cmd = [
|
||||
sys.executable, "main.py"
|
||||
]
|
||||
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
env=env
|
||||
)
|
||||
|
||||
output_lines = []
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
while True:
|
||||
line = process.stdout.readline()
|
||||
if not line and process.poll() is not None:
|
||||
break
|
||||
|
||||
if line:
|
||||
output_lines.append(line.strip())
|
||||
print(line.strip())
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed >= 30:
|
||||
print(f"\n[INFO] 测试达到 30 秒,停止进程...")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.kill()
|
||||
break
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n[INFO] 用户中断测试")
|
||||
process.terminate()
|
||||
|
||||
return output_lines
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test()
|
||||
print("\n" + "=" * 60)
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
Reference in New Issue
Block a user