-
Notifications
You must be signed in to change notification settings - Fork 4.1k
/
Copy pathpython_custom_ops.py
278 lines (225 loc) · 10.6 KB
/
python_custom_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# -*- coding: utf-8 -*-
"""
.. _python-custom-ops-tutorial:
Custom Python Operators
=======================
.. grid:: 2
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites
* How to integrate custom operators written in Python with PyTorch
* How to test custom operators using ``torch.library.opcheck``
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites
* PyTorch 2.4 or later
PyTorch offers a large library of operators that work on Tensors (e.g.
``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized
operator with PyTorch, perhaps written by a third-party library. This tutorial
shows how to wrap Python functions so that they behave like PyTorch native
operators. Reasons why you may wish to create a custom operator in PyTorch include:
- Treating an arbitrary Python function as an opaque callable with respect
to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing
into the function).
- Adding training support to an arbitrary Python function
Use :func:`torch.library.custom_op` to create Python custom operators.
Use the C++ ``TORCH_LIBRARY`` APIs to create C++ custom operators (these
work in Python-less environments).
See the `Custom Operators Landing Page <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html>`_
for more details.
Please note that if your operation can be expressed as a composition of
existing PyTorch operators, then there is usually no need to use the custom operator
API -- everything (for example ``torch.compile``, training support) should
just work.
"""
######################################################################
# Example: Wrapping PIL's crop into a custom operator
# ------------------------------------
# Let's say that we are using PIL's ``crop`` operation.
import torch
from torchvision.transforms.functional import to_pil_image, pil_to_tensor
import PIL
import IPython
import matplotlib.pyplot as plt
def crop(pic, box):
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return pil_to_tensor(cropped_img).to(pic.device) / 255.
def display(img):
plt.imshow(img.numpy().transpose((1, 2, 0)))
img = torch.ones(3, 64, 64)
img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)
display(img)
######################################################################
cropped_img = crop(img, (10, 10, 50, 50))
display(cropped_img)
######################################################################
# ``crop`` is not handled effectively out-of-the-box by
# ``torch.compile``: ``torch.compile`` induces a
# `"graph break" <https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks>`_
# on functions it is unable to handle and graph breaks are bad for performance.
# The following code demonstrates this by raising an error
# (``torch.compile`` with ``fullgraph=True`` raises an error if a
# graph break occurs).
@torch.compile(fullgraph=True)
def f(img):
return crop(img, (10, 10, 50, 50))
# The following raises an error. Uncomment the line to see it.
# cropped_img = f(img)
######################################################################
# In order to black-box ``crop`` for use with ``torch.compile``, we need to
# do two things:
#
# 1. wrap the function into a PyTorch custom operator.
# 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator.
# Given some ``FakeTensors`` inputs (dummy Tensors that don't have storage),
# this function should return dummy Tensors of your choice with the correct
# Tensor metadata (shape/strides/``dtype``/device).
from typing import Sequence
# Use torch.library.custom_op to define a new custom operator.
# If your operator mutates any input Tensors, their names must be specified
# in the ``mutates_args`` argument.
@torch.library.custom_op("mylib::crop", mutates_args=())
def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:
img = to_pil_image(pic.cpu())
cropped_img = img.crop(box)
return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)
# Use register_fake to add a ``FakeTensor`` kernel for the operator
@crop.register_fake
def _(pic, box):
channels = pic.shape[0]
x0, y0, x1, y1 = box
result = pic.new_empty(y1 - y0, x1 - x0, channels).permute(2, 0, 1)
# The result should have the same metadata (shape/strides/``dtype``/device)
# as running the ``crop`` function above.
return result
######################################################################
# After this, ``crop`` now works without graph breaks:
@torch.compile(fullgraph=True)
def f(img):
return crop(img, (10, 10, 50, 50))
cropped_img = f(img)
display(img)
######################################################################
display(cropped_img)
######################################################################
# Adding training support for crop
# --------------------------------
# Use ``torch.library.register_autograd`` to add training support for an operator.
# Prefer this over directly using ``torch.autograd.Function``; some compositions of
# ``autograd.Function`` with PyTorch operator registration APIs can lead to (and
# has led to) silent incorrectness when composed with ``torch.compile``.
#
# If you don't need training support, there is no need to use
# ``torch.library.register_autograd``.
# If you end up training with a ``custom_op`` that doesn't have an autograd
# registration, we'll raise an error message.
#
# The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the
# derivation as an exercise to the reader). Let's first wrap ``paste`` into a
# custom operator:
@torch.library.custom_op("mylib::paste", mutates_args=())
def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:
assert im1.device == im2.device
assert im1.dtype == im2.dtype
im1_pil = to_pil_image(im1.cpu())
im2_pil = to_pil_image(im2.cpu())
PIL.Image.Image.paste(im1_pil, im2_pil, coord)
return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)
@paste.register_fake
def _(im1, im2, coord):
assert im1.device == im2.device
assert im1.dtype == im2.dtype
return torch.empty_like(im1)
######################################################################
# And now let's use ``register_autograd`` to specify the gradient formula for ``crop``:
def backward(ctx, grad_output):
grad_input = grad_output.new_zeros(ctx.pic_shape)
grad_input = paste(grad_input, grad_output, ctx.coords)
return grad_input, None
def setup_context(ctx, inputs, output):
pic, box = inputs
ctx.coords = box[:2]
ctx.pic_shape = pic.shape
crop.register_autograd(backward, setup_context=setup_context)
######################################################################
# Note that the backward must be a composition of PyTorch-understood operators,
# which is why we wrapped paste into a custom operator instead of directly using
# PIL's paste.
img = img.requires_grad_()
result = crop(img, (10, 10, 50, 50))
result.sum().backward()
display(img.grad)
######################################################################
# This is the correct gradient, with 1s (white) in the cropped region and 0s
# (black) in the unused region.
######################################################################
# Testing Python Custom operators
# -------------------------------
# Use ``torch.library.opcheck`` to test that the custom operator was registered
# correctly. This does not test that the gradients are mathematically correct;
# please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``).
#
# To use ``opcheck``, pass it a set of example inputs to test against. If your
# operator supports training, then the examples should include Tensors that
# require grad. If your operator supports multiple devices, then the examples
# should include Tensors from each device.
examples = [
[torch.randn(3, 64, 64), [0, 0, 10, 10]],
[torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],
[torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],
[torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],
]
for example in examples:
torch.library.opcheck(crop, example)
######################################################################
# Mutable Python Custom operators
# -------------------------------
# You can also wrap a Python function that mutates its inputs into a custom
# operator.
# Functions that mutate inputs are common because that is how many low-level
# kernels are written; for example, a kernel that computes ``sin`` may take in
# the input and an output tensor and write ``input.sin()`` to the output tensor.
#
# We'll use ``numpy.sin`` to demonstrate an example of a mutable Python
# custom operator.
import numpy as np
@torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu")
def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
assert input.device == output.device
assert input.device.type == "cpu"
input_np = input.numpy()
output_np = output.numpy()
np.sin(input_np, out=output_np)
######################################################################
# Because the operator doesn't return anything, there is no need to register
# a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``.
@torch.compile(fullgraph=True)
def f(x):
out = torch.empty(3)
numpy_sin(x, out)
return out
x = torch.randn(3)
y = f(x)
assert torch.allclose(y, x.sin())
######################################################################
# And here's an ``opcheck`` run telling us that we did indeed register the operator correctly.
# ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example.
example_inputs = [
[torch.randn(3), torch.empty(3)],
[torch.randn(0, 3), torch.empty(0, 3)],
[torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],
]
for example in example_inputs:
torch.library.opcheck(numpy_sin, example)
######################################################################
# Conclusion
# ----------
# In this tutorial, we learned how to use ``torch.library.custom_op`` to
# create a custom operator in Python that works with PyTorch subsystems
# such as ``torch.compile`` and autograd.
#
# This tutorial provides a basic introduction to custom operators.
# For more detailed information, see:
#
# - `the torch.library documentation <https://pytorch.org/docs/stable/library.html>`_
# - `the Custom Operators Manual <https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual>`_
#