diff --git a/core/tensorrt_engine.py b/core/tensorrt_engine.py index d3a245e..677f5e8 100644 --- a/core/tensorrt_engine.py +++ b/core/tensorrt_engine.py @@ -148,11 +148,13 @@ class TensorRTEngine: if self._input_binding: shape = self._input_binding["shape"] + shape = tuple(max(1, s) if s < 0 else s for s in shape) dtype = self._get_numpy_dtype(self._input_binding["dtype"]) self._memory_pool["input"] = np.zeros(shape, dtype=dtype) for output in self._output_bindings: shape = output["shape"] + shape = tuple(max(1, s) if s < 0 else s for s in shape) dtype = self._get_numpy_dtype(output["dtype"]) self._memory_pool[output["name"]] = np.zeros(shape, dtype=dtype)