feat(project): move edge_inference_service contents to root and update paths
- Moved all project files and directories (config, core, models, etc.) from edge_inference_service/ to the repository root ai_edge/ - Updated model path in config/settings.py to reflect new structure - Revised usage paths in __init__.py documentation
This commit is contained in:
340
utils/common.py
Normal file
340
utils/common.py
Normal file
@@ -0,0 +1,340 @@
|
||||
"""
|
||||
公共工具函数模块
|
||||
提供项目中常用的工具函数
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_unique_id(prefix: str = "") -> str:
|
||||
"""生成唯一标识符"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
unique_str = f"{timestamp}_{uuid.uuid4().hex[:8]}"
|
||||
return f"{prefix}_{unique_str}" if prefix else unique_str
|
||||
|
||||
|
||||
def calculate_md5(data: Union[str, bytes]) -> str:
|
||||
"""计算MD5哈希值"""
|
||||
if isinstance(data, str):
|
||||
data = data.encode('utf-8')
|
||||
return hashlib.md5(data).hexdigest()
|
||||
|
||||
|
||||
def get_current_timestamp(format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""获取当前时间戳"""
|
||||
return datetime.now().strftime(format_str)
|
||||
|
||||
|
||||
def get_current_timestamp_ms() -> int:
|
||||
"""获取当前时间戳(毫秒)"""
|
||||
return int(time.time() * 1000)
|
||||
|
||||
|
||||
def parse_rtsp_url(rtsp_url: str) -> Dict[str, str]:
|
||||
"""解析RTSP URL获取各组成部分"""
|
||||
result = {
|
||||
"protocol": "",
|
||||
"host": "",
|
||||
"port": "",
|
||||
"path": "",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
try:
|
||||
match = re.match(r"(rtsp://)([^:]+):?(\d+)?([^?]*)\??(.*)", rtsp_url)
|
||||
if match:
|
||||
result["protocol"] = "rtsp"
|
||||
result["host"] = match.group(2)
|
||||
result["port"] = match.group(3) or "554"
|
||||
result["path"] = match.group(4)
|
||||
|
||||
params_str = match.group(5)
|
||||
if params_str:
|
||||
for param in params_str.split("&"):
|
||||
if "=" in param:
|
||||
k, v = param.split("=", 1)
|
||||
result["params"][k] = v
|
||||
except Exception as e:
|
||||
logging.error(f"RTSP URL解析失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def retry_operation(func, max_retries: int = 3, delay: float = 1.0,
|
||||
backoff: float = 2.0, exceptions: Tuple = (Exception,)):
|
||||
"""
|
||||
带重试的函数执行装饰器
|
||||
|
||||
Args:
|
||||
func: 要执行的函数
|
||||
max_retries: 最大重试次数
|
||||
delay: 初始延迟时间(秒)
|
||||
backoff: 延迟时间递增倍数
|
||||
exceptions: 需要重试的异常类型
|
||||
|
||||
Returns:
|
||||
函数执行结果
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
current_delay = delay
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries:
|
||||
logging.warning(f"操作失败,{current_delay:.1f}秒后重试 ({attempt + 1}/{max_retries}): {e}")
|
||||
time.sleep(current_delay)
|
||||
current_delay *= backoff
|
||||
else:
|
||||
logging.error(f"操作最终失败,已重试{max_retries}次: {e}")
|
||||
raise
|
||||
|
||||
raise last_exception
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ExponentialBackoff:
|
||||
"""指数退避管理器"""
|
||||
|
||||
def __init__(self, base_delay: float = 1.0, max_delay: float = 60.0,
|
||||
max_attempts: int = 5):
|
||||
self.base_delay = base_delay
|
||||
self.max_delay = max_delay
|
||||
self.max_attempts = max_attempts
|
||||
self.current_attempt = 0
|
||||
self.current_delay = base_delay
|
||||
|
||||
def reset(self):
|
||||
"""重置计数器"""
|
||||
self.current_attempt = 0
|
||||
self.current_delay = self.base_delay
|
||||
|
||||
def get_delay(self) -> float:
|
||||
"""获取当前延迟时间"""
|
||||
delay = self.base_delay * (2 ** self.current_attempt)
|
||||
return min(delay, self.max_delay)
|
||||
|
||||
def next_attempt(self) -> bool:
|
||||
"""
|
||||
执行下一次尝试
|
||||
|
||||
Returns:
|
||||
是否还可以继续尝试
|
||||
"""
|
||||
if self.current_attempt >= self.max_attempts:
|
||||
return False
|
||||
|
||||
self.current_attempt += 1
|
||||
return True
|
||||
|
||||
def sleep(self):
|
||||
"""休眠当前延迟时间"""
|
||||
time.sleep(self.get_delay())
|
||||
|
||||
|
||||
def ensure_directory_exists(path: str) -> bool:
|
||||
"""确保目录存在"""
|
||||
try:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"创建目录失败 {path}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def list_files(directory: str, extension: Optional[str] = None) -> List[str]:
|
||||
"""列出目录中的文件"""
|
||||
files = []
|
||||
try:
|
||||
for filename in os.listdir(directory):
|
||||
filepath = os.path.join(directory, filename)
|
||||
if os.path.isfile(filepath):
|
||||
if extension is None or filename.endswith(extension):
|
||||
files.append(filepath)
|
||||
except Exception as e:
|
||||
logging.error(f"列出文件失败 {directory}: {e}")
|
||||
return files
|
||||
|
||||
|
||||
def format_file_size(size_bytes: int) -> str:
|
||||
"""格式化文件大小"""
|
||||
for unit in ['B', 'KB', 'MB', 'GB']:
|
||||
if size_bytes < 1024:
|
||||
return f"{size_bytes:.2f} {unit}"
|
||||
size_bytes /= 1024
|
||||
return f"{size_bytes:.2f} TB"
|
||||
|
||||
|
||||
def draw_detection_results(
|
||||
image: np.ndarray,
|
||||
detections: List[Dict[str, Any]],
|
||||
class_names: Dict[int, str],
|
||||
colors: Optional[Dict[int, Tuple]] = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
在图像上绘制检测结果
|
||||
|
||||
Args:
|
||||
image: 原始图像 (BGR格式)
|
||||
detections: 检测结果列表
|
||||
class_names: 类别名称映射
|
||||
colors: 类别颜色映射
|
||||
|
||||
Returns:
|
||||
绘制后的图像
|
||||
"""
|
||||
result_image = image.copy()
|
||||
|
||||
if colors is None:
|
||||
colors = {
|
||||
0: (0, 255, 0), # 人员 - 绿色
|
||||
1: (0, 0, 255), # 车辆 - 红色
|
||||
2: (255, 0, 0), # 其他 - 蓝色
|
||||
}
|
||||
|
||||
for det in detections:
|
||||
bbox = det.get("bbox", [])
|
||||
conf = det.get("confidence", 0)
|
||||
class_id = det.get("class_id", 0)
|
||||
|
||||
if len(bbox) != 4:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = map(int, bbox)
|
||||
color = colors.get(class_id, (0, 255, 0))
|
||||
|
||||
cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2)
|
||||
|
||||
label = f"{class_names.get(class_id, 'unknown')}: {conf:.2f}"
|
||||
cv2.putText(result_image, label, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
||||
|
||||
return result_image
|
||||
|
||||
|
||||
def image_to_base64(image: np.ndarray, format: str = ".jpg") -> str:
|
||||
"""
|
||||
将图像转换为Base64编码
|
||||
|
||||
Args:
|
||||
image: OpenCV图像 (BGR格式)
|
||||
format: 图像格式 (.jpg, .png)
|
||||
|
||||
Returns:
|
||||
Base64编码字符串
|
||||
"""
|
||||
try:
|
||||
if format == ".jpg":
|
||||
encode_param = [cv2.IMWRITE_JPEG_QUALITY, 85]
|
||||
_, buffer = cv2.imencode(".jpg", image, encode_param)
|
||||
else:
|
||||
_, buffer = cv2.imencode(".png", image)
|
||||
|
||||
import base64
|
||||
return base64.b64encode(buffer).decode('utf-8')
|
||||
except Exception as e:
|
||||
logging.error(f"图像Base64编码失败: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
def base64_to_image(base64_str: str) -> Optional[np.ndarray]:
|
||||
"""
|
||||
将Base64编码转换为图像
|
||||
|
||||
Args:
|
||||
base64_str: Base64编码字符串
|
||||
|
||||
Returns:
|
||||
OpenCV图像 (BGR格式)
|
||||
"""
|
||||
try:
|
||||
import base64
|
||||
buffer = base64.b64decode(base64_str)
|
||||
nparr = np.frombuffer(buffer, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
return image
|
||||
except Exception as e:
|
||||
logging.error(f"Base64解码为图像失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_gpu_memory_info(device_id: int = 0) -> Dict[str, float]:
|
||||
"""
|
||||
获取GPU显存信息
|
||||
|
||||
Args:
|
||||
device_id: GPU设备ID
|
||||
|
||||
Returns:
|
||||
显存信息字典 (总显存、已用显存、可用显存)
|
||||
"""
|
||||
try:
|
||||
import pynvml
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
total = mem_info.total / (1024 ** 2)
|
||||
used = mem_info.used / (1024 ** 2)
|
||||
free = mem_info.free / (1024 ** 2)
|
||||
|
||||
pynvml.nvmlShutdown()
|
||||
|
||||
return {
|
||||
"total_mb": total,
|
||||
"used_mb": used,
|
||||
"free_mb": free,
|
||||
"used_percent": (used / total) * 100 if total > 0 else 0
|
||||
}
|
||||
except Exception as e:
|
||||
logging.warning(f"获取GPU显存信息失败: {e}")
|
||||
return {"total_mb": 0, "used_mb": 0, "free_mb": 0, "used_percent": 0}
|
||||
|
||||
|
||||
class PerformanceTimer:
|
||||
"""性能计时器"""
|
||||
|
||||
def __init__(self):
|
||||
self.start_times = {}
|
||||
self.elapsed_times = {}
|
||||
|
||||
def start(self, name: str = "default"):
|
||||
"""开始计时"""
|
||||
self.start_times[name] = time.perf_counter()
|
||||
|
||||
def stop(self, name: str = "default") -> float:
|
||||
"""停止计时并返回耗时"""
|
||||
if name in self.start_times:
|
||||
self.elapsed_times[name] = time.perf_counter() - self.start_times[name]
|
||||
del self.start_times[name]
|
||||
return self.elapsed_times[name]
|
||||
return 0.0
|
||||
|
||||
def get_elapsed(self, name: str = "default") -> float:
|
||||
"""获取已记录的耗时"""
|
||||
return self.elapsed_times.get(name, 0.0)
|
||||
|
||||
def reset(self):
|
||||
"""重置所有计时"""
|
||||
self.start_times.clear()
|
||||
self.elapsed_times.clear()
|
||||
|
||||
def get_average(self, name: str, count: int) -> float:
|
||||
"""获取平均耗时"""
|
||||
if name in self.elapsed_times and count > 0:
|
||||
return self.elapsed_times[name] / count
|
||||
return 0.0
|
||||
Reference in New Issue
Block a user