Files
Test_AI/quantize_yolo.py
16337 942244bd88 Add YOLO11 TensorRT quantization benchmark scripts
- Engine build scripts (FP16/INT8)
- Benchmark validation scripts
- Result parsing and analysis tools
- COCO dataset configuration
2026-01-29 13:59:42 +08:00

274 lines
10 KiB
Python

import os
import glob
import numpy as np
import cv2
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
from ultralytics import YOLO
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class DataLoader:
def __init__(self, data_dir, batch_size, input_shape):
self.batch_size = batch_size
self.input_shape = input_shape # (C, H, W)
self.img_paths = glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True)
if not self.img_paths:
self.img_paths = glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False)
logger.info(f"Found {len(self.img_paths)} images for calibration in {data_dir}")
self.batch_idx = 0
self.max_batches = len(self.img_paths) // self.batch_size
# 预分配内存
self.calibration_data = np.zeros((self.batch_size, *self.input_shape), dtype=np.float32)
def reset(self):
self.batch_idx = 0
def next_batch(self):
if self.batch_idx >= self.max_batches:
return None
start = self.batch_idx * self.batch_size
end = start + self.batch_size
batch_paths = self.img_paths[start:end]
for i, path in enumerate(batch_paths):
img = cv2.imread(path)
if img is None:
continue
# Letterbox resize (Keep aspect ratio, padding)
img = self.preprocess(img, (self.input_shape[1], self.input_shape[2]))
# BGR to RGB, HWC to CHW, Normalize 0-1
img = img[:, :, ::-1].transpose(2, 0, 1)
img = np.ascontiguousarray(img, dtype=np.float32) / 255.0
self.calibration_data[i] = img
self.batch_idx += 1
return np.ascontiguousarray(self.calibration_data.ravel())
def preprocess(self, img, new_shape=(640, 640), color=(114, 114, 114)):
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
r = min(r, 1.0) # only scale down
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
dw /= 2
dh /= 2
if shape[::-1] != new_unpad:
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
return img
class YOLOEntropyCalibrator(trt.IInt8EntropyCalibrator2):
def __init__(self, data_loader, cache_file='yolo_int8.cache'):
super().__init__()
self.data_loader = data_loader
self.cache_file = cache_file
self.d_input = cuda.mem_alloc(data_loader.calibration_data.nbytes)
self.data_loader.reset()
def get_batch_size(self):
return self.data_loader.batch_size
def get_batch(self, names):
try:
batch = self.data_loader.next_batch()
if batch is None:
return None
cuda.memcpy_htod(self.d_input, batch)
return [int(self.d_input)]
except Exception as e:
logger.error(f"Error in get_batch: {e}")
return None
def read_calibration_cache(self):
if os.path.exists(self.cache_file):
logger.info(f"Reading calibration cache from {self.cache_file}")
with open(self.cache_file, "rb") as f:
return f.read()
return None
def write_calibration_cache(self, cache):
logger.info(f"Writing calibration cache to {self.cache_file}")
with open(self.cache_file, "wb") as f:
f.write(cache)
def build_engine(onnx_path, engine_path, data_dir):
logger.info("Building TensorRT Engine...")
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(TRT_LOGGER)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
config = builder.create_builder_config()
parser = trt.OnnxParser(network, TRT_LOGGER)
# 1. Parse ONNX
with open(onnx_path, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
logger.error(parser.get_error(error))
return None
# 2. Config Builder
# Memory pool limit (e.g. 4GB)
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1 << 30)
# INT8 Mode
if builder.platform_has_fast_int8:
logger.info("INT8 mode enabled.")
config.set_flag(trt.BuilderFlag.INT8)
# Calibration
calib_loader = DataLoader(data_dir, batch_size=8, input_shape=(3, 640, 640))
config.int8_calibrator = YOLOEntropyCalibrator(calib_loader)
else:
logger.warning("INT8 not supported on this platform. Falling back to FP16/FP32.")
# FP16 Mode (Mixed precision)
if builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# 3. Dynamic Shapes
profile = builder.create_optimization_profile()
input_name = network.get_input(0).name
logger.info(f"Input Name: {input_name}")
# Min: 1, Opt: 4, Max: 8
profile.set_shape(input_name, (1, 3, 640, 640), (4, 3, 640, 640), (8, 3, 640, 640))
config.add_optimization_profile(profile)
# 4. Build Serialized Engine
try:
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine:
with open(engine_path, 'wb') as f:
f.write(serialized_engine)
logger.info(f"Engine saved to {engine_path}")
return engine_path
except Exception as e:
logger.error(f"Build failed: {e}")
return None
def export_onnx(model_name='yolo11n.pt'):
model = YOLO(model_name)
logger.info(f"Exporting {model_name} to ONNX...")
# 导出 dynamic=True 以支持动态 batch
path = model.export(format='onnx', dynamic=True, simplify=True, opset=12)
return path
def validate_models(pt_model_path, engine_path, data_yaml='coco8.yaml'):
logger.info("="*60)
logger.info("Model Validation and mAP Drop Calculation")
logger.info("="*60)
map_results = {}
logger.info("=== Validating FP32 (PyTorch) ===")
model_pt = YOLO(pt_model_path)
try:
metrics_pt = model_pt.val(data=data_yaml, batch=1, imgsz=640, rect=False)
map50_95_pt = metrics_pt.box.map
map50_pt = metrics_pt.box.map50
logger.info(f"FP32 mAP50-95: {map50_95_pt:.4f}")
logger.info(f"FP32 mAP50: {map50_pt:.4f}")
map_results['fp32'] = {
'map50_95': map50_95_pt,
'map50': map50_pt
}
except Exception as e:
logger.warning(f"FP32 Validation failed: {e}")
map_results['fp32'] = {'map50_95': 0.0, 'map50': 0.0}
logger.info("")
logger.info("=== Validating INT8 (TensorRT) ===")
model_trt = YOLO(engine_path, task='detect')
try:
metrics_trt = model_trt.val(data=data_yaml, batch=1, imgsz=640, rect=False)
map50_95_trt = metrics_trt.box.map
map50_trt = metrics_trt.box.map50
logger.info(f"INT8 mAP50-95: {map50_95_trt:.4f}")
logger.info(f"INT8 mAP50: {map50_trt:.4f}")
map_results['int8'] = {
'map50_95': map50_95_trt,
'map50': map50_trt
}
except Exception as e:
logger.warning(f"INT8 Validation failed: {e}")
map_results['int8'] = {'map50_95': 0.0, 'map50': 0.0}
logger.info("")
logger.info("="*60)
logger.info("mAP Drop Analysis")
logger.info("="*60)
if map_results['fp32']['map50_95'] > 0 and map_results['int8']['map50_95'] > 0:
drop_50_95 = (map_results['fp32']['map50_95'] - map_results['int8']['map50_95']) / map_results['fp32']['map50_95'] * 100
drop_50 = (map_results['fp32']['map50'] - map_results['int8']['map50']) / map_results['fp32']['map50'] * 100
logger.info(f"mAP50-95 Drop: {drop_50_95:.2f}%")
logger.info(f"mAP50 Drop: {drop_50:.2f}%")
logger.info("")
logger.info("Score Comparison:")
logger.info(f" FP32 mAP50-95: {map_results['fp32']['map50_95']:.4f} -> INT8 mAP50-95: {map_results['int8']['map50_95']:.4f}")
logger.info(f" FP32 mAP50: {map_results['fp32']['map50']:.4f} -> INT8 mAP50: {map_results['int8']['map50']:.4f}")
map_results['drop'] = {
'map50_95': drop_50_95,
'map50': drop_50
}
else:
logger.warning("Could not calculate mAP drop due to missing validation results")
map_results['drop'] = {'map50_95': 0.0, 'map50': 0.0}
logger.info("="*60)
return map_results
def main():
model_name = 'yolo11n.pt'
onnx_path = 'yolo11n.onnx'
engine_path = 'yolo11n.engine'
data_dir = 'data' # 校准数据目录
val_data_yaml = 'coco8.yaml' # 验证集配置,用于计算 mAP
# 1. 导出 ONNX
if not os.path.exists(onnx_path):
onnx_path = export_onnx(model_name)
else:
logger.info(f"Found existing ONNX: {onnx_path}")
# 2. 构建 TensorRT Engine (含 INT8 校准)
if not os.path.exists(engine_path):
# 检查 data 目录是否有图片
if not glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True) and \
not glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False):
logger.error(f"No images found in {data_dir} for calibration! Please prepare data first.")
return
build_engine(onnx_path, engine_path, data_dir)
else:
logger.info(f"Found existing Engine: {engine_path}. Skipping build.")
# 3. 验证与对比
logger.info("Starting Validation... (Ensure you have a valid dataset yaml)")
validate_models(model_name, engine_path, val_data_yaml)
if __name__ == "__main__":
main()