import os import glob import numpy as np import cv2 import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit import logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class DataLoader: def __init__(self, data_dir, batch_size, input_shape): self.batch_size = batch_size self.input_shape = input_shape self.img_paths = glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True) if not self.img_paths: self.img_paths = glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False) logger.info(f"Found {len(self.img_paths)} images for calibration in {data_dir}") self.batch_idx = 0 self.max_batches = len(self.img_paths) // self.batch_size if self.max_batches == 0: raise ValueError(f"Not enough images for calibration! Found {len(self.img_paths)}, need at least {batch_size}") logger.info(f"Total batches for calibration: {self.max_batches}") self.calibration_data = np.zeros((self.batch_size, *self.input_shape), dtype=np.float32) def reset(self): self.batch_idx = 0 def next_batch(self): if self.batch_idx >= self.max_batches: return None start = self.batch_idx * self.batch_size end = start + self.batch_size batch_paths = self.img_paths[start:end] for i, path in enumerate(batch_paths): img = cv2.imread(path) if img is None: logger.warning(f"Failed to read image: {path}") continue img = self.preprocess(img, (self.input_shape[1], self.input_shape[2])) img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img, dtype=np.float32) / 255.0 self.calibration_data[i] = img self.batch_idx += 1 return np.ascontiguousarray(self.calibration_data.ravel()) def preprocess(self, img, new_shape=(640, 640), color=(114, 114, 114)): shape = img.shape[:2] if isinstance(new_shape, int): new_shape = (new_shape, new_shape) r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) r = min(r, 1.0) new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] dw /= 2 dh /= 2 if shape[::-1] != new_unpad: img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return img class YOLOEntropyCalibrator(trt.IInt8EntropyCalibrator2): def __init__(self, data_loader, cache_file='yolo_int8.cache'): super().__init__() self.data_loader = data_loader self.cache_file = cache_file self.d_input = cuda.mem_alloc(data_loader.calibration_data.nbytes) self.data_loader.reset() def get_batch_size(self): return self.data_loader.batch_size def get_batch(self, names): try: batch = self.data_loader.next_batch() if batch is None: return None cuda.memcpy_htod(self.d_input, batch) return [int(self.d_input)] except Exception as e: logger.error(f"Error in get_batch: {e}") return None def read_calibration_cache(self): if os.path.exists(self.cache_file): logger.info(f"Reading calibration cache from {self.cache_file}") with open(self.cache_file, "rb") as f: return f.read() return None def write_calibration_cache(self, cache): logger.info(f"Writing calibration cache to {self.cache_file}") with open(self.cache_file, "wb") as f: f.write(cache) def generate_calibration_cache(onnx_path, cache_file, data_dir, batch_size=8, input_shape=(3, 640, 640)): logger.info("="*60) logger.info("Starting Calibration Cache Generation") logger.info("="*60) if not os.path.exists(onnx_path): logger.error(f"ONNX model not found: {onnx_path}") return False TRT_LOGGER = trt.Logger(trt.Logger.INFO) builder = trt.Builder(TRT_LOGGER) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) config = builder.create_builder_config() parser = trt.OnnxParser(network, TRT_LOGGER) logger.info(f"Parsing ONNX model: {onnx_path}") with open(onnx_path, 'rb') as model: if not parser.parse(model.read()): logger.error("Failed to parse ONNX model") for error in range(parser.num_errors): logger.error(parser.get_error(error)) return False logger.info("ONNX model parsed successfully") config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1 << 30) if builder.platform_has_fast_int8: logger.info("INT8 calibration enabled") config.set_flag(trt.BuilderFlag.INT8) logger.info(f"Loading calibration data from: {data_dir}") calib_loader = DataLoader(data_dir, batch_size=batch_size, input_shape=input_shape) calibrator = YOLOEntropyCalibrator(calib_loader, cache_file=cache_file) config.int8_calibrator = calibrator else: logger.warning("INT8 not supported, falling back to FP16") config.set_flag(trt.BuilderFlag.FP16) profile = builder.create_optimization_profile() input_tensor = network.get_input(0) input_name = input_tensor.name logger.info(f"Input name: {input_name}") logger.info(f"Input shape: {input_tensor.shape}") profile.set_shape(input_name, (1, 3, 640, 640), (4, 3, 640, 640), (8, 3, 640, 640)) config.add_optimization_profile(profile) logger.info("Building engine for calibration (this will generate cache)...") try: engine = builder.build_engine(network, config) if engine: logger.info("Calibration engine built successfully") del engine logger.info(f"Calibration cache written to: {cache_file}") return True else: logger.error("Failed to build engine") return False except Exception as e: logger.error(f"Error during calibration: {e}") return False def main(): onnx_path = 'yolo11n.onnx' cache_file = 'yolo11n_int8.cache' data_dir = 'data' batch_size = 8 input_shape = (3, 640, 640) logger.info(f"ONNX model: {onnx_path}") logger.info(f"Output cache: {cache_file}") logger.info(f"Data directory: {data_dir}") logger.info(f"Batch size: {batch_size}") logger.info(f"Input shape: {input_shape}") success = generate_calibration_cache(onnx_path, cache_file, data_dir, batch_size, input_shape) if success: logger.info("="*60) logger.info("Calibration cache generated successfully!") logger.info(f"Cache file: {os.path.abspath(cache_file)}") logger.info("="*60) else: logger.error("Calibration cache generation failed!") exit(1) if __name__ == "__main__": main()