GPU测试
This commit is contained in:
136
benchmark/utils.py
Normal file
136
benchmark/utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
工具函数模块
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def setup_logging(
|
||||
level: int = logging.INFO,
|
||||
log_file: Optional[str] = None,
|
||||
format_str: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
) -> logging.Logger:
|
||||
"""设置日志配置"""
|
||||
logger = logging.getLogger("benchmark")
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
logger.setLevel(level)
|
||||
formatter = logging.Formatter(format_str)
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(level)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
if log_file:
|
||||
file_handler = logging.FileHandler(log_file, encoding="utf-8")
|
||||
file_handler.setLevel(level)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_file_hash(file_path: str, algorithm: str = "md5") -> str:
|
||||
"""计算文件哈希值"""
|
||||
hash_func = hashlib.new(algorithm)
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
hash_func.update(chunk)
|
||||
return hash_func.hexdigest()
|
||||
|
||||
|
||||
def ensure_dir(path: str) -> str:
|
||||
"""确保目录存在"""
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_timestamp() -> str:
|
||||
"""获取当前时间戳字符串"""
|
||||
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
"""格式化时间长度"""
|
||||
if seconds < 60:
|
||||
return f"{seconds:.1f}s"
|
||||
elif seconds < 3600:
|
||||
minutes = int(seconds // 60)
|
||||
secs = seconds % 60
|
||||
return f"{minutes}m {secs:.1f}s"
|
||||
else:
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
return f"{hours}h {minutes}m"
|
||||
|
||||
|
||||
def calculate_statistics(data: List[float]) -> dict:
|
||||
"""计算统计值"""
|
||||
if not data:
|
||||
return {"avg": 0.0, "min": 0.0, "max": 0.0, "p95": 0.0, "p99": 0.0, "std": 0.0}
|
||||
|
||||
arr = np.array(data)
|
||||
return {
|
||||
"avg": float(np.mean(arr)),
|
||||
"min": float(np.min(arr)),
|
||||
"max": float(np.max(arr)),
|
||||
"p95": float(np.percentile(arr, 95)),
|
||||
"p99": float(np.percentile(arr, 99)),
|
||||
"std": float(np.std(arr)),
|
||||
}
|
||||
|
||||
|
||||
class Timer:
|
||||
"""计时器上下文管理器"""
|
||||
def __init__(self, name: str = ""):
|
||||
self.name = name
|
||||
self.elapsed = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.elapsed = time.perf_counter() - self.start_time
|
||||
|
||||
@property
|
||||
def elapsed_ms(self) -> float:
|
||||
return self.elapsed * 1000
|
||||
|
||||
|
||||
class RateCounter:
|
||||
"""速率计数器"""
|
||||
def __init__(self, window_size: int = 100):
|
||||
self.window_size = window_size
|
||||
self.timestamps: List[float] = []
|
||||
self.counts: List[int] = []
|
||||
|
||||
def tick(self, count: int = 1):
|
||||
now = time.time()
|
||||
self.timestamps.append(now)
|
||||
self.counts.append(count)
|
||||
while len(self.timestamps) > self.window_size:
|
||||
self.timestamps.pop(0)
|
||||
self.counts.pop(0)
|
||||
|
||||
def get_rate(self) -> float:
|
||||
if len(self.timestamps) < 2:
|
||||
return 0.0
|
||||
duration = self.timestamps[-1] - self.timestamps[0]
|
||||
if duration <= 0:
|
||||
return 0.0
|
||||
return sum(self.counts) / duration
|
||||
|
||||
def reset(self):
|
||||
self.timestamps.clear()
|
||||
self.counts.clear()
|
||||
Reference in New Issue
Block a user