Skip to content

Commit 62e4c44

Browse files
committed
[Inductor] [Doc] Add debugging document for inductor cpu backend
1 parent 1068abe commit 62e4c44

File tree

2 files changed

+324
-0
lines changed

2 files changed

+324
-0
lines changed

index.rst

+8
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,13 @@ What's new in PyTorch tutorials?
564564
:link: intermediate/torch_compile_tutorial.html
565565
:tags: Model-Optimization
566566

567+
.. customcarditem::
568+
:header: Inductor CPU Backend Debugging and Profiling
569+
:card_description: Learn the usage, debugging and performance profiling for ``torch.compile`` with Inductor CPU backend.
570+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
571+
:link: intermediate/inductor_debug_cpu.html
572+
:tags: Model-Optimization
573+
567574
.. customcarditem::
568575
:header: (beta) Implementing High-Performance Transformers with SCALED DOT PRODUCT ATTENTION
569576
:card_description: This tutorial explores the new torch.nn.functional.scaled_dot_product_attention and how it can be used to construct Transformer components.
@@ -962,6 +969,7 @@ Additional Resources
962969
intermediate/nvfuser_intro_tutorial
963970
intermediate/ax_multiobjective_nas_tutorial
964971
intermediate/torch_compile_tutorial
972+
intermediate/inductor_debug_cpu
965973
intermediate/scaled_dot_product_attention_tutorial
966974

