fix: 修复 tactic_sources 参数解析错误
- 支持数值格式 (+7) 和名称格式 (+CUBLAS,+CUDNN) - 添加 tactic_values 常量说明文档
This commit is contained in:
@@ -18,10 +18,15 @@ TensorRT Engine 生成脚本 (8GB显存优化版)
|
||||
--opt-batch 优化Batch大小 (默认: 4) <-- TensorRT会针对此尺寸专门优化
|
||||
--max-batch 最大Batch大小 (默认: 8)
|
||||
--workspace 工作空间大小MB (默认: 6144,即6GB)
|
||||
--tactics 启用优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN)
|
||||
--tactics 优化策略 (默认: +7,等价于 +CUBLAS,+CUBLAS_LT,+CUDNN)
|
||||
--best 全局最优搜索 (默认: 启用)
|
||||
--preview 预览特性 (默认: +faster_dynamic_shapes_0805)
|
||||
"""
|
||||
|
||||
tactic_values:
|
||||
CUBLAS = 1
|
||||
CUBLAS_LT = 2
|
||||
CUDNN = 4
|
||||
+7 = 全部启用"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
@@ -236,11 +241,27 @@ def build_engine(
|
||||
|
||||
config.set_flag(trt.BuilderFlag.TF32)
|
||||
|
||||
tactic_value = 0
|
||||
for source in tactic_sources.split(','):
|
||||
source = source.strip()
|
||||
if not source:
|
||||
continue
|
||||
|
||||
if source.startswith('+'):
|
||||
config.set_tactic_sources(int(source[1:]))
|
||||
elif source.startswith('-'):
|
||||
config.set_tactic_sources(~int(source[1:]))
|
||||
name = source[1:]
|
||||
if name.isdigit():
|
||||
tactic_value += int(name)
|
||||
else:
|
||||
name_upper = name.upper()
|
||||
if name_upper == 'CUBLAS':
|
||||
tactic_value += 1
|
||||
elif name_upper == 'CUBLAS_LT':
|
||||
tactic_value += 2
|
||||
elif name_upper == 'CUDNN':
|
||||
tactic_value += 4
|
||||
|
||||
if tactic_value > 0:
|
||||
config.set_tactic_sources(tactic_value)
|
||||
|
||||
if best:
|
||||
config.set_flag(trt.BuilderFlag.BENCHMARK)
|
||||
|
||||
Reference in New Issue
Block a user