import os import glob import numpy as np import cv2 import tensorrt as trt import pycuda.driver as cuda import pycuda.autoinit from ultralytics import YOLO import logging # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class DataLoader: def __init__(self, data_dir, batch_size, input_shape): self.batch_size = batch_size self.input_shape = input_shape # (C, H, W) self.img_paths = glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True) if not self.img_paths: self.img_paths = glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False) logger.info(f"Found {len(self.img_paths)} images for calibration in {data_dir}") self.batch_idx = 0 self.max_batches = len(self.img_paths) // self.batch_size # 预分配内存 self.calibration_data = np.zeros((self.batch_size, *self.input_shape), dtype=np.float32) def reset(self): self.batch_idx = 0 def next_batch(self): if self.batch_idx >= self.max_batches: return None start = self.batch_idx * self.batch_size end = start + self.batch_size batch_paths = self.img_paths[start:end] for i, path in enumerate(batch_paths): img = cv2.imread(path) if img is None: continue # Letterbox resize (Keep aspect ratio, padding) img = self.preprocess(img, (self.input_shape[1], self.input_shape[2])) # BGR to RGB, HWC to CHW, Normalize 0-1 img = img[:, :, ::-1].transpose(2, 0, 1) img = np.ascontiguousarray(img, dtype=np.float32) / 255.0 self.calibration_data[i] = img self.batch_idx += 1 return np.ascontiguousarray(self.calibration_data.ravel()) def preprocess(self, img, new_shape=(640, 640), color=(114, 114, 114)): shape = img.shape[:2] # current shape [height, width] if isinstance(new_shape, int): new_shape = (new_shape, new_shape) r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) r = min(r, 1.0) # only scale down new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] dw /= 2 dh /= 2 if shape[::-1] != new_unpad: img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) return img class YOLOEntropyCalibrator(trt.IInt8EntropyCalibrator2): def __init__(self, data_loader, cache_file='yolo_int8.cache'): super().__init__() self.data_loader = data_loader self.cache_file = cache_file self.d_input = cuda.mem_alloc(data_loader.calibration_data.nbytes) self.data_loader.reset() def get_batch_size(self): return self.data_loader.batch_size def get_batch(self, names): try: batch = self.data_loader.next_batch() if batch is None: return None cuda.memcpy_htod(self.d_input, batch) return [int(self.d_input)] except Exception as e: logger.error(f"Error in get_batch: {e}") return None def read_calibration_cache(self): if os.path.exists(self.cache_file): logger.info(f"Reading calibration cache from {self.cache_file}") with open(self.cache_file, "rb") as f: return f.read() return None def write_calibration_cache(self, cache): logger.info(f"Writing calibration cache to {self.cache_file}") with open(self.cache_file, "wb") as f: f.write(cache) def build_engine(onnx_path, engine_path, data_dir): logger.info("Building TensorRT Engine...") TRT_LOGGER = trt.Logger(trt.Logger.INFO) builder = trt.Builder(TRT_LOGGER) network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) config = builder.create_builder_config() parser = trt.OnnxParser(network, TRT_LOGGER) # 1. Parse ONNX with open(onnx_path, 'rb') as model: if not parser.parse(model.read()): for error in range(parser.num_errors): logger.error(parser.get_error(error)) return None # 2. Config Builder # Memory pool limit (e.g. 4GB) config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1 << 30) # INT8 Mode if builder.platform_has_fast_int8: logger.info("INT8 mode enabled.") config.set_flag(trt.BuilderFlag.INT8) # Calibration calib_loader = DataLoader(data_dir, batch_size=8, input_shape=(3, 640, 640)) config.int8_calibrator = YOLOEntropyCalibrator(calib_loader) else: logger.warning("INT8 not supported on this platform. Falling back to FP16/FP32.") # FP16 Mode (Mixed precision) if builder.platform_has_fast_fp16: config.set_flag(trt.BuilderFlag.FP16) # 3. Dynamic Shapes profile = builder.create_optimization_profile() input_name = network.get_input(0).name logger.info(f"Input Name: {input_name}") # Min: 1, Opt: 4, Max: 8 profile.set_shape(input_name, (1, 3, 640, 640), (4, 3, 640, 640), (8, 3, 640, 640)) config.add_optimization_profile(profile) # 4. Build Serialized Engine try: serialized_engine = builder.build_serialized_network(network, config) if serialized_engine: with open(engine_path, 'wb') as f: f.write(serialized_engine) logger.info(f"Engine saved to {engine_path}") return engine_path except Exception as e: logger.error(f"Build failed: {e}") return None def export_onnx(model_name='yolo11n.pt'): model = YOLO(model_name) logger.info(f"Exporting {model_name} to ONNX...") # 导出 dynamic=True 以支持动态 batch path = model.export(format='onnx', dynamic=True, simplify=True, opset=12) return path def validate_models(pt_model_path, engine_path, data_yaml='coco8.yaml'): logger.info("="*60) logger.info("Model Validation and mAP Drop Calculation") logger.info("="*60) map_results = {} logger.info("=== Validating FP32 (PyTorch) ===") model_pt = YOLO(pt_model_path) try: metrics_pt = model_pt.val(data=data_yaml, batch=1, imgsz=640, rect=False) map50_95_pt = metrics_pt.box.map map50_pt = metrics_pt.box.map50 logger.info(f"FP32 mAP50-95: {map50_95_pt:.4f}") logger.info(f"FP32 mAP50: {map50_pt:.4f}") map_results['fp32'] = { 'map50_95': map50_95_pt, 'map50': map50_pt } except Exception as e: logger.warning(f"FP32 Validation failed: {e}") map_results['fp32'] = {'map50_95': 0.0, 'map50': 0.0} logger.info("") logger.info("=== Validating INT8 (TensorRT) ===") model_trt = YOLO(engine_path, task='detect') try: metrics_trt = model_trt.val(data=data_yaml, batch=1, imgsz=640, rect=False) map50_95_trt = metrics_trt.box.map map50_trt = metrics_trt.box.map50 logger.info(f"INT8 mAP50-95: {map50_95_trt:.4f}") logger.info(f"INT8 mAP50: {map50_trt:.4f}") map_results['int8'] = { 'map50_95': map50_95_trt, 'map50': map50_trt } except Exception as e: logger.warning(f"INT8 Validation failed: {e}") map_results['int8'] = {'map50_95': 0.0, 'map50': 0.0} logger.info("") logger.info("="*60) logger.info("mAP Drop Analysis") logger.info("="*60) if map_results['fp32']['map50_95'] > 0 and map_results['int8']['map50_95'] > 0: drop_50_95 = (map_results['fp32']['map50_95'] - map_results['int8']['map50_95']) / map_results['fp32']['map50_95'] * 100 drop_50 = (map_results['fp32']['map50'] - map_results['int8']['map50']) / map_results['fp32']['map50'] * 100 logger.info(f"mAP50-95 Drop: {drop_50_95:.2f}%") logger.info(f"mAP50 Drop: {drop_50:.2f}%") logger.info("") logger.info("Score Comparison:") logger.info(f" FP32 mAP50-95: {map_results['fp32']['map50_95']:.4f} -> INT8 mAP50-95: {map_results['int8']['map50_95']:.4f}") logger.info(f" FP32 mAP50: {map_results['fp32']['map50']:.4f} -> INT8 mAP50: {map_results['int8']['map50']:.4f}") map_results['drop'] = { 'map50_95': drop_50_95, 'map50': drop_50 } else: logger.warning("Could not calculate mAP drop due to missing validation results") map_results['drop'] = {'map50_95': 0.0, 'map50': 0.0} logger.info("="*60) return map_results def main(): model_name = 'yolo11n.pt' onnx_path = 'yolo11n.onnx' engine_path = 'yolo11n.engine' data_dir = 'data' # 校准数据目录 val_data_yaml = 'coco8.yaml' # 验证集配置,用于计算 mAP # 1. 导出 ONNX if not os.path.exists(onnx_path): onnx_path = export_onnx(model_name) else: logger.info(f"Found existing ONNX: {onnx_path}") # 2. 构建 TensorRT Engine (含 INT8 校准) if not os.path.exists(engine_path): # 检查 data 目录是否有图片 if not glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True) and \ not glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False): logger.error(f"No images found in {data_dir} for calibration! Please prepare data first.") return build_engine(onnx_path, engine_path, data_dir) else: logger.info(f"Found existing Engine: {engine_path}. Skipping build.") # 3. 验证与对比 logger.info("Starting Validation... (Ensure you have a valid dataset yaml)") validate_models(model_name, engine_path, val_data_yaml) if __name__ == "__main__": main()