fix: 修复 tactic_sources 参数解析错误

- 支持数值格式 (+7) 和名称格式 (+CUBLAS,+CUDNN)
- 添加 tactic_values 常量说明文档
This commit is contained in:
2026-01-29 18:48:27 +08:00
parent b0ddb6ee1a
commit 4103b503db

View File

@@ -18,10 +18,15 @@ TensorRT Engine 生成脚本 (8GB显存优化版)
--opt-batch 优化Batch大小 (默认: 4) <-- TensorRT会针对此尺寸专门优化
--max-batch 最大Batch大小 (默认: 8)
--workspace 工作空间大小MB (默认: 6144即6GB)
--tactics 启用优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN)
--tactics 优化策略 (默认: +7等价于 +CUBLAS,+CUBLAS_LT,+CUDNN)
--best 全局最优搜索 (默认: 启用)
--preview 预览特性 (默认: +faster_dynamic_shapes_0805)
"""
tactic_values:
CUBLAS = 1
CUBLAS_LT = 2
CUDNN = 4
+7 = 全部启用"""
import os
import sys
@@ -236,11 +241,27 @@ def build_engine(
config.set_flag(trt.BuilderFlag.TF32)
tactic_value = 0
for source in tactic_sources.split(','):
source = source.strip()
if not source:
continue
if source.startswith('+'):
config.set_tactic_sources(int(source[1:]))
elif source.startswith('-'):
config.set_tactic_sources(~int(source[1:]))
name = source[1:]
if name.isdigit():
tactic_value += int(name)
else:
name_upper = name.upper()
if name_upper == 'CUBLAS':
tactic_value += 1
elif name_upper == 'CUBLAS_LT':
tactic_value += 2
elif name_upper == 'CUDNN':
tactic_value += 4
if tactic_value > 0:
config.set_tactic_sources(tactic_value)
if best:
config.set_flag(trt.BuilderFlag.BENCHMARK)