52 lines
1.4 KiB
Python
52 lines
1.4 KiB
Python
"""Configuration for Qwen3.5-9B benchmark."""
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional, List
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
@dataclass
|
|
class ModelConfig:
|
|
"""Model configuration."""
|
|
model_id: str = "Qwen/Qwen3.5-9B"
|
|
local_path: str = "./models/Qwen3.5-9B"
|
|
cache_dir: str = "./cache"
|
|
device: str = "cuda"
|
|
torch_dtype: str = "float16" # float16, bfloat16, int8, int4
|
|
max_length: int = 8192
|
|
trust_remote_code: bool = True
|
|
|
|
@dataclass
|
|
class BenchmarkConfig:
|
|
"""Benchmark configuration."""
|
|
# Test sequences of different lengths
|
|
input_lengths: List[int] = field(default_factory=lambda: [128, 512, 1024, 2048])
|
|
output_lengths: List[int] = field(default_factory=lambda: [128, 256, 512, 1024])
|
|
|
|
# Concurrency levels
|
|
concurrency_levels: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16])
|
|
|
|
# Number of warmup runs
|
|
warmup_runs: int = 3
|
|
|
|
# Number of benchmark runs
|
|
benchmark_runs: int = 10
|
|
|
|
# Batch sizes for throughput testing
|
|
batch_sizes: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16])
|
|
|
|
# Results output
|
|
results_dir: str = "./results"
|
|
|
|
@dataclass
|
|
class GPUConfig:
|
|
"""GPU monitoring configuration."""
|
|
monitor_interval: float = 0.1 # seconds
|
|
log_memory: bool = True
|
|
log_utilization: bool = True
|
|
|
|
# Global configs
|
|
model_config = ModelConfig()
|
|
benchmark_config = BenchmarkConfig()
|
|
gpu_config = GPUConfig() |