""" 公共工具函数模块 提供项目中常用的工具函数 """ import os import re import time import hashlib import logging import uuid from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, Union import cv2 import numpy as np def generate_unique_id(prefix: str = "") -> str: """生成唯一标识符""" timestamp = datetime.now().strftime("%Y%m%d%H%M%S") unique_str = f"{timestamp}_{uuid.uuid4().hex[:8]}" return f"{prefix}_{unique_str}" if prefix else unique_str def calculate_md5(data: Union[str, bytes]) -> str: """计算MD5哈希值""" if isinstance(data, str): data = data.encode('utf-8') return hashlib.md5(data).hexdigest() def get_current_timestamp(format_str: str = "%Y-%m-%d %H:%M:%S") -> str: """获取当前时间戳""" return datetime.now().strftime(format_str) def get_current_timestamp_ms() -> int: """获取当前时间戳(毫秒)""" return int(time.time() * 1000) def parse_rtsp_url(rtsp_url: str) -> Dict[str, str]: """解析RTSP URL获取各组成部分""" result = { "protocol": "", "host": "", "port": "", "path": "", "params": {} } try: match = re.match(r"(rtsp://)([^:]+):?(\d+)?([^?]*)\??(.*)", rtsp_url) if match: result["protocol"] = "rtsp" result["host"] = match.group(2) result["port"] = match.group(3) or "554" result["path"] = match.group(4) params_str = match.group(5) if params_str: for param in params_str.split("&"): if "=" in param: k, v = param.split("=", 1) result["params"][k] = v except Exception as e: logging.error(f"RTSP URL解析失败: {e}") return result def retry_operation(func, max_retries: int = 3, delay: float = 1.0, backoff: float = 2.0, exceptions: Tuple = (Exception,)): """ 带重试的函数执行装饰器 Args: func: 要执行的函数 max_retries: 最大重试次数 delay: 初始延迟时间(秒) backoff: 延迟时间递增倍数 exceptions: 需要重试的异常类型 Returns: 函数执行结果 """ def wrapper(*args, **kwargs): current_delay = delay last_exception = None for attempt in range(max_retries + 1): try: return func(*args, **kwargs) except exceptions as e: last_exception = e if attempt < max_retries: logging.warning(f"操作失败,{current_delay:.1f}秒后重试 ({attempt + 1}/{max_retries}): {e}") time.sleep(current_delay) current_delay *= backoff else: logging.error(f"操作最终失败,已重试{max_retries}次: {e}") raise raise last_exception return wrapper class ExponentialBackoff: """指数退避管理器""" def __init__(self, base_delay: float = 1.0, max_delay: float = 60.0, max_attempts: int = 5): self.base_delay = base_delay self.max_delay = max_delay self.max_attempts = max_attempts self.current_attempt = 0 self.current_delay = base_delay def reset(self): """重置计数器""" self.current_attempt = 0 self.current_delay = self.base_delay def get_delay(self) -> float: """获取当前延迟时间""" delay = self.base_delay * (2 ** self.current_attempt) return min(delay, self.max_delay) def next_attempt(self) -> bool: """ 执行下一次尝试 Returns: 是否还可以继续尝试 """ if self.current_attempt >= self.max_attempts: return False self.current_attempt += 1 return True def sleep(self): """休眠当前延迟时间""" time.sleep(self.get_delay()) def ensure_directory_exists(path: str) -> bool: """确保目录存在""" try: os.makedirs(path, exist_ok=True) return True except Exception as e: logging.error(f"创建目录失败 {path}: {e}") return False def list_files(directory: str, extension: Optional[str] = None) -> List[str]: """列出目录中的文件""" files = [] try: for filename in os.listdir(directory): filepath = os.path.join(directory, filename) if os.path.isfile(filepath): if extension is None or filename.endswith(extension): files.append(filepath) except Exception as e: logging.error(f"列出文件失败 {directory}: {e}") return files def format_file_size(size_bytes: int) -> str: """格式化文件大小""" for unit in ['B', 'KB', 'MB', 'GB']: if size_bytes < 1024: return f"{size_bytes:.2f} {unit}" size_bytes /= 1024 return f"{size_bytes:.2f} TB" def draw_detection_results( image: np.ndarray, detections: List[Dict[str, Any]], class_names: Dict[int, str], colors: Optional[Dict[int, Tuple]] = None ) -> np.ndarray: """ 在图像上绘制检测结果 Args: image: 原始图像 (BGR格式) detections: 检测结果列表 class_names: 类别名称映射 colors: 类别颜色映射 Returns: 绘制后的图像 """ result_image = image.copy() if colors is None: colors = { 0: (0, 255, 0), # 人员 - 绿色 1: (0, 0, 255), # 车辆 - 红色 2: (255, 0, 0), # 其他 - 蓝色 } for det in detections: bbox = det.get("bbox", []) conf = det.get("confidence", 0) class_id = det.get("class_id", 0) if len(bbox) != 4: continue x1, y1, x2, y2 = map(int, bbox) color = colors.get(class_id, (0, 255, 0)) cv2.rectangle(result_image, (x1, y1), (x2, y2), color, 2) label = f"{class_names.get(class_id, 'unknown')}: {conf:.2f}" cv2.putText(result_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) return result_image def image_to_base64(image: np.ndarray, format: str = ".jpg") -> str: """ 将图像转换为Base64编码 Args: image: OpenCV图像 (BGR格式) format: 图像格式 (.jpg, .png) Returns: Base64编码字符串 """ try: if format == ".jpg": encode_param = [cv2.IMWRITE_JPEG_QUALITY, 85] _, buffer = cv2.imencode(".jpg", image, encode_param) else: _, buffer = cv2.imencode(".png", image) import base64 return base64.b64encode(buffer).decode('utf-8') except Exception as e: logging.error(f"图像Base64编码失败: {e}") return "" def base64_to_image(base64_str: str) -> Optional[np.ndarray]: """ 将Base64编码转换为图像 Args: base64_str: Base64编码字符串 Returns: OpenCV图像 (BGR格式) """ try: import base64 buffer = base64.b64decode(base64_str) nparr = np.frombuffer(buffer, np.uint8) image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) return image except Exception as e: logging.error(f"Base64解码为图像失败: {e}") return None def get_gpu_memory_info(device_id: int = 0) -> Dict[str, float]: """ 获取GPU显存信息 Args: device_id: GPU设备ID Returns: 显存信息字典 (总显存、已用显存、可用显存) """ try: import pynvml pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) total = mem_info.total / (1024 ** 2) used = mem_info.used / (1024 ** 2) free = mem_info.free / (1024 ** 2) pynvml.nvmlShutdown() return { "total_mb": total, "used_mb": used, "free_mb": free, "used_percent": (used / total) * 100 if total > 0 else 0 } except Exception as e: logging.warning(f"获取GPU显存信息失败: {e}") return {"total_mb": 0, "used_mb": 0, "free_mb": 0, "used_percent": 0} class PerformanceTimer: """性能计时器""" def __init__(self): self.start_times = {} self.elapsed_times = {} def start(self, name: str = "default"): """开始计时""" self.start_times[name] = time.perf_counter() def stop(self, name: str = "default") -> float: """停止计时并返回耗时""" if name in self.start_times: self.elapsed_times[name] = time.perf_counter() - self.start_times[name] del self.start_times[name] return self.elapsed_times[name] return 0.0 def get_elapsed(self, name: str = "default") -> float: """获取已记录的耗时""" return self.elapsed_times.get(name, 0.0) def reset(self): """重置所有计时""" self.start_times.clear() self.elapsed_times.clear() def get_average(self, name: str, count: int) -> float: """获取平均耗时""" if name in self.elapsed_times and count > 0: return self.elapsed_times[name] / count return 0.0