forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_decorators.py
472 lines (361 loc) · 14.1 KB
/
test_decorators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
# Owner(s): ["module: dynamo"]
import functools
import os
import unittest.mock as mock
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.exc import IncorrectUsage
def my_custom_function(x):
return x + 1
class DecoratorTests(torch._dynamo.test_case.TestCase):
def test_disallow_in_graph(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnts)
def fn(a):
x = torch.add(a, 1)
x = torch.add(x, 1)
x = torch.sub(x, 1)
x = torch.add(x, 1)
x = torch.add(x, 1)
return x
torch._dynamo.disallow_in_graph(torch.sub)
fn(torch.randn(10))
torch._dynamo.allow_in_graph(torch.sub)
# check for graph break on sub
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
def test_disable_for_custom_op(self):
import torch.library
from torch.library import Library
foo = Library("foo", "DEF") # noqa: TOR901
foo.define("custom(Tensor self) -> Tensor")
# Dynamic shape data dependent operator. For static shape compilation, Dynamo
# should graph break on it. But, the meta kernel is not implemented properly.
@torch.library.impl(foo, "custom", "CPU")
def foo_cpu(x):
return x.nonzero()
# Disallow does not work because of extra python frames with torch.library python API
torch.ops.foo.custom = torch._dynamo.disable(torch.ops.foo.custom)
def fn(x):
a = torch.nn.functional.relu(x)
b = torch.ops.foo.custom(a)
c = torch.cos(b)
return c
x = torch.randint(2, (100,))
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res = opt_fn(x)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(ref, res)
def test_disable_ignores_outer_wraps(self):
def orig_inner():
pass
def inner():
pass
inner._torchdynamo_orig_callable = orig_inner
@functools.wraps(inner)
def wrapper():
raise AssertionError("wrapper called")
# This behavior is not ideal, but supporting it would add overhead
# to callsites of eval_frame.innermost_fn. A warning would also be very noisy.
w = torch._dynamo.disable(fn=wrapper, recursive=True)
def test_disable_nn_modules_forward_hook(self):
class SimpleLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
def hook(module, args):
inp = args[0].sigmoid()
return (inp,)
model = SimpleModel()
model.layer0.register_forward_pre_hook(hook)
# Disable my monkeypatching
model.layer0 = torch._dynamo.disable(model.layer0)
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
opt_model = torch.compile(model, backend=cnts)
opt_model(torch.randn(4))
# check for no graph break
self.assertEqual(cnts.frame_count, 2)
gm0 = cnts.graphs[0]
# Check that the first graph has sin node, and no sigmoid
self.assertTrue(any(node.target is torch.sin for node in gm0.graph.nodes))
self.assertTrue(
all(node.target is not torch.sigmoid for node in gm0.graph.nodes)
)
gm1 = cnts.graphs[1]
# Check that the first graph does not have sigmoid. sigmoid is used in
# both hook and disabled module.
self.assertTrue(
all(node.target is not torch.sigmoid for node in gm1.graph.nodes)
)
def test_disable_nn_module_with_class_decorator(self):
cnts = torch._dynamo.testing.CompileCounterWithBackend("eager")
@torch._dynamo.disable
class SimpleLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer0 = torch.nn.Linear(4, 4)
def forward(self, inp):
return self.layer0(torch.sigmoid(inp))
@torch.compile(backend=cnts)
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer0 = SimpleLinear()
self.layer1 = torch.nn.Linear(4, 4)
def forward(self, inp):
z = self.layer0(torch.sin(inp))
return self.layer1(z)
def hook(module, args):
inp = args[0].sigmoid()
return (inp,)
model = SimpleModel()
model.layer0.register_forward_pre_hook(hook)
model(torch.randn(4))
# check for no graph break
self.assertEqual(cnts.frame_count, 2)
gm0 = cnts.graphs[0]
# Check that the first graph has sin node, and no sigmoid
self.assertTrue(any(node.target is torch.sin for node in gm0.graph.nodes))
self.assertTrue(
all(node.target is not torch.sigmoid for node in gm0.graph.nodes)
)
gm1 = cnts.graphs[1]
# Check that the first graph does not have sigmoid. sigmoid is used in
# both hook and disabled module.
self.assertTrue(
all(node.target is not torch.sigmoid for node in gm1.graph.nodes)
)
def test_allow_in_graph(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnts)
def fn(a):
x = torch.add(a, 1)
x = torch.add(x, 1)
x = my_custom_function(x)
x = torch.add(x, 1)
x = torch.add(x, 1)
return x
torch._dynamo.allow_in_graph(my_custom_function)
fn(torch.randn(10))
torch._dynamo.disallow_in_graph(my_custom_function)
# check for no graph break
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 5)
def test_incorrect_usage_disallow_in_graph(self):
with self.assertRaises(IncorrectUsage):
@torch._dynamo.disallow_in_graph
def fn1(x):
return x.cos()
def test_graph_break(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnts)
def fn(x):
x = torch.cos(x)
x = torch.cos(x)
torch._dynamo.graph_break()
x = torch.cos(x)
x = torch.cos(x)
torch._dynamo.graph_break()
x = torch.cos(x)
x = torch.cos(x)
return x
fn(torch.randn(4, 5))
self.assertEqual(cnts.frame_count, 3)
self.assertEqual(cnts.op_count, 6)
def test_skip(self):
def fn2(x):
return x.sin()
@torch._dynamo.disable(recursive=False)
def fn1(x):
x = x.sigmoid()
return fn2(x.cos())
def fn(x):
return fn1(x.tan())
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
opt_fn(torch.randn(4))
self.assertEqual(cnts.frame_count, 2)
@patch.object(torch._dynamo.config, "suppress_errors", True)
def test_nested_disable_decorator(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch._dynamo.disable()
def fn1(x):
return torch.sin(x) * 10
@torch._dynamo.optimize(cnts)
def fn2(x):
x = x + 1
x = x + 1
x = fn1(x) # graph break
x = x + 1
x = x + 1
return x
@torch._dynamo.optimize(cnts, nopython=True)
def fn3(x):
return fn2(x)
fn2(torch.randn(4, 5))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
try:
fn3(torch.randn(4, 5))
self.assertFalse(True)
except torch._dynamo.exc.Unsupported as e:
self.assertIn("call torch._dynamo.disable() wrapped function", str(e))
def test_disable_optimize(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch._dynamo.optimize(cnt, disable=True)
def f1(x):
return x + 1
f1(torch.ones(6))
self.assertEqual(cnt.frame_count, 0)
@torch._dynamo.optimize(cnt, disable=True)
def f2(x):
return x + 1
f2(torch.ones(6))
self.assertEqual(cnt.frame_count, 0)
with patch.dict(os.environ, {"TORCHDYNAMO_DISABLE": "1"}):
@torch._dynamo.optimize(cnt)
def f3(x):
return x + 1
f3(torch.ones(6))
self.assertEqual(cnt.frame_count, 0)
def test_torch_guards_stack_frame_register_inlining_disable(self):
y = torch.nn.Parameter(torch.tensor([0.25, 0.25]))
x = torch.tensor([0.5, 0.5])
class encoder(torch.nn.Module):
def __init__(self, y):
super().__init__()
self.register_parameter("param", y)
@torch._dynamo.disable
def helper(self, x, y):
return x * y
def forward(self, a, *args):
x = a + a
return self.helper(x, self.param)
e = encoder(y)
seen_frames = []
import contextlib
@contextlib.contextmanager
def global_context_capture_fn(frame_summary):
if frame_summary is not None:
seen_frames.append(frame_summary)
yield
with mock.patch(
"torch._guards.TracingContext.current_frame",
side_effect=global_context_capture_fn,
):
torch._dynamo.optimize("eager")(e)(x)
self.assertEqual(len(seen_frames), 0)
def test_torch_guards_stack_frame_register_inlining_partially_disable(self):
y = torch.nn.Parameter(torch.tensor([0.25, 0.25]))
x = torch.tensor([0.5, 0.5])
class encoder(torch.nn.Module):
def __init__(self, y):
super().__init__()
self.register_parameter("param", y)
@torch._dynamo.disable
def helper_disabled(self, x, y):
return x.sin() * y.cos()
def helper(self, x, y):
return x * y
def forward(self, a, *args):
x = a + a
return self.helper(x, self.param) + self.helper_disabled(x, self.param)
e = encoder(y)
cnt = torch._dynamo.testing.CompileCounter()
torch.compile(e, backend=cnt)(x)
# first frame is before disable, second frame is after disable
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 3)
def _test_mark_static_address(self, guarded):
compiles_with_buffers = 0
compiles = 0
def debug_compiler(gm, _):
nonlocal compiles_with_buffers
nonlocal compiles
compiles_with_buffers += len(gm._buffers) > 0
compiles += 1
return gm
@torch._dynamo.optimize(backend=debug_compiler)
def fn(x):
return x + 1
inp = torch.ones(2)
torch._dynamo.mark_static_address(inp, guard=guarded)
fn(inp)
self.assertEqual(compiles_with_buffers, 1)
inp2 = torch.ones(2)
# if guarded, should trigger another recompile
# since it was not marked static, compiles with buffers
# should not be incremented
fn(inp2)
self.assertEqual(compiles_with_buffers, 1)
self.assertEqual(compiles, 2 if guarded else 1)
def test_mark_static_address_guarded(self):
self._test_mark_static_address(guarded=True)
def test_mark_static_address_unguarded(self):
self._test_mark_static_address(guarded=False)
def test_class_methods(self):
class A:
@classmethod
def my_class_method(cls, arg1):
return cls, arg1
@staticmethod
def my_static_method(arg1):
return None, arg1
def my_regular_method(self, arg1):
return self, arg1
class B(A):
def my_class_method(self, arg1):
return super().my_class_method(arg1)
def my_static_method(self, arg1):
return super().my_static_method(arg1)
class C(A):
@classmethod
def my_class_method(cls, arg1):
return super().my_class_method(arg1)
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt)
def fn(a, b, c):
# We want a function that does not graph break but
# does generate custom bytecode
v1 = a.my_class_method(1)
v2 = A.my_class_method(2)
v3 = a.my_static_method(3)
v4 = A.my_static_method(4)
v5 = a.my_regular_method(5)
v6 = b.my_class_method(6)
v7 = b.my_static_method(7)
v8 = c.my_class_method(8)
v9 = C.my_class_method(9)
torch.rand(2)
return v1, v2, v3, v4, v5, v6, v7, v8, v9
a, b, c = A(), B(), C()
v1, v2, v3, v4, v5, v6, v7, v8, v9 = fn(a, b, c)
self.assertEqual(v1, (A, 1))
self.assertEqual(v2, (A, 2))
self.assertEqual(v3, (None, 3))
self.assertEqual(v4, (None, 4))
self.assertEqual(v5, (a, 5))
# TODO fix me: we do not resolve classmethods properly
# from a regular method
# self.assertEqual(v6, (B, 6))
self.assertEqual(v7, (None, 7))
self.assertEqual(v8, (C, 8))
self.assertEqual(v9, (C, 9))
self.assertEqual(cnt.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()