From 4103b503db3b6b4cf60470bea2d95a64aa0542b0 Mon Sep 17 00:00:00 2001 From: 16337 <1633794139@qq.com> Date: Thu, 29 Jan 2026 18:48:27 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20tactic=5Fsources=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E8=A7=A3=E6=9E=90=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 支持数值格式 (+7) 和名称格式 (+CUBLAS,+CUDNN) - 添加 tactic_values 常量说明文档 --- build_engine.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/build_engine.py b/build_engine.py index cc2f62b..fe2a8d8 100644 --- a/build_engine.py +++ b/build_engine.py @@ -18,10 +18,15 @@ TensorRT Engine 生成脚本 (8GB显存优化版) --opt-batch 优化Batch大小 (默认: 4) <-- TensorRT会针对此尺寸专门优化 --max-batch 最大Batch大小 (默认: 8) --workspace 工作空间大小MB (默认: 6144,即6GB) - --tactics 启用优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN) + --tactics 优化策略 (默认: +7,等价于 +CUBLAS,+CUBLAS_LT,+CUDNN) --best 全局最优搜索 (默认: 启用) --preview 预览特性 (默认: +faster_dynamic_shapes_0805) -""" + + tactic_values: + CUBLAS = 1 + CUBLAS_LT = 2 + CUDNN = 4 + +7 = 全部启用""" import os import sys @@ -236,11 +241,27 @@ def build_engine( config.set_flag(trt.BuilderFlag.TF32) + tactic_value = 0 for source in tactic_sources.split(','): + source = source.strip() + if not source: + continue + if source.startswith('+'): - config.set_tactic_sources(int(source[1:])) - elif source.startswith('-'): - config.set_tactic_sources(~int(source[1:])) + name = source[1:] + if name.isdigit(): + tactic_value += int(name) + else: + name_upper = name.upper() + if name_upper == 'CUBLAS': + tactic_value += 1 + elif name_upper == 'CUBLAS_LT': + tactic_value += 2 + elif name_upper == 'CUDNN': + tactic_value += 4 + + if tactic_value > 0: + config.set_tactic_sources(tactic_value) if best: config.set_flag(trt.BuilderFlag.BENCHMARK)