Skip to content

Commit 1ca0ab5

Browse files
authored
Add PrivateUse1 Tutotial on integrating a new backend to pytorch via PrivateUse1 (#2526)
* Add PrivateUse1 Tutotial --------- Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent 420399d commit 1ca0ab5

File tree

2 files changed

+317
-0
lines changed

2 files changed

+317
-0
lines changed

advanced_source/privateuseone.rst

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
Facilitating New Backend Integration by PrivateUse1
2+
===================================================
3+
4+
In this tutorial we will walk through some necessary steps to integrate a new backend
5+
living outside ``pytorch/pytorch`` repo by ``PrivateUse1``. Note that this tutorial assumes that
6+
you already have a basic understanding of PyTorch.
7+
you are an advanced user of PyTorch.
8+
9+
.. note::
10+
11+
This tutorial only involves the parts related to the PrivateUse1 mechanism that facilitates the integration of new devices,
12+
and other parts will not be covered. At the same time, not all the modules involved in this tutorial are required,
13+
and you can choose the modules that are helpful to you according to your actual needs.
14+
15+
16+
What is PrivateUse1?
17+
--------------------
18+
19+
Prior to Pytorch 2.0, PyTorch provided three reserved dispatch keys (and their corresponding Autograd keys)
20+
for prototyping out-of-tree backend extensions, the three dispatch keys are as follows:
21+
22+
* ``PrivateUse1/AutogradPrivateUse1``
23+
* ``PrivateUse2/AutogradPrivateUse2``
24+
* ``PrivateUse3/AutogradPrivateUse3``
25+
26+
After the prototype verification is passed, you can apply for a private key for the new backend, such as CUDA, XLA, MPS, and so on.
27+
28+
However, with the rapid development of PyTorch, more and more hardware manufacturers are trying to
29+
integrate their backends into PyTorch, which might cause the following problems:
30+
31+
* Every new backend integration involves a lot of file modification
32+
* There is currently a hard limit on the number of Dispatch Keys (``DispatchKeySet`` 64-bit limit)
33+
34+
.. note::
35+
36+
There is also a problem with integrating the new backend into PyTorch through the PrivateUse1 Key, as it is impossible
37+
to integrate many backends at the same time. Fortunately, these out-of-tree backends are rarely used simultaneously.
38+
39+
40+
In view of the above reasons, the community began to recommend new backend to be integrated
41+
into the PyTorch via ``PrivateUse1``.
42+
43+
However, the previous ``PrivateUse1`` mechanism is not fully capable of integrating with the new backend, because it
44+
lacks some related support in certain modules, such as Storage, AMP, Distributed, and so on.
45+
46+
With the arrival of Pytorch 2.1.0, a series of optimizations and enhancements have been made
47+
for ``PrivateUse1`` in terms of new backend integration, and it is now possible to support the integration
48+
of new devices rapidly and efficiently.
49+
50+
How to integrate new backend via PrivateUse1
51+
--------------------------------------------
52+
53+
In this section, we will discuss the details of integrating the new backend into Pytorch via ``PrivateUse1``,
54+
which mainly consists of the following parts:
55+
56+
1. Register kernels for the new backend.
57+
2. Register generator for the new backend.
58+
3. Register device guard for the new backend.
59+
4. Register serialization and deserialization functions for new backend metadata.
60+
5. Other Modules.
61+
62+
Register kernels for the new backend
63+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
64+
65+
The new backend may have some high-performance implementations of operator, which can be registered to the dispatcher
66+
by ``TORCH_LIBRARY_IMPL`` API described in `Registering a Dispatched Operator in C++ <dispatcher>`_. This involves
67+
several situations:
68+
69+
1. Register all the forward operators supported by the new backend to the dispatcher, and register the fallback
70+
at the same time, so that when the new backend does not support some operators, these operators can fall back
71+
to the CPU for execution to ensure the availability of functions.
72+
73+
.. code-block:: cpp
74+
75+
at::Tensor wrapper_Custom_Tensor_add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
76+
// Implementation of add kernel in new backend
77+
...
78+
}
79+
80+
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
81+
...
82+
m.impl("add.Tensor", TORCH_FN(wrapper_Custom_Tensor_add));
83+
...
84+
}
85+
86+
void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
87+
// Add some hints about new devices that do not support and need to fall back to cpu
88+
at::native::cpu_fallback(op, stack);
89+
}
90+
91+
TORCH_LIBRARY_IMPL(_, PrivateUse1, m) {
92+
m.fallback(torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
93+
}
94+
95+
2. Register kernels from ``torch::autograd::Function`` to the dispatcher by ``AutogradPrivateUse1``, if it is necessary for
96+
new backend to override ``PyTorch Autograd layer``, the dispatcher and autograd system will automatically call the forward and
97+
backward implementations of these operators.
98+
99+
.. code-block:: cpp
100+
101+
class CumtomSeluFunction : public torch::autograd::Function<CumtomSeluFunction> {
102+
// Implementation of selu kernel in new backend
103+
}
104+
105+
at::Tensor wrapper_AutogradCumstom__selu(const at::Tensor & self) {
106+
return CumtomSeluFunction::apply(self);
107+
}
108+
109+
TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
110+
...
111+
m.impl("selu", TORCH_FN(wrapper_AutogradCustom__selu));
112+
...
113+
}
114+
115+
3. Register kernels which want to support `automatic mixed precision (AMP) <https://pytorch.org/docs/stable/amp.html>`_ and
116+
fallback mechanism to the dispatcher by ``AutocastPrivateUse1``, the autocast system will automatically call these kernels when needed.
117+
118+
.. code-block:: cpp
119+
120+
TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
121+
...
122+
KERNEL_PRIVATEUSEONE(<operator>, <policy>)
123+
...
124+
}
125+
126+
TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
127+
m.fallback(torch::CppFunction::makeFallthrough());
128+
}
129+
130+
What needs to be added is that if you want to support AMP in a new backend, you need to register a new ``BackendModule`` by
131+
``torch._register_device_module("backend_name", BackendModule)``, and the ``BackendModule`` needs to have the following APIs:
132+
133+
* ``get_amp_supported_dtype() -> List[torch.dtype]``
134+
get the supported dtypes on the new backend in AMP, which might support one more ``dtype``.
135+
* ``is_autocast_enabled() -> bool``
136+
check the AMP is enabled or not on the new backend.
137+
* ``get_autocast_dtype() -> torch.dtype``
138+
get the supported ``dtype`` on the new backend in AMP, which is set by ``set_autocast_dtype`` or the
139+
default ``dtype``, and the default ``dtype`` is ``torch.float16``.
140+
* ``set_autocast_enabled(bool) -> None``
141+
enable or disable AMP on the new backend.
142+
* ``set_autocast_dtype(dtype) -> None``
143+
set the supported ``dtype`` on the new backend in AMP, and the ``dtype`` be contained in the ``dtypes`` got
144+
from ``get_amp_supported_dtype``.
145+
146+
Register generator for the new backend
147+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148+
149+
It is necessary to support generators corresponding to new devices. Currently, ``PrivateUse1`` can dynamically
150+
register custom generators, which are mainly divided into the following steps.
151+
152+
1. Inherit the ``GeneratorImpl`` class to implement the generator class corresponding to the new backend,
153+
and implement various general methods.
154+
2. Define a new backend ``builder`` with a single parameter: ``device index``.
155+
3. Call ``REGISTER_GENERATOR_PRIVATEUSE1`` macro to complete dynamic registration.
156+
157+
.. code-block:: cpp
158+
159+
struct CustomGeneratorImpl : public c10::GeneratorImpl {
160+
// Implementation of generator in new backend
161+
}
162+
163+
at::Generator make_custom_generator(c10::DeviceIndex device_index) {
164+
return at::make_generator<CustomGeneratorImpl>(device_index);
165+
}
166+
167+
REGISTER_GENERATOR_PRIVATEUSE1(make_cumstom_generator)
168+
169+
Register device guard for the new backend
170+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
171+
172+
PyTorch provides functionalities related to device, stream, and event switching via ``DeviceGuard``.
173+
This function is also applicable to ``PrivateUse1`` Key.
174+
175+
1. Inherit the ``DeviceGuardImplInterface`` class to implement the various general methods corresponding to the new backend.
176+
2. Call ``C10_REGISTER_GUARD_IMPL`` macro to complete dynamic registration.
177+
178+
.. code-block:: cpp
179+
180+
struct CustomGuardImpl final : public c10::impl::DeviceGuardImplInterface {
181+
// Implementation of guard in new backend
182+
}
183+
184+
C10_REGISTER_GUARD_IMPL(PrivateUse1, CustomGuardImpl);
185+
186+
Register serialization and deserialization functions for new backend metadata
187+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
188+
189+
PyTorch is currently able to dynamically register serialization/deserialization functions to support the serialization and deserialization
190+
of new backend additional metadata named ``backend_meta_`` in class ``TensorImpl.ExtraMeta``. You can refer to the following steps:
191+
192+
1. Inherit the ``BackendMeta`` class to implement ``CustomBackendMetadata`` corresponding to the new backend and
193+
various fields of the new backend can be customized in the class.
194+
2. Implement the serialization and deserialization functions of the new backend, the function signatures are
195+
``void(const at::Tensor&, std::unordered_map<std::string, bool>&)``.
196+
3. Call the ``TensorBackendMetaRegistry`` macro to complete dynamic registration.
197+
198+
.. code-block:: cpp
199+
200+
struct CustomBackendMetadata : public c10::BackendMeta {
201+
// Implementation of backend metadata in new backend
202+
}
203+
204+
void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
205+
// Implementation of serialization
206+
}
207+
208+
void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
209+
// Implementation of deserialization
210+
}
211+
212+
TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1, &for_serialization, &for_deserialization);
213+
214+
Other Modules
215+
^^^^^^^^^^^^^
216+
217+
In addition to the above-mentioned parts, there are some other modules that can be expanded through ``PrivateUse1``,
218+
such as ``distributed collective communication``, ``benchmark timer``, and others, which will be added in the future.
219+
One example about ``PrivateUse1`` integration is `Ascend NPU <https://github.com/ascend/pytorch>`_.
220+
221+
222+
How to Improve User Experience with Privateuse1
223+
-----------------------------------------------
224+
225+
The primary goal of integrating new devices through ``PrivateUse1`` is to meet the basic functional requirements,
226+
and the next thing to do is to improve usability, which mainly involves the following aspects.
227+
228+
1. Register new backend module to Pytorch.
229+
2. Generate methods and properties related to the new backend.
230+
3. Generate methods and properties related to the new backend.
231+
232+
Register new backend module to Pytorch
233+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
234+
235+
Some CUDA-related interfaces in PyTorch can be called through the following form: ``torch.cuda.xxx``. Therefore, in order to
236+
comply with user habits, the new backend implemented through the ``PrivateUse1`` mechanism should also provide similar interfaces.
237+
238+
For example, using ``Ascend NPU``:
239+
240+
.. code-block:: python
241+
242+
torch._register_device_module('npu', torch_npu.npu)
243+
244+
After doing the above operations, users can call some exclusive APIs of ``Ascend NPU`` through ``torch.npu.xxx``
245+
246+
Rename PrivateUse1 to a custom name for the new backend
247+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
248+
249+
``PrivateUse1`` Key is the internal mechanism of the new backend integrated into PyTorch. For users, compared with ``PrivateUse1``,
250+
the custom name strongly related to the new backend should be more friendly.
251+
252+
Taking the ``Ascend NPU`` as an example, the first usage will be more user-friendly.
253+
254+
.. code-block:: python
255+
256+
torch.rand((2,2),device='npu:0')
257+
torch.rand((2,2),device='privateuse1:0')
258+
259+
Now, PyTorch provides a new C++/Python API for the self-named ``PrivateUse1`` backend, which is very simple to use.
260+
261+
.. tab-set-code::
262+
263+
.. code-block:: python
264+
265+
torch.rename_privateuse1_backend("npu")
266+
267+
.. code-block:: C++
268+
269+
c10::register_privateuse1_backend("npu")
270+
271+
Generate methods and properties related to the new backend
272+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
273+
274+
After renaming ``PrivateUse1`` to a custome name, automatically generate properties and methods related to the new backend name
275+
in the ``Tensor, nn, Storage`` modules for the new backend.
276+
277+
Here is an example for ``Ascend NPU``:
278+
279+
.. code-block:: python
280+
281+
torch.rename_privateuse1_backend("npu")
282+
unsupported_dtype = [torch.quint8]
283+
torch.utils.generate_methods_for_privateuse1_backend(for_tensor=True, for_module=True, for_storage=True, unsupported_dtype=unsupported_dtype)
284+
285+
Then, you can use the following methods and properties:
286+
287+
.. code-block:: python
288+
289+
torch.Tensor.npu()
290+
torch.Tensor.is_npu
291+
torch.Storage.npu()
292+
torch.Storage.is_npu
293+
...
294+
295+
Future Work
296+
-----------
297+
298+
The improvement of the ``PrivateUse1`` mechanism is still in progress, so the integration method of ``PrivateUse1``
299+
of the new module will be added in turn. Here are a few items that we are actively working on:
300+
301+
* Add the integration method of ``distributed collective communication``.
302+
* Add the integration method of ``benchmark timer``.
303+
304+
Conclusion
305+
----------
306+
307+
This tutorial walked you through the process of integrating new backends into PyTorch via ``PrivateUse1``, including but not limited to
308+
operator registration, generator registration, device guard registration, and so on. At the same time, some methods are introduced
309+
to improve the user experience.

index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,13 @@ What's new in PyTorch tutorials?
423423
:link: advanced/extend_dispatcher.html
424424
:tags: Extending-PyTorch,Frontend-APIs,C++
425425

426+
.. customcarditem::
427+
:header: Facilitating New Backend Integration by PrivateUse1
428+
:card_description: Learn how to integrate a new backend living outside of the pytorch/pytorch repo and maintain it to keep in sync with the native PyTorch backend.
429+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
430+
:link: advanced/privateuseone.html
431+
:tags: Extending-PyTorch,Frontend-APIs,C++
432+
426433
.. customcarditem::
427434
:header: Custom Function Tutorial: Double Backward
428435
:card_description: Learn how to write a custom autograd Function that supports double backward.
@@ -962,6 +969,7 @@ Additional Resources
962969
advanced/torch_script_custom_classes
963970
advanced/dispatcher
964971
advanced/extend_dispatcher
972+
advanced/privateuseone
965973

966974
.. toctree::
967975
:maxdepth: 2

0 commit comments

Comments
 (0)