forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_multi_kernel.py
285 lines (230 loc) · 9.67 KB
/
test_multi_kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# Owner(s): ["module: inductor"]
import os
import re
import unittest
import torch
from torch import nn
from torch._dynamo.testing import reset_rng_state
from torch._inductor import config, test_operators
from torch._inductor.codegen.multi_kernel import MultiKernelCall
from torch._inductor.test_case import TestCase
from torch._inductor.utils import run_and_get_code
from torch.nn import functional as F
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
class TransformerSnippet(nn.Module):
def __init__(self):
super().__init__()
self.ln1 = nn.LayerNorm(64)
self.ln2 = nn.LayerNorm(64)
def forward(self, x1, x2):
x1 = F.dropout(x1, 0.1)
x2 = F.dropout(self.ln1(x2), 0.1)
return self.ln2(x1 + x2)
def example_inputs(self):
return (torch.randn(2, 64).cuda(), torch.randn(2, 64).cuda())
def _contains_multi_kernel_code(wrapper_code: str):
return (
re.search(r"multi_kernel_[^ ]* = async_compile.multi_kernel[(]", wrapper_code)
is not None
)
def make_cpp_wrapper_test(orig_test, **extra_args):
"""
Wrap an existing test into a new test with cpp-wrapper enabled.
Make this as a free function rather than staticmethod in MultiKernelTest.
Otherwise we get 'TypeError: 'staticmethod' object is not callable'
error in py3.8. (py3.10 works)
"""
@config.patch("cpp_wrapper", True)
def fn(self):
# The same kernel may have been compiled by previous tests with
# cpp_wrapper disabled. Clear the cache so we go ahead to re-compile
# the kernel with cpp_wrapper enabled.
from torch._inductor import codecache
codecache.PyCodeCache.cache_clear()
return orig_test(self, **extra_args)
return fn
@config.patch(
{
"triton.multi_kernel": int(os.environ.get("TORCHINDUCTOR_MULTI_KERNEL", "1")),
"benchmark_kernel": True,
}
)
@instantiate_parametrized_tests
class MultiKernelTest(TestCase):
def test_softmax(self, expect_multi_kernel=True):
x = torch.rand(2, 1024).cuda()
ref = torch.softmax(x, -1)
compiled_fn = torch.compile(torch.softmax)
act, wrapper_code = run_and_get_code(compiled_fn, x, -1)
# wrapper_code will contains 2 entries if cpp_wrapper=True.
# One for the first pass and one for the second pass.
# We mainly care about the wrapper for the final pass here.
wrapper_code = wrapper_code[-1]
self.assertTrue(torch.allclose(ref, act))
if expect_multi_kernel:
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
else:
# Skip verifying the wrapper_code in fbcode since we may fail
# compiling the cpp wrapper cuda code due to lacking proper setup of
# cuda compiler in fbcode environment. In that case, the last
# collected wrapper_code will corresponds to the first pass
# cpp-wrapper codegen which contains the multi-kernel.
if not config.is_fbcode():
self.assertFalse(_contains_multi_kernel_code(wrapper_code))
@parametrize("force_kernel", (0, 1))
@unittest.mock.patch.dict(
os.environ, {"TORCHINDUCTOR_DISABLE_MULTI_KERNEL_CACHE": "1"}
)
def test_softmax_force_non_persistent_reduction(self, force_kernel):
"""
Force a specific sub-kernel being picked by mocking the benchmark result.
"""
x = torch.rand(2, 1024).cuda()
mock_latency = [0.2, 0.2]
mock_latency[force_kernel] = 0.1 # this make sure force_kernel will be picked
def f(x):
return torch.softmax(x, -1) + force_kernel
orig_run = MultiKernelCall.run_with_argless_kernels
picked_kernel = None
def mock_run(self, kernel_calls):
out = orig_run(self, kernel_calls)
nonlocal picked_kernel
picked_kernel = self.picked_kernel
return out
with unittest.mock.patch.object(
MultiKernelCall, "run_with_argless_kernels", mock_run
), unittest.mock.patch.object(
MultiKernelCall, "benchmark_sub_kernels", lambda *args: mock_latency
):
torch.compile(f)(x)
self.assertEqual(picked_kernel, force_kernel)
@config.patch("warn_mix_layout", True)
def test_softmax_warn_mixed_layout(self):
self.test_softmax()
test_softmax_cpp_wrapper = skipIfRocm(
make_cpp_wrapper_test(test_softmax, expect_multi_kernel=False)
)
def test_layernorm(self):
ln = nn.LayerNorm(1024).cuda()
x = torch.rand(2, 1024).cuda()
ref = ln(x)
act = torch.compile(ln)(x)
self.assertTrue(
torch.allclose(ref, act, atol=1e-4, rtol=1e-4), f"ref:\n{ref}\nact:\n{act}"
)
def test_inplace_update(self):
"""
Inductor generate inplace kernel for mul.
"""
def f(x, y):
return x.sum(dim=-1, keepdims=True) * (y @ y)
x = torch.rand(1024, 1024).cuda()
y = torch.rand(1024, 1024).cuda()
ref = f(x, y)
act = torch.compile(f)(x, y)
self.assertTrue(torch.allclose(ref, act))
def test_transformer_snippet(self):
model = TransformerSnippet().cuda()
x = model.example_inputs()
def f(*x):
y = model(*x)
return y
reset_rng_state()
ref = f(*x)
opt_f = torch.compile(f)
reset_rng_state()
act = opt_f(*x)
# don't compare tensor if using inductor random number generator.
# inductor random number implementation is different to eager.
# We should fallback to eager if we want to test accuracy.
if config.fallback_random:
self.assertTrue(
torch.allclose(ref, act, atol=1e-4, rtol=1e-4),
f"ref:\n{ref}\nact:\n{act}",
)
def test_transformer_snippet_with_fallback_random(self):
"""
Same as test_transformer_snippet but fallback the random number
generator to eager so we can check accuracy.
"""
with config.patch("fallback_random", True):
self.test_transformer_snippet()
def test_batchnorm_training(self):
"""
For training, batchnorm will tracking running mean/variance during forward pass.
The kernel generated by inductor currently will pass in those tensors twice as arguments:
once for input and once for output. They are ruled out as in-out argument because
they are considered as graph inputs.
Multi-kernel previously assumes that we never pass the same argument mutli times
for a kernel. No mater if we change inductor behavior to assure that, it's better
to make multi-kernel being able to handle those cases.
"""
bn = nn.BatchNorm2d(3).to("cuda")
@torch.compile
def f(x):
bn(x).sum().backward()
_, (wrapper_code, _) = run_and_get_code(
f, torch.randn(2, 3, 8, 8, device="cuda")
)
self.assertTrue(_contains_multi_kernel_code(wrapper_code))
def test_pass_same_arg_multi_times(self):
"""
A super simple example that simulate how BatchNorm update the running
stats.
Inductor currently pass the same tensor multiple times for the generated
kernel: once for input and once for output.
Here is a paster for the generated kernel (without multi-kernel enabled):
https://gist.github.com/shunting314/f0b446b4b9a28f4940e31dcd3e809cf9
"""
def f(x, y):
x = x.sum(dim=1, keepdim=False)
y.copy_(y * 0.9 + x * 0.1)
x = torch.randn(8, 16, device="cuda")
y = torch.randn(8, device="cuda")
y_ref = y.clone()
ref = f(x, y_ref)
act = torch.compile(f)(x, y)
self.assertTrue(torch.allclose(y_ref, y))
def test_reduction_scratch_buffer(self, force_multi_kernel=1):
"""
The explicited realized buffer in the test function will be passed in
as a scratch buffer for the non-persistent reduction kernel but
can be skipped for the persistent reduction kernel.
This causes different argument lists for non-persistent reduction kernel and
persistent reduction kernel.
Check documentation around torch._inductor.config.triton.multi_kernel about
how to interpret the force_multi_kernel argument.
"""
def f(x):
x = x.sum(dim=-1, keepdim=True) + x
x = test_operators.realize(x)
x = x.sum(dim=-1, keepdim=True) + x
return x
x = torch.rand(16, 16, device="cuda")
ref = f(x)
with config.patch("triton.multi_kernel", force_multi_kernel):
act = torch.compile(f)(x)
self.assertTrue(torch.allclose(ref, act))
# Use benchmarking to pick the faster kernel
test_reduction_scratch_buffer_cpp_wrapper = make_cpp_wrapper_test(
test_reduction_scratch_buffer, force_multi_kernel=1
)
# force pick persistent reduction. This can be a good test since this persistent
# reduction uses less call arguments than the corresponding non-persistent
# reduction.
test_reduction_scratch_buffer_cpp_wrapper_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=2)
)
# force pick non-persistent reduction
test_reduction_scratch_buffer_cpp_wrapper_non_persistent_reduction = (
make_cpp_wrapper_test(test_reduction_scratch_buffer, force_multi_kernel=3)
)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
if HAS_CUDA:
run_tests()