Files
security-ai-edge/utils/common.py
16337 b0ddb6ee1a feat(project): move edge_inference_service contents to root and update paths
- Moved all project files and directories (config, core, models, etc.) from
  edge_inference_service/ to the repository root ai_edge/
- Updated model path in config/settings.py to reflect new structure
- Revised usage paths in __init__.py documentation
2026-01-29 18:43:19 +08:00

341 lines
9.5 KiB
Python

"""
公共工具函数模块
提供项目中常用的工具函数
"""
import os
import re
import time
import hashlib
import logging
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
def generate_unique_id(prefix: str = "") -> str:
"""生成唯一标识符"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
unique_str = f"{timestamp}_{uuid.uuid4().hex[:8]}"
return f"{prefix}_{unique_str}" if prefix else unique_str
def calculate_md5(data: Union[str, bytes]) -> str:
"""计算MD5哈希值"""
if isinstance(data, str):
data = data.encode('utf-8')
return hashlib.md5(data).hexdigest()
def get_current_timestamp(format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""获取当前时间戳"""
return datetime.now().strftime(format_str)
def get_current_timestamp_ms() -> int:
"""获取当前时间戳(毫秒)"""
return int(time.time() * 1000)
def parse_rtsp_url(rtsp_url: str) -> Dict[str, str]:
"""解析RTSP URL获取各组成部分"""
result = {
"protocol": "",
"host": "",
"port": "",
"path": "",
"params": {}
}
try:
match = re.match(r"(rtsp://)([^:]+):?(\d+)?([^?]*)\??(.*)", rtsp_url)
if match:
result["protocol"] = "rtsp"
result["host"] = match.group(2)
result["port"] = match.group(3) or "554"
result["path"] = match.group(4)
params_str = match.group(5)
if params_str:
for param in params_str.split("&"):
if "=" in param:
k, v = param.split("=", 1)
result["params"][k] = v
except Exception as e:
logging.error(f"RTSP URL解析失败: {e}")
return result
def retry_operation(func, max_retries: int = 3, delay: float = 1.0,
backoff: float = 2.0, exceptions: Tuple = (Exception,)):
"""
带重试的函数执行装饰器
Args:
func: 要执行的函数
max_retries: 最大重试次数
delay: 初始延迟时间(秒)
backoff: 延迟时间递增倍数
exceptions: 需要重试的异常类型
Returns:
函数执行结果
"""
def wrapper(*args, **kwargs):
current_delay = delay
last_exception = None
for attempt in range(max_retries + 1):
try:
return func(*args, **kwargs)
except exceptions as e:
last_exception = e
if attempt < max_retries:
logging.warning(f"操作失败,{current_delay:.1f}秒后重试 ({attempt + 1}/{max_retries}): {e}")
time.sleep(current_delay)
current_delay *= backoff
else:
logging.error(f"操作最终失败,已重试{max_retries}次: {e}")
raise
raise last_exception
return wrapper
class ExponentialBackoff:
"""指数退避管理器"""
def __init__(self, base_delay: float = 1.0, max_delay: float = 60.0,
max_attempts: int = 5):
self.base_delay = base_delay
self.max_delay = max_delay
self.max_attempts = max_attempts
self.current_attempt = 0
self.current_delay = base_delay
def reset(self):
"""重置计数器"""
self.current_attempt = 0
self.current_delay = self.base_delay
def get_delay(self) -> float:
"""获取当前延迟时间"""
delay = self.base_delay * (2 ** self.current_attempt)
return min(delay, self.max_delay)
def next_attempt(self) -> bool:
"""
执行下一次尝试
Returns:
是否还可以继续尝试
"""
if self.current_attempt >= self.max_attempts:
return False
self.current_attempt += 1
return True
def sleep(self):
"""休眠当前延迟时间"""
time.sleep(self.get_delay())
def ensure_directory_exists(path: str) -> bool:
"""确保目录存在"""
try:
os.makedirs(path, exist_ok=True)
return True
except Exception as e:
logging.error(f"创建目录失败 {path}: {e}")
return False
def list_files(directory: str, extension: Optional[str] = None) -> List[str]:
"""列出目录中的文件"""
files = []
try:
for filename in os.listdir(directory):
filepath = os.path.join(directory, filename)
if os.path.isfile(filepath):
if extension is None or filename.endswith(extension):
files.append(filepath)
except Exception as e:
logging.error(f"列出文件失败 {directory}: {e}")
return files
def format_file_size(size_bytes: int) -> str:
"""格式化文件大小"""
for unit in ['B', 'KB', 'MB', 'GB']:
if size_bytes < 1024:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024
return f"{size_bytes:.2f} TB"
def draw_detection_results(
image: np.ndarray,
detections: List[Dict[str, Any]],
class_names: Dict[int, str],
colors: Optional[Dict[int, Tuple]] = None
) -> np.ndarray:
"""
在图像上绘制检测结果
Args:
image: 原始图像 (BGR格式)
detections: 检测结果列表
class_names: 类别名称映射
colors: 类别颜色映射
Returns:
绘制后的图像
"""
result_image = image.copy()
if colors is None:
colors = {
0: (0, 255, 0), # 人员 - 绿色
1: (0, 0, 255), # 车辆 - 红色
2: (255, 0, 0), # 其他 - 蓝色
}
for det in detections:
bbox = det.get("bbox", [])
conf = det.get("confidence", 0)
class_id = det.get("class_id", 0)
if len(bbox) != 4:
continue
x1, y1, x2, y2 = map(int, bbox)
color = colors.get(class_id, (0, 255, 0))
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
label = f"{class_names.get(class_id, 'unknown')}: {conf:.2f}"
cv2.putText(result_image, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
return result_image
def image_to_base64(image: np.ndarray, format: str = ".jpg") -> str:
"""
将图像转换为Base64编码
Args:
image: OpenCV图像 (BGR格式)
format: 图像格式 (.jpg, .png)
Returns:
Base64编码字符串
"""
try:
if format == ".jpg":
encode_param = [cv2.IMWRITE_JPEG_QUALITY, 85]
_, buffer = cv2.imencode(".jpg", image, encode_param)
else:
_, buffer = cv2.imencode(".png", image)
import base64
return base64.b64encode(buffer).decode('utf-8')
except Exception as e:
logging.error(f"图像Base64编码失败: {e}")
return ""
def base64_to_image(base64_str: str) -> Optional[np.ndarray]:
"""
将Base64编码转换为图像
Args:
base64_str: Base64编码字符串
Returns:
OpenCV图像 (BGR格式)
"""
try:
import base64
buffer = base64.b64decode(base64_str)
nparr = np.frombuffer(buffer, np.uint8)
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
return image
except Exception as e:
logging.error(f"Base64解码为图像失败: {e}")
return None
def get_gpu_memory_info(device_id: int = 0) -> Dict[str, float]:
"""
获取GPU显存信息
Args:
device_id: GPU设备ID
Returns:
显存信息字典 (总显存、已用显存、可用显存)
"""
try:
import pynvml
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
total = mem_info.total / (1024 ** 2)
used = mem_info.used / (1024 ** 2)
free = mem_info.free / (1024 ** 2)
pynvml.nvmlShutdown()
return {
"total_mb": total,
"used_mb": used,
"free_mb": free,
"used_percent": (used / total) * 100 if total > 0 else 0
}
except Exception as e:
logging.warning(f"获取GPU显存信息失败: {e}")
return {"total_mb": 0, "used_mb": 0, "free_mb": 0, "used_percent": 0}
class PerformanceTimer:
"""性能计时器"""
def __init__(self):
self.start_times = {}
self.elapsed_times = {}
def start(self, name: str = "default"):
"""开始计时"""
self.start_times[name] = time.perf_counter()
def stop(self, name: str = "default") -> float:
"""停止计时并返回耗时"""
if name in self.start_times:
self.elapsed_times[name] = time.perf_counter() - self.start_times[name]
del self.start_times[name]
return self.elapsed_times[name]
return 0.0
def get_elapsed(self, name: str = "default") -> float:
"""获取已记录的耗时"""
return self.elapsed_times.get(name, 0.0)
def reset(self):
"""重置所有计时"""
self.start_times.clear()
self.elapsed_times.clear()
def get_average(self, name: str, count: int) -> float:
"""获取平均耗时"""
if name in self.elapsed_times and count > 0:
return self.elapsed_times[name] / count
return 0.0