From a6130b510265fd76f6cb8e0669af0de2f2f4bd91 Mon Sep 17 00:00:00 2001 From: 16337 <1633794139@qq.com> Date: Fri, 30 Jan 2026 09:20:05 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=8A=A8=E6=80=81?= =?UTF-8?q?=E7=BB=B4=E5=BA=A6=E5=86=85=E5=AD=98=E5=88=86=E9=85=8D=E9=94=99?= =?UTF-8?q?=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 处理 TensorRT 引擎的负维度 (-1) - 将动态 Batch 维度替换为最小值 1 --- core/tensorrt_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/tensorrt_engine.py b/core/tensorrt_engine.py index d3a245e..677f5e8 100644 --- a/core/tensorrt_engine.py +++ b/core/tensorrt_engine.py @@ -148,11 +148,13 @@ class TensorRTEngine: if self._input_binding: shape = self._input_binding["shape"] + shape = tuple(max(1, s) if s < 0 else s for s in shape) dtype = self._get_numpy_dtype(self._input_binding["dtype"]) self._memory_pool["input"] = np.zeros(shape, dtype=dtype) for output in self._output_bindings: shape = output["shape"] + shape = tuple(max(1, s) if s < 0 else s for s in shape) dtype = self._get_numpy_dtype(output["dtype"]) self._memory_pool[output["name"]] = np.zeros(shape, dtype=dtype)