274 lines
10 KiB
Python
274 lines
10 KiB
Python
|
|
import os
|
||
|
|
import glob
|
||
|
|
import numpy as np
|
||
|
|
import cv2
|
||
|
|
import tensorrt as trt
|
||
|
|
import pycuda.driver as cuda
|
||
|
|
import pycuda.autoinit
|
||
|
|
from ultralytics import YOLO
|
||
|
|
import logging
|
||
|
|
|
||
|
|
# 配置日志
|
||
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
class DataLoader:
|
||
|
|
def __init__(self, data_dir, batch_size, input_shape):
|
||
|
|
self.batch_size = batch_size
|
||
|
|
self.input_shape = input_shape # (C, H, W)
|
||
|
|
self.img_paths = glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True)
|
||
|
|
if not self.img_paths:
|
||
|
|
self.img_paths = glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False)
|
||
|
|
|
||
|
|
logger.info(f"Found {len(self.img_paths)} images for calibration in {data_dir}")
|
||
|
|
self.batch_idx = 0
|
||
|
|
self.max_batches = len(self.img_paths) // self.batch_size
|
||
|
|
|
||
|
|
# 预分配内存
|
||
|
|
self.calibration_data = np.zeros((self.batch_size, *self.input_shape), dtype=np.float32)
|
||
|
|
|
||
|
|
def reset(self):
|
||
|
|
self.batch_idx = 0
|
||
|
|
|
||
|
|
def next_batch(self):
|
||
|
|
if self.batch_idx >= self.max_batches:
|
||
|
|
return None
|
||
|
|
|
||
|
|
start = self.batch_idx * self.batch_size
|
||
|
|
end = start + self.batch_size
|
||
|
|
batch_paths = self.img_paths[start:end]
|
||
|
|
|
||
|
|
for i, path in enumerate(batch_paths):
|
||
|
|
img = cv2.imread(path)
|
||
|
|
if img is None:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Letterbox resize (Keep aspect ratio, padding)
|
||
|
|
img = self.preprocess(img, (self.input_shape[1], self.input_shape[2]))
|
||
|
|
|
||
|
|
# BGR to RGB, HWC to CHW, Normalize 0-1
|
||
|
|
img = img[:, :, ::-1].transpose(2, 0, 1)
|
||
|
|
img = np.ascontiguousarray(img, dtype=np.float32) / 255.0
|
||
|
|
|
||
|
|
self.calibration_data[i] = img
|
||
|
|
|
||
|
|
self.batch_idx += 1
|
||
|
|
return np.ascontiguousarray(self.calibration_data.ravel())
|
||
|
|
|
||
|
|
def preprocess(self, img, new_shape=(640, 640), color=(114, 114, 114)):
|
||
|
|
shape = img.shape[:2] # current shape [height, width]
|
||
|
|
if isinstance(new_shape, int):
|
||
|
|
new_shape = (new_shape, new_shape)
|
||
|
|
|
||
|
|
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
||
|
|
r = min(r, 1.0) # only scale down
|
||
|
|
|
||
|
|
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||
|
|
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]
|
||
|
|
dw /= 2
|
||
|
|
dh /= 2
|
||
|
|
|
||
|
|
if shape[::-1] != new_unpad:
|
||
|
|
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||
|
|
|
||
|
|
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||
|
|
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||
|
|
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
|
||
|
|
return img
|
||
|
|
|
||
|
|
class YOLOEntropyCalibrator(trt.IInt8EntropyCalibrator2):
|
||
|
|
def __init__(self, data_loader, cache_file='yolo_int8.cache'):
|
||
|
|
super().__init__()
|
||
|
|
self.data_loader = data_loader
|
||
|
|
self.cache_file = cache_file
|
||
|
|
self.d_input = cuda.mem_alloc(data_loader.calibration_data.nbytes)
|
||
|
|
self.data_loader.reset()
|
||
|
|
|
||
|
|
def get_batch_size(self):
|
||
|
|
return self.data_loader.batch_size
|
||
|
|
|
||
|
|
def get_batch(self, names):
|
||
|
|
try:
|
||
|
|
batch = self.data_loader.next_batch()
|
||
|
|
if batch is None:
|
||
|
|
return None
|
||
|
|
cuda.memcpy_htod(self.d_input, batch)
|
||
|
|
return [int(self.d_input)]
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error in get_batch: {e}")
|
||
|
|
return None
|
||
|
|
|
||
|
|
def read_calibration_cache(self):
|
||
|
|
if os.path.exists(self.cache_file):
|
||
|
|
logger.info(f"Reading calibration cache from {self.cache_file}")
|
||
|
|
with open(self.cache_file, "rb") as f:
|
||
|
|
return f.read()
|
||
|
|
return None
|
||
|
|
|
||
|
|
def write_calibration_cache(self, cache):
|
||
|
|
logger.info(f"Writing calibration cache to {self.cache_file}")
|
||
|
|
with open(self.cache_file, "wb") as f:
|
||
|
|
f.write(cache)
|
||
|
|
|
||
|
|
def build_engine(onnx_path, engine_path, data_dir):
|
||
|
|
logger.info("Building TensorRT Engine...")
|
||
|
|
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
|
||
|
|
builder = trt.Builder(TRT_LOGGER)
|
||
|
|
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
||
|
|
config = builder.create_builder_config()
|
||
|
|
parser = trt.OnnxParser(network, TRT_LOGGER)
|
||
|
|
|
||
|
|
# 1. Parse ONNX
|
||
|
|
with open(onnx_path, 'rb') as model:
|
||
|
|
if not parser.parse(model.read()):
|
||
|
|
for error in range(parser.num_errors):
|
||
|
|
logger.error(parser.get_error(error))
|
||
|
|
return None
|
||
|
|
|
||
|
|
# 2. Config Builder
|
||
|
|
# Memory pool limit (e.g. 4GB)
|
||
|
|
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * 1 << 30)
|
||
|
|
|
||
|
|
# INT8 Mode
|
||
|
|
if builder.platform_has_fast_int8:
|
||
|
|
logger.info("INT8 mode enabled.")
|
||
|
|
config.set_flag(trt.BuilderFlag.INT8)
|
||
|
|
# Calibration
|
||
|
|
calib_loader = DataLoader(data_dir, batch_size=8, input_shape=(3, 640, 640))
|
||
|
|
config.int8_calibrator = YOLOEntropyCalibrator(calib_loader)
|
||
|
|
else:
|
||
|
|
logger.warning("INT8 not supported on this platform. Falling back to FP16/FP32.")
|
||
|
|
|
||
|
|
# FP16 Mode (Mixed precision)
|
||
|
|
if builder.platform_has_fast_fp16:
|
||
|
|
config.set_flag(trt.BuilderFlag.FP16)
|
||
|
|
|
||
|
|
# 3. Dynamic Shapes
|
||
|
|
profile = builder.create_optimization_profile()
|
||
|
|
input_name = network.get_input(0).name
|
||
|
|
logger.info(f"Input Name: {input_name}")
|
||
|
|
|
||
|
|
# Min: 1, Opt: 4, Max: 8
|
||
|
|
profile.set_shape(input_name, (1, 3, 640, 640), (4, 3, 640, 640), (8, 3, 640, 640))
|
||
|
|
config.add_optimization_profile(profile)
|
||
|
|
|
||
|
|
# 4. Build Serialized Engine
|
||
|
|
try:
|
||
|
|
serialized_engine = builder.build_serialized_network(network, config)
|
||
|
|
if serialized_engine:
|
||
|
|
with open(engine_path, 'wb') as f:
|
||
|
|
f.write(serialized_engine)
|
||
|
|
logger.info(f"Engine saved to {engine_path}")
|
||
|
|
return engine_path
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Build failed: {e}")
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
def export_onnx(model_name='yolo11n.pt'):
|
||
|
|
model = YOLO(model_name)
|
||
|
|
logger.info(f"Exporting {model_name} to ONNX...")
|
||
|
|
# 导出 dynamic=True 以支持动态 batch
|
||
|
|
path = model.export(format='onnx', dynamic=True, simplify=True, opset=12)
|
||
|
|
return path
|
||
|
|
|
||
|
|
def validate_models(pt_model_path, engine_path, data_yaml='coco8.yaml'):
|
||
|
|
logger.info("="*60)
|
||
|
|
logger.info("Model Validation and mAP Drop Calculation")
|
||
|
|
logger.info("="*60)
|
||
|
|
|
||
|
|
map_results = {}
|
||
|
|
|
||
|
|
logger.info("=== Validating FP32 (PyTorch) ===")
|
||
|
|
model_pt = YOLO(pt_model_path)
|
||
|
|
try:
|
||
|
|
metrics_pt = model_pt.val(data=data_yaml, batch=1, imgsz=640, rect=False)
|
||
|
|
map50_95_pt = metrics_pt.box.map
|
||
|
|
map50_pt = metrics_pt.box.map50
|
||
|
|
logger.info(f"FP32 mAP50-95: {map50_95_pt:.4f}")
|
||
|
|
logger.info(f"FP32 mAP50: {map50_pt:.4f}")
|
||
|
|
map_results['fp32'] = {
|
||
|
|
'map50_95': map50_95_pt,
|
||
|
|
'map50': map50_pt
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"FP32 Validation failed: {e}")
|
||
|
|
map_results['fp32'] = {'map50_95': 0.0, 'map50': 0.0}
|
||
|
|
|
||
|
|
logger.info("")
|
||
|
|
logger.info("=== Validating INT8 (TensorRT) ===")
|
||
|
|
model_trt = YOLO(engine_path, task='detect')
|
||
|
|
try:
|
||
|
|
metrics_trt = model_trt.val(data=data_yaml, batch=1, imgsz=640, rect=False)
|
||
|
|
map50_95_trt = metrics_trt.box.map
|
||
|
|
map50_trt = metrics_trt.box.map50
|
||
|
|
logger.info(f"INT8 mAP50-95: {map50_95_trt:.4f}")
|
||
|
|
logger.info(f"INT8 mAP50: {map50_trt:.4f}")
|
||
|
|
map_results['int8'] = {
|
||
|
|
'map50_95': map50_95_trt,
|
||
|
|
'map50': map50_trt
|
||
|
|
}
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning(f"INT8 Validation failed: {e}")
|
||
|
|
map_results['int8'] = {'map50_95': 0.0, 'map50': 0.0}
|
||
|
|
|
||
|
|
logger.info("")
|
||
|
|
logger.info("="*60)
|
||
|
|
logger.info("mAP Drop Analysis")
|
||
|
|
logger.info("="*60)
|
||
|
|
|
||
|
|
if map_results['fp32']['map50_95'] > 0 and map_results['int8']['map50_95'] > 0:
|
||
|
|
drop_50_95 = (map_results['fp32']['map50_95'] - map_results['int8']['map50_95']) / map_results['fp32']['map50_95'] * 100
|
||
|
|
drop_50 = (map_results['fp32']['map50'] - map_results['int8']['map50']) / map_results['fp32']['map50'] * 100
|
||
|
|
|
||
|
|
logger.info(f"mAP50-95 Drop: {drop_50_95:.2f}%")
|
||
|
|
logger.info(f"mAP50 Drop: {drop_50:.2f}%")
|
||
|
|
logger.info("")
|
||
|
|
logger.info("Score Comparison:")
|
||
|
|
logger.info(f" FP32 mAP50-95: {map_results['fp32']['map50_95']:.4f} -> INT8 mAP50-95: {map_results['int8']['map50_95']:.4f}")
|
||
|
|
logger.info(f" FP32 mAP50: {map_results['fp32']['map50']:.4f} -> INT8 mAP50: {map_results['int8']['map50']:.4f}")
|
||
|
|
|
||
|
|
map_results['drop'] = {
|
||
|
|
'map50_95': drop_50_95,
|
||
|
|
'map50': drop_50
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
logger.warning("Could not calculate mAP drop due to missing validation results")
|
||
|
|
map_results['drop'] = {'map50_95': 0.0, 'map50': 0.0}
|
||
|
|
|
||
|
|
logger.info("="*60)
|
||
|
|
|
||
|
|
return map_results
|
||
|
|
|
||
|
|
def main():
|
||
|
|
model_name = 'yolo11n.pt'
|
||
|
|
onnx_path = 'yolo11n.onnx'
|
||
|
|
engine_path = 'yolo11n.engine'
|
||
|
|
data_dir = 'data' # 校准数据目录
|
||
|
|
val_data_yaml = 'coco8.yaml' # 验证集配置,用于计算 mAP
|
||
|
|
|
||
|
|
# 1. 导出 ONNX
|
||
|
|
if not os.path.exists(onnx_path):
|
||
|
|
onnx_path = export_onnx(model_name)
|
||
|
|
else:
|
||
|
|
logger.info(f"Found existing ONNX: {onnx_path}")
|
||
|
|
|
||
|
|
# 2. 构建 TensorRT Engine (含 INT8 校准)
|
||
|
|
if not os.path.exists(engine_path):
|
||
|
|
# 检查 data 目录是否有图片
|
||
|
|
if not glob.glob(os.path.join(data_dir, '**', '*.jpg'), recursive=True) and \
|
||
|
|
not glob.glob(os.path.join(data_dir, '*.jpg'), recursive=False):
|
||
|
|
logger.error(f"No images found in {data_dir} for calibration! Please prepare data first.")
|
||
|
|
return
|
||
|
|
|
||
|
|
build_engine(onnx_path, engine_path, data_dir)
|
||
|
|
else:
|
||
|
|
logger.info(f"Found existing Engine: {engine_path}. Skipping build.")
|
||
|
|
|
||
|
|
# 3. 验证与对比
|
||
|
|
logger.info("Starting Validation... (Ensure you have a valid dataset yaml)")
|
||
|
|
validate_models(model_name, engine_path, val_data_yaml)
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|