Files
Test_AI/calibration_gen.py
16337 942244bd88 Add YOLO11 TensorRT quantization benchmark scripts
- Engine build scripts (FP16/INT8)
- Benchmark validation scripts
- Result parsing and analysis tools
- COCO dataset configuration
2026-01-29 13:59:42 +08:00

207 lines
7.4 KiB
Python

import os
import glob
import numpy as np
import cv2
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class DataLoader:
def __init__(self, data_dir, batch_size, input_shape):
self.batch_size = batch_size
self.input_shape = input_shape
self.img_paths = glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True)
if not self.img_paths:
self.img_paths = glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False)
logger.info(f"Found {len(self.img_paths)} images for calibration in {data_dir}")
self.batch_idx = 0
self.max_batches = len(self.img_paths) // self.batch_size
if self.max_batches == 0:
raise ValueError(f"Not enough images for calibration! Found {len(self.img_paths)}, need at least {batch_size}")
logger.info(f"Total batches for calibration: {self.max_batches}")
self.calibration_data = np.zeros((self.batch_size, *self.input_shape), dtype=np.float32)
def reset(self):
self.batch_idx = 0
def next_batch(self):
if self.batch_idx >= self.max_batches:
return None
start = self.batch_idx * self.batch_size
end = start + self.batch_size
batch_paths = self.img_paths[start:end]
for i, path in enumerate(batch_paths):
img = cv2.imread(path)
if img is None:
logger.warning(f"Failed to read image: {path}")
continue
img = self.preprocess(img, (self.input_shape[1], self.input_shape[2]))
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img, dtype=np.float32) / 255.0
self.calibration_data[i] = img
self.batch_idx += 1
return np.ascontiguousarray(self.calibration_data.ravel())
def preprocess(self, img, new_shape=(640, 640), color=(114, 114, 114)):
shape = img.shape[:2]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
r = min(r, 1.0)
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
dw /= 2
dh /= 2
if shape[::-1] != new_unpad:
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return img
class YOLOEntropyCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, data_loader, cache_file='yolo_int8.cache'):
super().__init__()
self.data_loader = data_loader
self.cache_file = cache_file
self.d_input = cuda.mem_alloc(data_loader.calibration_data.nbytes)
self.data_loader.reset()
def get_batch_size(self):
return self.data_loader.batch_size
def get_batch(self, names):
try:
batch = self.data_loader.next_batch()
if batch is None:
return None
cuda.memcpy_htod(self.d_input, batch)
return [int(self.d_input)]
except Exception as e:
logger.error(f"Error in get_batch: {e}")
return None
def read_calibration_cache(self):
if os.path.exists(self.cache_file):
logger.info(f"Reading calibration cache from {self.cache_file}")
with open(self.cache_file, "rb") as f:
return f.read()
return None
def write_calibration_cache(self, cache):
logger.info(f"Writing calibration cache to {self.cache_file}")
with open(self.cache_file, "wb") as f:
f.write(cache)
def generate_calibration_cache(onnx_path, cache_file, data_dir, batch_size=8, input_shape=(3, 640, 640)):
logger.info("="*60)
logger.info("Starting Calibration Cache Generation")
logger.info("="*60)
if not os.path.exists(onnx_path):
logger.error(f"ONNX model not found: {onnx_path}")
return False
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
config = builder.create_builder_config()
parser = trt.OnnxParser(network, TRT_LOGGER)
logger.info(f"Parsing ONNX model: {onnx_path}")
with open(onnx_path, 'rb') as model:
if not parser.parse(model.read()):
logger.error("Failed to parse ONNX model")
for error in range(parser.num_errors):
logger.error(parser.get_error(error))
return False
logger.info("ONNX model parsed successfully")
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1 << 30)
if builder.platform_has_fast_int8:
logger.info("INT8 calibration enabled")
config.set_flag(trt.BuilderFlag.INT8)
logger.info(f"Loading calibration data from: {data_dir}")
calib_loader = DataLoader(data_dir, batch_size=batch_size, input_shape=input_shape)
calibrator = YOLOEntropyCalibrator(calib_loader, cache_file=cache_file)
config.int8_calibrator = calibrator
else:
logger.warning("INT8 not supported, falling back to FP16")
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
input_tensor = network.get_input(0)
input_name = input_tensor.name
logger.info(f"Input name: {input_name}")
logger.info(f"Input shape: {input_tensor.shape}")
profile.set_shape(input_name,
(1, 3, 640, 640),
(4, 3, 640, 640),
(8, 3, 640, 640))
config.add_optimization_profile(profile)
logger.info("Building engine for calibration (this will generate cache)...")
try:
engine = builder.build_engine(network, config)
if engine:
logger.info("Calibration engine built successfully")
del engine
logger.info(f"Calibration cache written to: {cache_file}")
return True
else:
logger.error("Failed to build engine")
return False
except Exception as e:
logger.error(f"Error during calibration: {e}")
return False
def main():
onnx_path = 'yolo11n.onnx'
cache_file = 'yolo11n_int8.cache'
data_dir = 'data'
batch_size = 8
input_shape = (3, 640, 640)
logger.info(f"ONNX model: {onnx_path}")
logger.info(f"Output cache: {cache_file}")
logger.info(f"Data directory: {data_dir}")
logger.info(f"Batch size: {batch_size}")
logger.info(f"Input shape: {input_shape}")
success = generate_calibration_cache(onnx_path, cache_file, data_dir, batch_size, input_shape)
if success:
logger.info("="*60)
logger.info("Calibration cache generated successfully!")
logger.info(f"Cache file: {os.path.abspath(cache_file)}")
logger.info("="*60)
else:
logger.error("Calibration cache generation failed!")
exit(1)
if __name__ == "__main__":
main()