967975
.. toctree::
+316
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
Inductor CPU backend debugging and profiling
2+
==============================================
3+
4+
**Author**: `Liao Xuan <https://github.com/Valentine233>`_, `Zhu Haozhe <https://github.com/zhuhaozhe>`_
5+
6+
Usage
7+
--------------
8+
9+
Start with an example
10+
^^^^^^^^^^^^^^^^^^^
11+
12+
Here is a simple example to run the ``torch.compile`` with Inductor.
13+
14+
.. code-block:: python
15+
16+
import torch
17+
18+
def fn(x):
19+
return torch.neg(x)
20+
21+
x = torch.randn((2, 4, 28))
22+
compiled_fn = torch.compile(fn) # backend=inductor as default
23+
result = compiled_fn(x)
24+
25+
Get more loggings
26+
^^^^^^^^^^^^^^^^^^^
27+
28+
However, the above code would not give any debugging info. If we want to get more useful logging, one way is to add an environment variable.
29+
30+
.. code:: shell
31+
32+
TORCH_COMPILE_DEBUG=1 python xx.py
33+
34+
The time taken in each step is shown. This also does the graph visualization and prints the output code. In logging, a temperate debug tracing directory like this can be found.
35+
36+
.. code:: shell
37+
38+
torch._inductor.debug: [WARNING] model___20 debug trace: /tmp/torchinductor_root/rx/crxfi2ybd7yp5sbj2pnhw33wfhtdw7wumvrobyp5sjvdui5ktjc2.debug
39+
40+
The directory saves several files for debugging.
41+
42+
+-------------------------+----------------------------------------------------------+
43+
| fx_graph_readable.py | Readable FX graph, post decomps |
44+
+-------------------------+----------------------------------------------------------+
45+
| fx_graph_runnable.py | Executable FX graph, post decomps, pre pattern match |
46+
+-------------------------+----------------------------------------------------------+
47+
| fx_graph_transformed.py | Transformed FX graph, post pattern match |
48+
+-------------------------+----------------------------------------------------------+
49+
| ir_post_fusion.txt | Inductor IR before fusion |
50+
+-------------------------+----------------------------------------------------------+
51+
| ir_pre_fusion.txt | Inductor IR after fusion |
52+
+-------------------------+----------------------------------------------------------+
53+
| output_code.py | Generated Python code for graph, with cpp/triton kernels |
54+
+-------------------------+----------------------------------------------------------+
55+
56+
57+
``fx_graph_runnable.py`` and ``output_code.py`` are both runnable and editable in order to make debugging easier.
58+
59+
60+
Here is another way to print logging for Inductor.
61+
62+
.. code:: shell
63+
64+
TORCH_LOGS="+inductor,output_code,schedule" python xx.py
65+
66+
+--------------+-------------------------------------------------------------+
67+
| +inductor | Set the logging level of Inductor to DEBUG, default is INFO |
68+
+--------------+-------------------------------------------------------------+
69+
| +output_code | Print output code with cpp/triton kernels |
70+
+--------------+-------------------------------------------------------------+
71+
| +schedule | Print reasons for not doing vectorization in cpp kernels |
72+
+--------------+-------------------------------------------------------------+
73+
74+
Configs to do deeper analysis
75+
^^^^^^^^^^^^^^^^^^^
76+
77+
Moreover, there are several config parameters helping the analysis.
78+
79+
+--------------------------------------------------+---------------------------------------------------------------------+
80+
| torch._inductor.config.max_fusion_size | Set the maximum number of nodes allowed in one fusion |
81+
+--------------------------------------------------+---------------------------------------------------------------------+
82+
| torch._inductor.config.cpp.simdlen | Specify the bit width for cpp vectorization |
83+
+--------------------------------------------------+---------------------------------------------------------------------+
84+
| torch._inductor.config.cpp.min_chunk_size | Set the minimum number of workloads one thread should at least take |
85+
+--------------------------------------------------+---------------------------------------------------------------------+
86+
| torch._inductor.config.cpp.enable_kernel_profile | Allow cpp kernel performance profiling via profiler |
87+
+--------------------------------------------------+---------------------------------------------------------------------+
88+
89+
90+
Debugging
91+
--------------
92+
93+
Determine component of error
94+
^^^^^^^^^^^^^^^^^^^
95+
96+
When encountering errors or accuracy problem, a straightforward solution to find the bug is to narrow down the problem. The first thing to do is to determine the component where error occurs. Luckily, it can be simply achieved by changing the backend of ``torch.compile``.
97+
98+
+----------------------------------------+-----------------------------------------+
99+
| torch.compile(fn, backend="eager") | Enable Dynamo |
100+
+----------------------------------------+-----------------------------------------+
101+
| torch.compile(fn, backend="aot_eager") | Enable Dynamo + AOT autograd |
102+
+----------------------------------------+-----------------------------------------+
103+
| torch.compile(fn, backend="inductor") | Enable Dynamo + AOT autograd + Inductor |
104+
+----------------------------------------+-----------------------------------------+
105+
106+
If the model can successfully run when backend is eager or aot_eager while it fails with inductor, we can narrow down the failure to Inductor.
107+
108+
109+
Example
110+
^^^^^^^^^^^^^^^^^^^
111+
112+
Here is an example for the subsequent debugging.
113+
114+
.. code-block:: python
115+
116+
import torch
117+
from torch._dynamo.utils import same
118+
119+
def foo(x1, x2):
120+
a = torch.neg(x1)
121+
b = torch.maximum(x2, a)
122+
y = torch.cat([b], dim=0)
123+
return y
124+
125+
x1 = torch.randint(256, (1,), dtype=torch.uint8)
126+
x2 = torch.randint(256, (8390,), dtype=torch.uint8)
127+
128+
expected_result = fn(x1, x2)
129+
130+
compiled_fn = torch.compile(fn)
131+
actual_result = compiled_fn(x1, x2)
132+
133+
assert same(expected_result, actual_result) == True
134+
135+
136+
The implementation of ``neg`` in cpp codegen is as follows.
137+
138+
.. code-block:: python
139+
140+
def neg(x):
141+
return f"decltype({x})(-{x})"
142+
143+
144+
In order to demonstrate the debugging, we will modify the function to a wrong one later.
145+
146+
Errors debugging
147+
^^^^^^^^^^^^^^^^^^^
148+
149+
If it occurs a compile error, the root cause is usually shown in traceback log.
150+
151+
For example, the ``neg`` function is modified like this.
152+
153+
.. code-block:: python
154+
155+
def neg(x):
156+
return f"-{x}"
157+
158+
159+
The logging gives the following compile error with a rather clear reason. In this case, the root cause is that data types of maximum's inputs are inconsistent.
160+
161+
.. code:: shell
162+
163+
164+
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
165+
CppCompileError: C++ compile error
166+
167+
/tmp/torchinductor_root/2x/c2xgxsooklulr4u54etfnnha7dsu6xzbwdscttvs7dkpba3uwkem.cpp: In function ‘void kernel(const unsigned char*, const unsigned char*, unsigned char*)’:
168+
/tmp/torchinductor_root/2x/c2xgxsooklulr4u54etfnnha7dsu6xzbwdscttvs7dkpba3uwkem.cpp:14:53: error: no matching function for call to ‘max_propagate_nan(unsigned char&, int&)’
169+
14 | auto tmp3 = max_propagate_nan(tmp0, tmp2);
170+
| ^
171+
In file included from /tmp/torchinductor_root/2x/c2xgxsooklulr4u54etfnnha7dsu6xzbwdscttvs7dkpba3uwkem.cpp:2:
172+
/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h:27:17: note: candidate: ‘template<class scalar_t> scalar_t max_propagate_nan(scalar_t, scalar_t)’
173+
27 | inline scalar_t max_propagate_nan(scalar_t a, scalar_t b) {
174+
| ^~~~~~~~~~~~~~~~~
175+
/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h:27:17: note: template argument deduction/substitution failed:
176+
/tmp/torchinductor_root/2x/c2xgxsooklulr4u54etfnnha7dsu6xzbwdscttvs7dkpba3uwkem.cpp:14:53: note: deduced conflicting types for parameter ‘scalar_t’ (‘unsigned char’ and ‘int’)
177+
14 | auto tmp3 = max_propagate_nan(tmp0, tmp2);
178+
| ^
179+
180+
181+
Otherwise, if the model runs with other errors, we can do the model code reduction until finding the minimum code snippet with failure. Thus, the target operators and kernels are located.
182+
183+
184+
Accuracy debugging
185+
^^^^^^^^^^^^^^^^^^^
186+
187+
The accuracy problem refers the case where outputs of backends eager and inductor are different. As FX graph is generated before Inductor and output code is generated after Inductor, we can narrow down the problem by comparing their outputs.
188+
189+
If a model has several graphs, the first step is to compare the final outputs of FX graph and output code for each graph, given the same input. The target is to find the first graph occurring error or with different outputs. Binary search is suggested to use for efficiency.
190+
191+
When a model has only one graph or the problematic graph has been found with the above step, compare the intermediate outputs of FX graph and output code in each graph, given the same input. The idea is to continuously narrow down the problem.
192+
193+
For example, we modify the ``neg`` function like this.
194+
195+
.. code-block:: python
196+
197+
def neg(x):
198+
return f"decltype({x})(2 * {x})"
199+
200+
201+
An accuracy problem would be raised as follows.
202+
203+
.. code:: shell
204+
205+
torch._dynamo.utils: [ERROR] Accuracy failed: allclose not within tol=0.0001
206+
Traceback (most recent call last):
207+
File "test_script.py", line 18, in <module>
208+
assert same(expected_result, actual_result) == True
209+
AssertionError
210+
211+
212+
By comparing the intermediate outputs of FX graph and output code, it would be found that outputs are already different after doing ``torch.neg``.
213+
214+
Specifically, the modifications of FX graph and output code are attached.
215+
216+
*Change of FX graph*
217+
218+
.. code-block:: python
219+
220+
# Before
221+
class Repro(torch.nn.Module):
222+
def __init__(self):
223+
super().__init__()
224+
225+
def forward(self, arg0_1, arg1_1):
226+
neg = torch.ops.aten.neg.default(arg0_1); arg0_1 = None
227+
maximum = torch.ops.aten.maximum.default(arg1_1, neg); arg1_1 = neg = None
228+
clone = torch.ops.aten.clone.default(maximum); maximum = None
229+
return (clone,)
230+
231+
# After
232+
class Repro(torch.nn.Module):
233+
def __init__(self):
234+
super().__init__()
235+
236+
def forward(self, arg0_1, arg1_1):
237+
neg = torch.ops.aten.neg.default(arg0_1); arg0_1 = None
238+
return (neg,)
239+
240+
241+
*Change of output code*
242+
243+
.. code-block:: python
244+
245+
# Before
246+
cpp_fused_cat_maximum_neg_0 = async_compile.cpp('''
247+
#include "/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"
248+
extern "C" void kernel(const long* in_ptr0,
249+
const long* in_ptr1,
250+
long* out_ptr0)
251+
{
252+
{
253+
#pragma GCC ivdep
254+
for(long i0=static_cast<long>(0L); i0<static_cast<long>(8390L); i0+=static_cast<long>(1L))
255+
{
256+
auto tmp0 = in_ptr0[static_cast<long>(i0)];
257+
auto tmp1 = in_ptr1[static_cast<long>(0L)];
258+
auto tmp2 = decltype(tmp1)(2 * tmp1);
259+
auto tmp3 = max_propagate_nan(tmp0, tmp2);
260+
out_ptr0[static_cast<long>(i0)] = tmp3;
261+
}
262+
}
263+
}
264+
''')
265+
266+
def call(args):
267+
arg0_1, arg1_1 = args
268+
args.clear()
269+
buf0 = empty_strided((8390, ), (1, ), device='cpu', dtype=torch.int64)
270+
cpp_fused_cat_maximum_neg_0(c_void_p(arg1_1.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
271+
del arg0_1
272+
del arg1_1
273+
return (buf0, )
274+
275+
# After
276+
cpp_fused_cat_maximum_neg_0 = async_compile.cpp('''
277+
#include "/tmp/torchinductor_root/gv/cgv6n5aotqjo5w4vknjibhengeycuattfto532hkxpozszcgxr3x.h"
278+
extern "C" void kernel(const long* in_ptr0,
279+
const long* in_ptr1,
280+
long* out_ptr0)
281+
{
282+
{
283+
auto tmp1 = in_ptr1[static_cast<long>(0L)];
284+
auto tmp2 = decltype(tmp1)(2 * tmp1);
285+
out_ptr0[static_cast<long>(0L)] = tmp2;
286+
}
287+
}
288+
''')
289+
290+
def call(args):
291+
arg0_1, arg1_1 = args
292+
args.clear()
293+
buf0 = empty_strided((1, ), (1, ), device='cpu', dtype=torch.int64)
294+
cpp_fused_cat_maximum_neg_0(c_void_p(arg1_1.data_ptr()), c_void_p(arg0_1.data_ptr()), c_void_p(buf0.data_ptr()))
295+
del arg0_1
296+
del arg1_1
297+
return (buf0, )
298+
299+
300+
Note that there exists a debugging tool provided by PyTorch, called `Minifier <https://pytorch.org/docs/stable/dynamo/troubleshooting.html>`_. It helps us automatically generate a minified problematic graph.
301+
302+
303+
Performance profiling
304+
--------------
305+
TODO: Haozhe
306+
307+
308+
Future work
309+
--------------
310+
311+
Implement and up-stream the debug tools
312+
1. **Cosim**: Merge graphs of a model into a single large graph. Thus, graphs can be compared quickly between different versions of PyTorch. `#102958 <https://github.com/pytorch/pytorch/pull/102958>`_
313+
2. **Graph matching**: In order to know what each kernel does, this tool matches cpp kernel with FX graph operators and adds corresponding operators before each kernel in cpp output code. `#102958 <https://github.com/pytorch/pytorch/pull/102958>`_
314+
3. **Save inputs and outputs**: For the purpose of reproducing rapidly the failure of a large model, it is necessary to add serializations for the inputs and outputs among graphs and intermediate outputs in graphs.
315+
4. **Test case generation**: When a user has found the operators which are inefficient with cpp kernels, a tool is needed to automatically write a test case. Specifically, one test case can be generated for each kernel, with the corresponding small FX graph and input.
316+
5. **Minifier optimization**: Keep refining Minifier and make it adapted for more scenarios.

0 commit comments

Comments
 (0)