fix: 修复 tactic_sources 参数解析错误
- 支持数值格式 (+7) 和名称格式 (+CUBLAS,+CUDNN) - 添加 tactic_values 常量说明文档
This commit is contained in:
@@ -18,10 +18,15 @@ TensorRT Engine 生成脚本 (8GB显存优化版)
|
|||||||
--opt-batch 优化Batch大小 (默认: 4) <-- TensorRT会针对此尺寸专门优化
|
--opt-batch 优化Batch大小 (默认: 4) <-- TensorRT会针对此尺寸专门优化
|
||||||
--max-batch 最大Batch大小 (默认: 8)
|
--max-batch 最大Batch大小 (默认: 8)
|
||||||
--workspace 工作空间大小MB (默认: 6144,即6GB)
|
--workspace 工作空间大小MB (默认: 6144,即6GB)
|
||||||
--tactics 启用优化策略 (默认: +CUBLAS,+CUBLAS_LT,+CUDNN)
|
--tactics 优化策略 (默认: +7,等价于 +CUBLAS,+CUBLAS_LT,+CUDNN)
|
||||||
--best 全局最优搜索 (默认: 启用)
|
--best 全局最优搜索 (默认: 启用)
|
||||||
--preview 预览特性 (默认: +faster_dynamic_shapes_0805)
|
--preview 预览特性 (默认: +faster_dynamic_shapes_0805)
|
||||||
"""
|
|
||||||
|
tactic_values:
|
||||||
|
CUBLAS = 1
|
||||||
|
CUBLAS_LT = 2
|
||||||
|
CUDNN = 4
|
||||||
|
+7 = 全部启用"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -236,11 +241,27 @@ def build_engine(
|
|||||||
|
|
||||||
config.set_flag(trt.BuilderFlag.TF32)
|
config.set_flag(trt.BuilderFlag.TF32)
|
||||||
|
|
||||||
|
tactic_value = 0
|
||||||
for source in tactic_sources.split(','):
|
for source in tactic_sources.split(','):
|
||||||
|
source = source.strip()
|
||||||
|
if not source:
|
||||||
|
continue
|
||||||
|
|
||||||
if source.startswith('+'):
|
if source.startswith('+'):
|
||||||
config.set_tactic_sources(int(source[1:]))
|
name = source[1:]
|
||||||
elif source.startswith('-'):
|
if name.isdigit():
|
||||||
config.set_tactic_sources(~int(source[1:]))
|
tactic_value += int(name)
|
||||||
|
else:
|
||||||
|
name_upper = name.upper()
|
||||||
|
if name_upper == 'CUBLAS':
|
||||||
|
tactic_value += 1
|
||||||
|
elif name_upper == 'CUBLAS_LT':
|
||||||
|
tactic_value += 2
|
||||||
|
elif name_upper == 'CUDNN':
|
||||||
|
tactic_value += 4
|
||||||
|
|
||||||
|
if tactic_value > 0:
|
||||||
|
config.set_tactic_sources(tactic_value)
|
||||||
|
|
||||||
if best:
|
if best:
|
||||||
config.set_flag(trt.BuilderFlag.BENCHMARK)
|
config.set_flag(trt.BuilderFlag.BENCHMARK)
|
||||||
|
|||||||
Reference in New Issue
Block a user