Professional Documents
Culture Documents
Migrationguide RST
Migrationguide RST
Legacy Compatibility
--------------------
It is possible to continue using the legacy API by importing the TensorRT legacy
package.
Simply replace ``import tensorrt as trt`` (for example) with ``import
tensorrt.legacy as trt`` in your scripts.
Eventually, legacy support will be dropped, so it is still advisable to migrate to
the new API.
Submodules
----------
The new API removes the ``infer``, ``parsers``, ``utils``, and ``lite`` submodules.
Instead, all functionality is now included in the top-level ``tensorrt`` module.
+--------------------------------------------------------------------------
+--------------------------------------------------------------------------------+
| Legacy API | New
API |
+--------------------------------------------------------------------------
+--------------------------------------------------------------------------------+
| ``G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.WARNING)`` |
``TRT_LOGGER = trt.Logger(trt.Logger.WARNING)`` |
+--------------------------------------------------------------------------
+--------------------------------------------------------------------------------+
| ``builder = trt.infer.create_infer_builder(logger)`` |
``builder = trt.Builder(logger)`` |
+--------------------------------------------------------------------------
+--------------------------------------------------------------------------------+
| ``runtime = trt.infer.create_infer_runtime(logger)`` |
``runtime = trt.Runtime(logger)`` |
+--------------------------------------------------------------------------
+--------------------------------------------------------------------------------+
| ``parser = trt.parsers.caffeparser.create_caffe_parser()`` |
``parser = trt.CaffeParser()`` |
+--------------------------------------------------------------------------
+--------------------------------------------------------------------------------+
| ``parser = trt.parsers.uffparser.create_uff_parser()`` |
``parser = trt.UffParser()`` |
+--------------------------------------------------------------------------
+--------------------------------------------------------------------------------+
| ``parser.destroy()`` # Or any other TensorRT object | ``del
parser`` # with ... as ... clauses strongly preferred (see below) |
+--------------------------------------------------------------------------
+--------------------------------------------------------------------------------+
For example, building an engine from an ONNX file using ``with ... as ...`` might
look something like this:
::
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
ONNX_MODEL = "mnist.onnx"
def build_engine():
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as
network, trt.OnnxParser(network, TRT_LOGGER) as parser:
# Configure the builder here.
builder.max_workspace_size = 2**30
# Parse the model to create a network.
with open(ONNX_MODEL, 'rb') as model:
parser.parse(model.read())
# Build and return the engine. Note that the builder, network and
parser are destroyed when this function returns.
return builder.build_cuda_engine(network)
Data Types
----------
Data types have been given short-hand aliases to better align with naming
conventions established by Python libraries such as NumPy.
+---------------------------------+-------------------+
| Legacy API | New API |
+---------------------------------+-------------------+
| ``trt.infer.DataType.FLOAT`` | ``trt.float32`` |
+---------------------------------+-------------------+
| ``trt.infer.DataType.HALF`` | ``trt.float16`` |
+---------------------------------+-------------------+
| ``trt.infer.DataType.INT32`` | ``trt.int32`` |
+---------------------------------+-------------------+
| ``trt.infer.DataType.INT8`` | ``trt.int8`` |
+---------------------------------+-------------------+
In general, attribute and function names align with NumPy naming conventions.
That is, ``shape`` is used in place of ``get_dimensions``, ``dtype`` instead of
``get_type``,
``nbytes`` instead of ``size``, ``num`` instead of ``nb`` and so on.
+---------------------------------------------
+------------------------------------------------+
| Legacy API | New API
|
+---------------------------------------------
+------------------------------------------------+
| ``builder.set_fp16_mode(True)`` | ``builder.fp16_mode = True``
|
+---------------------------------------------
+------------------------------------------------+
| ``int8_mode = builder.get_int8_mode()`` | ``int8_mode = builder.int8_mode``
|
+---------------------------------------------
+------------------------------------------------+
| ``builder.set_max_workspace_size(1 << 20)`` | ``builder.max_workspace_size = 1 <<
20`` |
+---------------------------------------------
+------------------------------------------------+
+------------------------------------------------------------
+------------------------------------------------+
| Legacy API | New API
|
+------------------------------------------------------------
+------------------------------------------------+
| ``num_inputs = network.get_nb_inputs()`` | ``num_inputs =
network.num_inputs`` |
+------------------------------------------------------------
+------------------------------------------------+
| ``input_shape = network.get_input(0).get_dimensions()`` | ``input_shape =
network.get_input(0).shape`` |
+------------------------------------------------------------
+------------------------------------------------+
| ``input_type = network.get_input(0).get_type()`` | ``input_type =
network.get_input(0).dtype`` |
+------------------------------------------------------------
+------------------------------------------------+
dims2 = engine.get_binding_dimensions(1).to_DimsCHW()
# Compute volume. This will differ between different subclasses of Dims.
elt_count = dims2.vol() # This will not work for DimsHW or DimsNCHW, for
example.
# Or equivalently,
elt_count = dims2.C() * dims2.H() * dims2.W()
# New API
dims = network.get_input(0).shape # All Dims subclasses behave like iterables,
so we don't care about the exact subclass.
arr = np.random.randint(5, size=dims) # Dims act exactly like any other
iterable.
dims2 = engine.get_binding_shape(1)
# Compute volume. Works for any iterable, including lists and tuples.
elt_count = trt.volume(dims2)
# Or equivalently,
elt_count = reduce(lambda x, y: x * y, dims2)
Lightweight :class:`tensorrt.Weights`
-------------------------------------
Previously, the :class:`tensorrt.Weights` class would perform deep-copies of any
buffers used to create weights.
To better align with the C++ API, and for the sake of efficiency, the new bindings
no longer create
these deep copies, but instead increment the reference count of the existing
buffer.
Therefore, modifying the buffer used to create a :class:`tensorrt.Weights` object
will also modify the :class:`tensorrt.Weights` object.
Note that the :class:`tensorrt.ICudaEngine` will still create its own copies of
weights internally.
The above only applies to :class:`tensorrt.Weights` created before engine
construction (when using the Network API, for example).
# Legacy API
context.enqueue(batch_size, bindings, stream.handle, None)
# New API
context.execute_async(stream_handle=stream.handle, bindings=bindings,
batch_size=batch_size)
Serializing An Engine
~~~~~~~~~~~~~~~~~~~~~
::
# New API
with open("sample.engine", "wb") as f:
f.write(engine.serialize())
Deserializing An Engine
~~~~~~~~~~~~~~~~~~~~~~~
::
# New API
with open("sample.engine", "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
Migrating Plugins
-----------------
Using Pybind11
~~~~~~~~~~~~~~
SWIG-based wrappers are incompatible with TensorRT's pybind11-based bindings. In
order to migrate existing plugins,
you need to write a pybind11-based wrapper for the plugin. Typically, this involves
writing bindings for the plugin itself,
as well as the PluginFactory. For more details, refer to the `fc_plugin_caffe_mnist
<https://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-
guide/index.html#fc_plugin_caffe_mnist>`_ sample.
1. Build a shared object file for the plugin. Make sure the
``REGISTER_TENSORRT_PLUGIN`` macro is used in the plugin implementation.
2. In python, load the file created above using ``ctypes.CDLL()``
At this point, you can add plugins to your network using the
:class:`tensorrt.IPluginRegistry`
which can be retrieved with :func:`tensorrt.get_plugin_registry`. For more details,
refer to the `uff_custom_plugin <https://docs.nvidia.com/deeplearning/sdk/tensorrt-
developer-guide/index.html#uff_custom_plugin>`_ sample.