Files
Test_AI/calibration_gen.py

207 lines
7.4 KiB
Python
Raw Permalink Normal View History

import os
import glob
import numpy as np
import cv2
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class DataLoader:
def __init__(self, data_dir, batch_size, input_shape):
self.batch_size = batch_size
self.input_shape = input_shape
self.img_paths = glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True)
if not self.img_paths:
self.img_paths = glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False)
logger.info(f"Found {len(self.img_paths)} images for calibration in {data_dir}")
self.batch_idx = 0
self.max_batches = len(self.img_paths) // self.batch_size
if self.max_batches == 0:
raise ValueError(f"Not enough images for calibration! Found {len(self.img_paths)}, need at least {batch_size}")
logger.info(f"Total batches for calibration: {self.max_batches}")
self.calibration_data = np.zeros((self.batch_size, *self.input_shape), dtype=np.float32)
def reset(self):
self.batch_idx = 0
def next_batch(self):
if self.batch_idx >= self.max_batches:
return None
start = self.batch_idx * self.batch_size
end = start + self.batch_size
batch_paths = self.img_paths[start:end]
for i, path in enumerate(batch_paths):
img = cv2.imread(path)
if img is None:
logger.warning(f"Failed to read image: {path}")
continue
img = self.preprocess(img, (self.input_shape[1], self.input_shape[2]))
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img, dtype=np.float32) / 255.0
self.calibration_data[i] = img
self.batch_idx += 1
return np.ascontiguousarray(self.calibration_data.ravel())
def preprocess(self, img, new_shape=(640, 640), color=(114, 114, 114)):
shape = img.shape[:2]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
r = min(r, 1.0)
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
dw /= 2
dh /= 2
if shape[::-1] != new_unpad:
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return img
class YOLOEntropyCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, data_loader, cache_file='yolo_int8.cache'):
super().__init__()
self.data_loader = data_loader
self.cache_file = cache_file
self.d_input = cuda.mem_alloc(data_loader.calibration_data.nbytes)
self.data_loader.reset()
def get_batch_size(self):
return self.data_loader.batch_size
def get_batch(self, names):
try:
batch = self.data_loader.next_batch()
if batch is None:
return None
cuda.memcpy_htod(self.d_input, batch)
return [int(self.d_input)]
except Exception as e:
logger.error(f"Error in get_batch: {e}")
return None
def read_calibration_cache(self):
if os.path.exists(self.cache_file):
logger.info(f"Reading calibration cache from {self.cache_file}")
with open(self.cache_file, "rb") as f:
return f.read()
return None
def write_calibration_cache(self, cache):
logger.info(f"Writing calibration cache to {self.cache_file}")
with open(self.cache_file, "wb") as f:
f.write(cache)
def generate_calibration_cache(onnx_path, cache_file, data_dir, batch_size=8, input_shape=(3, 640, 640)):
logger.info("="*60)
logger.info("Starting Calibration Cache Generation")
logger.info("="*60)
if not os.path.exists(onnx_path):
logger.error(f"ONNX model not found: {onnx_path}")
return False
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
config = builder.create_builder_config()
parser = trt.OnnxParser(network, TRT_LOGGER)
logger.info(f"Parsing ONNX model: {onnx_path}")
with open(onnx_path, 'rb') as model:
if not parser.parse(model.read()):
logger.error("Failed to parse ONNX model")
for error in range(parser.num_errors):
logger.error(parser.get_error(error))
return False
logger.info("ONNX model parsed successfully")
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1 << 30)
if builder.platform_has_fast_int8:
logger.info("INT8 calibration enabled")
config.set_flag(trt.BuilderFlag.INT8)
logger.info(f"Loading calibration data from: {data_dir}")
calib_loader = DataLoader(data_dir, batch_size=batch_size, input_shape=input_shape)
calibrator = YOLOEntropyCalibrator(calib_loader, cache_file=cache_file)
config.int8_calibrator = calibrator
else:
logger.warning("INT8 not supported, falling back to FP16")
config.set_flag(trt.BuilderFlag.FP16)
profile = builder.create_optimization_profile()
input_tensor = network.get_input(0)
input_name = input_tensor.name
logger.info(f"Input name: {input_name}")
logger.info(f"Input shape: {input_tensor.shape}")
profile.set_shape(input_name,
(1, 3, 640, 640),
(4, 3, 640, 640),
(8, 3, 640, 640))
config.add_optimization_profile(profile)
logger.info("Building engine for calibration (this will generate cache)...")
try:
engine = builder.build_engine(network, config)
if engine:
logger.info("Calibration engine built successfully")
del engine
logger.info(f"Calibration cache written to: {cache_file}")
return True
else:
logger.error("Failed to build engine")
return False
except Exception as e:
logger.error(f"Error during calibration: {e}")
return False
def main():
onnx_path = 'yolo11n.onnx'
cache_file = 'yolo11n_int8.cache'
data_dir = 'data'
batch_size = 8
input_shape = (3, 640, 640)
logger.info(f"ONNX model: {onnx_path}")
logger.info(f"Output cache: {cache_file}")
logger.info(f"Data directory: {data_dir}")
logger.info(f"Batch size: {batch_size}")
logger.info(f"Input shape: {input_shape}")
success = generate_calibration_cache(onnx_path, cache_file, data_dir, batch_size, input_shape)
if success:
logger.info("="*60)
logger.info("Calibration cache generated successfully!")
logger.info(f"Cache file: {os.path.abspath(cache_file)}")
logger.info("="*60)
else:
logger.error("Calibration cache generation failed!")
exit(1)
if __name__ == "__main__":
main()