Skip to content

1.6 model freezing tutorial #1077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a9bd604
Update feature classification labels
Jun 18, 2020
45d02c7
Update NVidia -> Nvidia
Jun 23, 2020
8d32bae
Merge branch 'master' into master
Jun 25, 2020
9a7250d
Merge branch 'master' into master
Jul 2, 2020
0d48b72
Merge pull request #1035 from jlin27/master
Jul 4, 2020
68c22a0
Bring back default filename_pattern so that by default we run all gal…
ezyang Jul 8, 2020
b6d1838
Add prototype_source directory
Jul 8, 2020
01fc130
Add prototype directory
Jul 8, 2020
26511cc
Add prototype
Jul 8, 2020
fb779e1
Remove extra "done"
Jul 8, 2020
494d037
Add REAME.txt
Jul 9, 2020
23fb4c7
Merge pull request #1058 from jlin27/master
Jul 9, 2020
d32aa04
Update for prototype instructions
Jul 9, 2020
67f76d3
Update for prototype feature
Jul 9, 2020
958aa33
refine torchvision_tutorial doc for windows
guyang3532 Jul 9, 2020
c83c23d
Merge pull request #1060 from guyang3532/fix_torchvision_tutorial_win
ezyang Jul 9, 2020
9b0635d
Update neural_style_tutorial.py (#1059)
hritikbhandari Jul 9, 2020
3740027
torch_script_custom_ops restructure (#1057)
ezyang Jul 9, 2020
3e32d22
Port custom ops tutorial to new registration API, increase testability.
ezyang Jul 9, 2020
999a029
Kill some other occurrences of RegisterOperators
ezyang Jul 9, 2020
f90f773
Update README.md
Jul 9, 2020
c6059ec
Make torch_script_custom_classes tutorial runnable
ezyang Jul 9, 2020
32e5407
Update torch_script_custom_classes to use TORCH_LIBRARY (#1062)
ezyang Jul 14, 2020
36c67f6
Add Model Freezing in TorchScript
Jul 21, 2020
36089ae
Merge branch 'release/1.6' into 1.6-model-freezing-tutorial
Jul 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Port custom ops tutorial to new registration API, increase testability.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
  • Loading branch information
ezyang committed Jul 9, 2020
commit 3e32d228c437a32dcfa4f70b09e14448c2832e18
230 changes: 105 additions & 125 deletions advanced_source/torch_script_custom_ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ like this:
:end-before: END output_tensor

We use the ``.ptr<float>()`` method on the OpenCV ``Mat`` class to get a raw
pointer to the underlying data (just like ``.data<float>()`` for the PyTorch
pointer to the underlying data (just like ``.data_ptr<float>()`` for the PyTorch
tensor earlier). We also specify the output shape of the tensor, which we
hardcoded as ``8 x 8``. The output of ``torch::from_blob`` is then a
``torch::Tensor``, pointing to the memory owned by the OpenCV matrix.
Expand All @@ -145,40 +145,28 @@ Registering the Custom Operator with TorchScript
Now that have implemented our custom operator in C++, we need to *register* it
with the TorchScript runtime and compiler. This will allow the TorchScript
compiler to resolve references to our custom operator in TorchScript code.
Registration is very simple. For our case, we need to write:
If you have ever used the pybind11 library, our syntax for registration
resembles the pybind11 syntax very closely. To register a single function,
we write:

.. literalinclude:: ../advanced_source/torch_script_custom_ops/op.cpp
:language: cpp
:start-after: BEGIN registry
:end-before: END registry

somewhere in the global scope of our ``op.cpp`` file. This creates a global
variable ``registry``, which will register our operator with TorchScript in its
constructor (i.e. exactly once per program). We specify the name of the
operator, and a pointer to its implementation (the function we wrote earlier).
The name consists of two parts: a *namespace* (``my_ops``) and a name for the
particular operator we are registering (``warp_perspective``). The namespace and
operator name are separated by two colons (``::``).
somewhere at the top level of our ``op.cpp`` file. The ``TORCH_LIBRARY`` macro
creates a function that will be called when your program starts. The name
of your library (``my_ops``) is given as the first argument (it should not
be in quotes). The second argument (``m``) defines a variable of type
``torch::Library`` which is the main interface to register your operators.
The method ``Library::def`` actually creates an operator named ``warp_perspective``,
exposing it to both Python and TorchScript. You can define as many operators
as you like by making multiple calls to ``def``.

.. tip::

If you want to register more than one operator, you can chain calls to
``.op()`` after the constructor:

.. code-block:: cpp

static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective)
.op("my_ops::another_op", &another_op)
.op("my_ops::and_another_op", &and_another_op);

Behind the scenes, ``RegisterOperators`` will perform a number of fairly
complicated C++ template metaprogramming magic tricks to infer the argument and
return value types of the function pointer we pass it (``&warp_perspective``).
This information is used to form a *function schema* for our operator. A
function schema is a structured representation of an operator -- a kind of
"signature" or "prototype" -- used by the TorchScript compiler to verify
correctness in TorchScript programs.
Behinds the scenes, the ``def`` function is actually doing quite a bit of work:
it is using template metaprogramming to inspect the type signature of your
function and translate it into an operator schema which specifies the operators
type within TorchScript's type system.

Building the Custom Operator
----------------------------
Expand All @@ -189,7 +177,16 @@ we can load into Python for research and experimentation, or into C++ for
inference in a no-Python environment. There exist multiple ways to build our
operator, using either pure CMake, or Python alternatives like ``setuptools``.
For brevity, the paragraphs below only discuss the CMake approach. The appendix
of this tutorial dives into the Python based alternatives.
of this tutorial dives into other alternatives.

Environment setup
*****************

We need an installation of PyTorch and OpenCV. The easiest and most platform
independent way to get both is to via Conda::

conda install -c pytorch pytorch
conda install opencv

Building with CMake
*******************
Expand All @@ -203,29 +200,11 @@ a directory structure that looks like this::
op.cpp
CMakeLists.txt

Also, make sure to grab the latest version of the LibTorch distribution, which
packages PyTorch's C++ libraries and CMake build files, from `pytorch.org
<https://pytorch.org/get-started/locally>`_. Place the unzipped distribution
somewhere accessible in your file system. The following paragraphs will refer to
that location as ``/path/to/libtorch``. The contents of our ``CMakeLists.txt``
file should then be the following:
The contents of our ``CMakeLists.txt`` file should then be the following:

.. literalinclude:: ../advanced_source/torch_script_custom_ops/CMakeLists.txt
:language: cpp

.. warning::

This setup makes some assumptions about the build environment, particularly
what pertains to the installation of OpenCV. The above ``CMakeLists.txt`` file
was tested inside a Docker container running Ubuntu Xenial with
``libopencv-dev`` installed via ``apt``. If it does not work for you and you
feel stuck, please use the ``Dockerfile`` in the `accompanying tutorial
repository <https://github.com/pytorch/extension-script>`_ to
build an isolated, reproducible environment in which to play around with the
code from this tutorial. If you run into further troubles, please file an
issue in the tutorial repository or post a question in `our forum
<https://discuss.pytorch.org/>`_.

To now build our operator, we can run the following commands from our
``warp_perspective`` folder:

Expand Down Expand Up @@ -268,24 +247,18 @@ To now build our operator, we can run the following commands from our
[100%] Built target warp_perspective

which will place a ``libwarp_perspective.so`` shared library file in the
``build`` folder. In the ``cmake`` command above, you should replace
``/path/to/libtorch`` with the path to your unzipped LibTorch distribution.
``build`` folder. In the ``cmake`` command above, we use the helper
variable ``torch.utils.cmake_prefix_path`` to conveniently tell us where
the cmake files for our PyTorch install are.

We will explore how to use and call our operator in detail further below, but to
get an early sensation of success, we can try running the following code in
Python:

.. code-block:: python

>>> import torch
>>> torch.ops.load_library("/path/to/libwarp_perspective.so")
>>> print(torch.ops.my_ops.warp_perspective)

Here, ``/path/to/libwarp_perspective.so`` should be a relative or absolute path
to the ``libwarp_perspective.so`` shared library we just built. If all goes
well, this should print something like
.. literalinclude:: ../advanced_source/torch_script_custom_ops/smoke_test.py
:language: python

.. code-block:: python
If all goes well, this should print something like::

<built-in method my_ops::warp_perspective of PyCapsule object at 0x7f618fc6fa50>

Expand All @@ -302,10 +275,9 @@ TorchScript code.
You already saw how to import your operator into Python:
``torch.ops.load_library()``. This function takes the path to a shared library
containing custom operators, and loads it into the current process. Loading the
shared library will also execute the constructor of the global
``RegisterOperators`` object we placed into our custom operator implementation
file. This will register our custom operator with the TorchScript compiler and
allow us to use that operator in TorchScript code.
shared library will also execute the ``TORCH_LIBRARY`` block. This will register
our custom operator with the TorchScript compiler and allow us to use that
operator in TorchScript code.

You can refer to your loaded operator as ``torch.ops.<namespace>.<function>``,
where ``<namespace>`` is the namespace part of your operator name, and
Expand All @@ -316,11 +288,16 @@ While this function can be used in scripted or traced TorchScript modules, we
can also just use it in vanilla eager PyTorch and pass it regular PyTorch
tensors:

.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
:language: python
:prepend: import torch
:start-after: BEGIN preamble
:end-before: END preamble

producing:

.. code-block:: python

>>> import torch
>>> torch.ops.load_library("libwarp_perspective.so")
>>> torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3))
tensor([[0.0000, 0.3218, 0.4611, ..., 0.4636, 0.4636, 0.4636],
[0.3746, 0.0978, 0.5005, ..., 0.4636, 0.4636, 0.4636],
[0.3245, 0.0169, 0.0000, ..., 0.4458, 0.4458, 0.4458],
Expand All @@ -332,90 +309,92 @@ tensors:

.. note::

What happens behind the scenes is that the first time you access
``torch.ops.namespace.function`` in Python, the TorchScript compiler (in C++
land) will see if a function ``namespace::function`` has been registered, and
if so, return a Python handle to this function that we can subsequently use to
call into our C++ operator implementation from Python. This is one noteworthy
difference between TorchScript custom operators and C++ extensions: C++
extensions are bound manually using pybind11, while TorchScript custom ops are
bound on the fly by PyTorch itself. Pybind11 gives you more flexibility with
regards to what types and classes you can bind into Python and is thus
recommended for purely eager code, but it is not supported for TorchScript
ops.
What happens behind the scenes is that the first time you access
``torch.ops.namespace.function`` in Python, the TorchScript compiler (in C++
land) will see if a function ``namespace::function`` has been registered, and
if so, return a Python handle to this function that we can subsequently use to
call into our C++ operator implementation from Python. This is one noteworthy
difference between TorchScript custom operators and C++ extensions: C++
extensions are bound manually using pybind11, while TorchScript custom ops are
bound on the fly by PyTorch itself. Pybind11 gives you more flexibility with
regards to what types and classes you can bind into Python and is thus
recommended for purely eager code, but it is not supported for TorchScript
ops.

From here on, you can use your custom operator in scripted or traced code just
as you would other functions from the ``torch`` package. In fact, "standard
library" functions like ``torch.matmul`` go through largely the same
registration path as custom operators, which makes custom operators really
first-class citizens when it comes to how and where they can be used in
TorchScript.
TorchScript. (One difference, however, is that standard library functions
have custom written Python argument parsing logic that differs from
``torch.ops`` argument parsing.)

Using the Custom Operator with Tracing
**************************************

Let's start by embedding our operator in a traced function. Recall that for
tracing, we start with some vanilla Pytorch code:

.. code-block:: python

def compute(x, y, z):
return x.matmul(y) + torch.relu(z)
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
:language: python
:start-after: BEGIN compute
:end-before: END compute

and then call ``torch.jit.trace`` on it. We further pass ``torch.jit.trace``
some example inputs, which it will forward to our implementation to record the
sequence of operations that occur as the inputs flow through it. The result of
this is effectively a "frozen" version of the eager PyTorch program, which the
TorchScript compiler can further analyze, optimize and serialize:

.. code-block:: python
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
:language: python
:start-after: BEGIN trace
:end-before: END trace

>>> inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
>>> trace = torch.jit.trace(compute, inputs)
>>> print(trace.graph)
graph(%x : Float(4, 8)
%y : Float(8, 5)
%z : Float(4, 5)) {
%3 : Float(4, 5) = aten::matmul(%x, %y)
%4 : Float(4, 5) = aten::relu(%z)
%5 : int = prim::Constant[value=1]()
%6 : Float(4, 5) = aten::add(%3, %4, %5)
return (%6);
}
Producing::

graph(%x : Float(4:8, 8:1),
%y : Float(8:5, 5:1),
%z : Float(4:5, 5:1)):
%3 : Float(4:5, 5:1) = aten::matmul(%x, %y) # test.py:10:0
%4 : Float(4:5, 5:1) = aten::relu(%z) # test.py:10:0
%5 : int = prim::Constant[value=1]() # test.py:10:0
%6 : Float(4:5, 5:1) = aten::add(%3, %4, %5) # test.py:10:0
return (%6)

Now, the exciting revelation is that we can simply drop our custom operator into
our PyTorch trace as if it were ``torch.relu`` or any other ``torch`` function:

.. code-block:: python

torch.ops.load_library("libwarp_perspective.so")

def compute(x, y, z):
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + torch.relu(z)
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
:language: python
:start-after: BEGIN compute2
:end-before: END compute2

and then trace it as before:

.. code-block:: python

>>> inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(8, 5)]
>>> trace = torch.jit.trace(compute, inputs)
>>> print(trace.graph)
graph(%x.1 : Float(4, 8)
%y : Float(8, 5)
%z : Float(8, 5)) {
%3 : int = prim::Constant[value=3]()
%4 : int = prim::Constant[value=6]()
%5 : int = prim::Constant[value=0]()
%6 : int[] = prim::Constant[value=[0, -1]]()
%7 : Float(3, 3) = aten::eye(%3, %4, %5, %6)
%x : Float(8, 8) = my_ops::warp_perspective(%x.1, %7)
%11 : Float(8, 5) = aten::matmul(%x, %y)
%12 : Float(8, 5) = aten::relu(%z)
%13 : int = prim::Constant[value=1]()
%14 : Float(8, 5) = aten::add(%11, %12, %13)
return (%14);
}
.. literalinclude:: ../advanced_source/torch_script_custom_ops/test.py
:language: python
:start-after: BEGIN trace2
:end-before: END trace2

Producing::

graph(%x.1 : Float(4:8, 8:1),
%y : Float(8:5, 5:1),
%z : Float(8:5, 5:1)):
%3 : int = prim::Constant[value=3]() # test.py:25:0
%4 : int = prim::Constant[value=6]() # test.py:25:0
%5 : int = prim::Constant[value=0]() # test.py:25:0
%6 : Device = prim::Constant[value="cpu"]() # test.py:25:0
%7 : bool = prim::Constant[value=0]() # test.py:25:0
%8 : Float(3:3, 3:1) = aten::eye(%3, %4, %5, %6, %7) # test.py:25:0
%x : Float(8:8, 8:1) = my_ops::warp_perspective(%x.1, %8) # test.py:25:0
%10 : Float(8:5, 5:1) = aten::matmul(%x, %y) # test.py:26:0
%11 : Float(8:5, 5:1) = aten::relu(%z) # test.py:26:0
%12 : int = prim::Constant[value=1]() # test.py:26:0
%13 : Float(8:5, 5:1) = aten::add(%10, %11, %12) # test.py:26:0
return (%13)

Integrating TorchScript custom ops into traced PyTorch code is as easy as this!

Expand Down Expand Up @@ -947,8 +926,9 @@ custom TorchScript operator as a string. For this, use
return output.clone();
}

static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
TORCH_LIBRARY(my_ops, m) {
m.def("warp_perspective", &warp_perspective);
}
"""

torch.utils.cpp_extension.load_inline(
Expand Down
5 changes: 3 additions & 2 deletions advanced_source/torch_script_custom_ops/op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ torch::Tensor warp_perspective(torch::Tensor image, torch::Tensor warp) {
// END warp_perspective

// BEGIN registry
static auto registry =
torch::RegisterOperators("my_ops::warp_perspective", &warp_perspective);
TORCH_LIBRARY(my_ops, m) {
m.def("warp_perspective", warp_perspective);
}
// END registry
3 changes: 3 additions & 0 deletions advanced_source/torch_script_custom_ops/smoke_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import torch
torch.ops.load_library("build/libwarp_perspective.so")
print(torch.ops.my_ops.warp_perspective)
34 changes: 34 additions & 0 deletions advanced_source/torch_script_custom_ops/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch


print("BEGIN preamble")
torch.ops.load_library("build/libwarp_perspective.so")
print(torch.ops.my_ops.warp_perspective(torch.randn(32, 32), torch.rand(3, 3)))
print("END preamble")


# BEGIN compute
def compute(x, y, z):
return x.matmul(y) + torch.relu(z)
# END compute


print("BEGIN trace")
inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(4, 5)]
trace = torch.jit.trace(compute, inputs)
print(trace.graph)
print("END trace")


# BEGIN compute2
def compute(x, y, z):
x = torch.ops.my_ops.warp_perspective(x, torch.eye(3))
return x.matmul(y) + torch.relu(z)
# END compute2


print("BEGIN trace2")
inputs = [torch.randn(4, 8), torch.randn(8, 5), torch.randn(8, 5)]
trace = torch.jit.trace(compute, inputs)
print(trace.graph)
print("END trace2")