- Moved all project files and directories (config, core, models, etc.) from edge_inference_service/ to the repository root ai_edge/ - Updated model path in config/settings.py to reflect new structure - Revised usage paths in __init__.py documentation
341 lines
9.5 KiB
Python
341 lines
9.5 KiB
Python
"""
|
|
公共工具函数模块
|
|
提供项目中常用的工具函数
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import time
|
|
import hashlib
|
|
import logging
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
|
|
def generate_unique_id(prefix: str = "") -> str:
|
|
"""生成唯一标识符"""
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
unique_str = f"{timestamp}_{uuid.uuid4().hex[:8]}"
|
|
return f"{prefix}_{unique_str}" if prefix else unique_str
|
|
|
|
|
|
def calculate_md5(data: Union[str, bytes]) -> str:
|
|
"""计算MD5哈希值"""
|
|
if isinstance(data, str):
|
|
data = data.encode('utf-8')
|
|
return hashlib.md5(data).hexdigest()
|
|
|
|
|
|
def get_current_timestamp(format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
|
"""获取当前时间戳"""
|
|
return datetime.now().strftime(format_str)
|
|
|
|
|
|
def get_current_timestamp_ms() -> int:
|
|
"""获取当前时间戳(毫秒)"""
|
|
return int(time.time() * 1000)
|
|
|
|
|
|
def parse_rtsp_url(rtsp_url: str) -> Dict[str, str]:
|
|
"""解析RTSP URL获取各组成部分"""
|
|
result = {
|
|
"protocol": "",
|
|
"host": "",
|
|
"port": "",
|
|
"path": "",
|
|
"params": {}
|
|
}
|
|
|
|
try:
|
|
match = re.match(r"(rtsp://)([^:]+):?(\d+)?([^?]*)\??(.*)", rtsp_url)
|
|
if match:
|
|
result["protocol"] = "rtsp"
|
|
result["host"] = match.group(2)
|
|
result["port"] = match.group(3) or "554"
|
|
result["path"] = match.group(4)
|
|
|
|
params_str = match.group(5)
|
|
if params_str:
|
|
for param in params_str.split("&"):
|
|
if "=" in param:
|
|
k, v = param.split("=", 1)
|
|
result["params"][k] = v
|
|
except Exception as e:
|
|
logging.error(f"RTSP URL解析失败: {e}")
|
|
|
|
return result
|
|
|
|
|
|
def retry_operation(func, max_retries: int = 3, delay: float = 1.0,
|
|
backoff: float = 2.0, exceptions: Tuple = (Exception,)):
|
|
"""
|
|
带重试的函数执行装饰器
|
|
|
|
Args:
|
|
func: 要执行的函数
|
|
max_retries: 最大重试次数
|
|
delay: 初始延迟时间(秒)
|
|
backoff: 延迟时间递增倍数
|
|
exceptions: 需要重试的异常类型
|
|
|
|
Returns:
|
|
函数执行结果
|
|
"""
|
|
def wrapper(*args, **kwargs):
|
|
current_delay = delay
|
|
last_exception = None
|
|
|
|
for attempt in range(max_retries + 1):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except exceptions as e:
|
|
last_exception = e
|
|
if attempt < max_retries:
|
|
logging.warning(f"操作失败,{current_delay:.1f}秒后重试 ({attempt + 1}/{max_retries}): {e}")
|
|
time.sleep(current_delay)
|
|
current_delay *= backoff
|
|
else:
|
|
logging.error(f"操作最终失败,已重试{max_retries}次: {e}")
|
|
raise
|
|
|
|
raise last_exception
|
|
|
|
return wrapper
|
|
|
|
|
|
class ExponentialBackoff:
|
|
"""指数退避管理器"""
|
|
|
|
def __init__(self, base_delay: float = 1.0, max_delay: float = 60.0,
|
|
max_attempts: int = 5):
|
|
self.base_delay = base_delay
|
|
self.max_delay = max_delay
|
|
self.max_attempts = max_attempts
|
|
self.current_attempt = 0
|
|
self.current_delay = base_delay
|
|
|
|
def reset(self):
|
|
"""重置计数器"""
|
|
self.current_attempt = 0
|
|
self.current_delay = self.base_delay
|
|
|
|
def get_delay(self) -> float:
|
|
"""获取当前延迟时间"""
|
|
delay = self.base_delay * (2 ** self.current_attempt)
|
|
return min(delay, self.max_delay)
|
|
|
|
def next_attempt(self) -> bool:
|
|
"""
|
|
执行下一次尝试
|
|
|
|
Returns:
|
|
是否还可以继续尝试
|
|
"""
|
|
if self.current_attempt >= self.max_attempts:
|
|
return False
|
|
|
|
self.current_attempt += 1
|
|
return True
|
|
|
|
def sleep(self):
|
|
"""休眠当前延迟时间"""
|
|
time.sleep(self.get_delay())
|
|
|
|
|
|
def ensure_directory_exists(path: str) -> bool:
|
|
"""确保目录存在"""
|
|
try:
|
|
os.makedirs(path, exist_ok=True)
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"创建目录失败 {path}: {e}")
|
|
return False
|
|
|
|
|
|
def list_files(directory: str, extension: Optional[str] = None) -> List[str]:
|
|
"""列出目录中的文件"""
|
|
files = []
|
|
try:
|
|
for filename in os.listdir(directory):
|
|
filepath = os.path.join(directory, filename)
|
|
if os.path.isfile(filepath):
|
|
if extension is None or filename.endswith(extension):
|
|
files.append(filepath)
|
|
except Exception as e:
|
|
logging.error(f"列出文件失败 {directory}: {e}")
|
|
return files
|
|
|
|
|
|
def format_file_size(size_bytes: int) -> str:
|
|
"""格式化文件大小"""
|
|
for unit in ['B', 'KB', 'MB', 'GB']:
|
|
if size_bytes < 1024:
|
|
return f"{size_bytes:.2f} {unit}"
|
|
size_bytes /= 1024
|
|
return f"{size_bytes:.2f} TB"
|
|
|
|
|
|
def draw_detection_results(
|
|
image: np.ndarray,
|
|
detections: List[Dict[str, Any]],
|
|
class_names: Dict[int, str],
|
|
colors: Optional[Dict[int, Tuple]] = None
|
|
) -> np.ndarray:
|
|
"""
|
|
在图像上绘制检测结果
|
|
|
|
Args:
|
|
image: 原始图像 (BGR格式)
|
|
detections: 检测结果列表
|
|
class_names: 类别名称映射
|
|
colors: 类别颜色映射
|
|
|
|
Returns:
|
|
绘制后的图像
|
|
"""
|
|
result_image = image.copy()
|
|
|
|
if colors is None:
|
|
colors = {
|
|
0: (0, 255, 0), # 人员 - 绿色
|
|
1: (0, 0, 255), # 车辆 - 红色
|
|
2: (255, 0, 0), # 其他 - 蓝色
|
|
}
|
|
|
|
for det in detections:
|
|
bbox = det.get("bbox", [])
|
|
conf = det.get("confidence", 0)
|
|
class_id = det.get("class_id", 0)
|
|
|
|
if len(bbox) != 4:
|
|
continue
|
|
|
|
x1, y1, x2, y2 = map(int, bbox)
|
|
color = colors.get(class_id, (0, 255, 0))
|
|
|
|
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
|
|
|
|
label = f"{class_names.get(class_id, 'unknown')}: {conf:.2f}"
|
|
cv2.putText(result_image, label, (x1, y1 - 10),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
|
|
|
return result_image
|
|
|
|
|
|
def image_to_base64(image: np.ndarray, format: str = ".jpg") -> str:
|
|
"""
|
|
将图像转换为Base64编码
|
|
|
|
Args:
|
|
image: OpenCV图像 (BGR格式)
|
|
format: 图像格式 (.jpg, .png)
|
|
|
|
Returns:
|
|
Base64编码字符串
|
|
"""
|
|
try:
|
|
if format == ".jpg":
|
|
encode_param = [cv2.IMWRITE_JPEG_QUALITY, 85]
|
|
_, buffer = cv2.imencode(".jpg", image, encode_param)
|
|
else:
|
|
_, buffer = cv2.imencode(".png", image)
|
|
|
|
import base64
|
|
return base64.b64encode(buffer).decode('utf-8')
|
|
except Exception as e:
|
|
logging.error(f"图像Base64编码失败: {e}")
|
|
return ""
|
|
|
|
|
|
def base64_to_image(base64_str: str) -> Optional[np.ndarray]:
|
|
"""
|
|
将Base64编码转换为图像
|
|
|
|
Args:
|
|
base64_str: Base64编码字符串
|
|
|
|
Returns:
|
|
OpenCV图像 (BGR格式)
|
|
"""
|
|
try:
|
|
import base64
|
|
buffer = base64.b64decode(base64_str)
|
|
nparr = np.frombuffer(buffer, np.uint8)
|
|
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
|
return image
|
|
except Exception as e:
|
|
logging.error(f"Base64解码为图像失败: {e}")
|
|
return None
|
|
|
|
|
|
def get_gpu_memory_info(device_id: int = 0) -> Dict[str, float]:
|
|
"""
|
|
获取GPU显存信息
|
|
|
|
Args:
|
|
device_id: GPU设备ID
|
|
|
|
Returns:
|
|
显存信息字典 (总显存、已用显存、可用显存)
|
|
"""
|
|
try:
|
|
import pynvml
|
|
pynvml.nvmlInit()
|
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
|
|
|
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
total = mem_info.total / (1024 ** 2)
|
|
used = mem_info.used / (1024 ** 2)
|
|
free = mem_info.free / (1024 ** 2)
|
|
|
|
pynvml.nvmlShutdown()
|
|
|
|
return {
|
|
"total_mb": total,
|
|
"used_mb": used,
|
|
"free_mb": free,
|
|
"used_percent": (used / total) * 100 if total > 0 else 0
|
|
}
|
|
except Exception as e:
|
|
logging.warning(f"获取GPU显存信息失败: {e}")
|
|
return {"total_mb": 0, "used_mb": 0, "free_mb": 0, "used_percent": 0}
|
|
|
|
|
|
class PerformanceTimer:
|
|
"""性能计时器"""
|
|
|
|
def __init__(self):
|
|
self.start_times = {}
|
|
self.elapsed_times = {}
|
|
|
|
def start(self, name: str = "default"):
|
|
"""开始计时"""
|
|
self.start_times[name] = time.perf_counter()
|
|
|
|
def stop(self, name: str = "default") -> float:
|
|
"""停止计时并返回耗时"""
|
|
if name in self.start_times:
|
|
self.elapsed_times[name] = time.perf_counter() - self.start_times[name]
|
|
del self.start_times[name]
|
|
return self.elapsed_times[name]
|
|
return 0.0
|
|
|
|
def get_elapsed(self, name: str = "default") -> float:
|
|
"""获取已记录的耗时"""
|
|
return self.elapsed_times.get(name, 0.0)
|
|
|
|
def reset(self):
|
|
"""重置所有计时"""
|
|
self.start_times.clear()
|
|
self.elapsed_times.clear()
|
|
|
|
def get_average(self, name: str, count: int) -> float:
|
|
"""获取平均耗时"""
|
|
if name in self.elapsed_times and count > 0:
|
|
return self.elapsed_times[name] / count
|
|
return 0.0
|