Files
qwen-test/config.py

52 lines
1.4 KiB
Python

"""Configuration for Qwen3.5-9B benchmark."""
import os
from dataclasses import dataclass, field
from typing import Optional, List
from dotenv import load_dotenv
load_dotenv()
@dataclass
class ModelConfig:
"""Model configuration."""
model_id: str = "Qwen/Qwen3.5-9B"
local_path: str = "./models/Qwen3.5-9B"
cache_dir: str = "./cache"
device: str = "cuda"
torch_dtype: str = "float16" # float16, bfloat16, int8, int4
max_length: int = 8192
trust_remote_code: bool = True
@dataclass
class BenchmarkConfig:
"""Benchmark configuration."""
# Test sequences of different lengths
input_lengths: List[int] = field(default_factory=lambda: [128, 512, 1024, 2048])
output_lengths: List[int] = field(default_factory=lambda: [128, 256, 512, 1024])
# Concurrency levels
concurrency_levels: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16])
# Number of warmup runs
warmup_runs: int = 3
# Number of benchmark runs
benchmark_runs: int = 10
# Batch sizes for throughput testing
batch_sizes: List[int] = field(default_factory=lambda: [1, 2, 4, 8, 16])
# Results output
results_dir: str = "./results"
@dataclass
class GPUConfig:
"""GPU monitoring configuration."""
monitor_interval: float = 0.1 # seconds
log_memory: bool = True
log_utilization: bool = True
# Global configs
model_config = ModelConfig()
benchmark_config = BenchmarkConfig()
gpu_config = GPUConfig()