diff --git a/build_engine.py b/build_engine.py index cc2f62b..fe2a8d8 100644 --- a/build_engine.py +++ b/build_engine.py @@ -18,10 +18,15 @@ TensorRT Engine 生成脚本 (8GB显存优化版) --opt-batch 优化Batch大小 (默认: 4) <-- TensorRT会针对此尺寸专门优化 --max-batch 最大Batch大小 (默认: 8) --workspace 工作空间大小MB (默认: 6144,即6GB) - --tactics 启用优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN) + --tactics 优化策略 (默认: +7,等价于 +CUBLAS,+CUBLAS_LT,+CUDNN) --best 全局最优搜索 (默认: 启用) --preview 预览特性 (默认: +faster_dynamic_shapes_0805) -""" + + tactic_values: + CUBLAS = 1 + CUBLAS_LT = 2 + CUDNN = 4 + +7 = 全部启用""" import os import sys @@ -236,11 +241,27 @@ def build_engine( config.set_flag(trt.BuilderFlag.TF32) + tactic_value = 0 for source in tactic_sources.split(','): + source = source.strip() + if not source: + continue + if source.startswith('+'): - config.set_tactic_sources(int(source[1:])) - elif source.startswith('-'): - config.set_tactic_sources(~int(source[1:])) + name = source[1:] + if name.isdigit(): + tactic_value += int(name) + else: + name_upper = name.upper() + if name_upper == 'CUBLAS': + tactic_value += 1 + elif name_upper == 'CUBLAS_LT': + tactic_value += 2 + elif name_upper == 'CUDNN': + tactic_value += 4 + + if tactic_value > 0: + config.set_tactic_sources(tactic_value) if best: config.set_flag(trt.BuilderFlag.BENCHMARK)