forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_higher_order_ops.py
6217 lines (4715 loc) · 222 KB
/
test_higher_order_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["module: dynamo"]
import enum
import functools
import pprint
import re
import sys
import unittest
import warnings
import functorch.experimental.control_flow as control_flow
import torch
import torch._dynamo.config as config
import torch._dynamo.test_case
import torch._functorch.config
import torch.nn as nn
import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.testing import (
CompileCounter,
CompileCounterWithBackend,
EagerAndRecordGraphs,
normalize_gm,
)
from torch._dynamo.utils import counters, ifdynstaticdefault
from torch._higher_order_ops.wrap import wrap
from torch.testing._internal.common_utils import (
munge_exc,
TEST_WITH_TORCHDYNAMO,
xfailIfTorchDynamo,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
def check_dynamic_shape_capture():
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
if not config.assume_static_by_default:
return True
return False
def count_ops(gm, args, freq, op):
assert [node.target for node in gm.graph.nodes].count(op) == freq
return gm
class Obj:
pass
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.existing = torch.nn.Parameter(torch.ones([]))
def forward(self, x):
return self.existing * x
global_obj = Obj()
global_module = MyModule()
global_var = torch.randn(3)
global_num = 3.14
global_list = []
def find_first_node(gm, func):
for node in gm.graph.nodes:
if node.target is func:
return node
return None
def op_count(gm):
result = 0
for node in gm.graph.nodes:
if "call" in node.op:
result += 1
return result
# Checks that a dict matches a dict with "regex keys". That is,
# the keys are regex expressions.
def assert_dict_matches_regex(self, dct, dct_with_regex_keys):
regex_keys = dct_with_regex_keys.keys()
regex_key_to_actual_key = {}
for regex_key in regex_keys:
for key in dct:
if re.match(regex_key, key):
if regex_key in regex_key_to_actual_key:
raise AssertionError(
f"Single key regex mapped to multiple keys. Please improve your "
f"regex. Got: regex='{regex_key}' "
f"keys='{regex_key_to_actual_key[regex_key]}',"
f"'{key}'"
)
regex_key_to_actual_key[regex_key] = key
new_dct = {}
for regex_key in regex_keys:
if regex_key not in regex_key_to_actual_key:
raise AssertionError(
f"Got regex '{regex_key}' but could not match any key in dict with "
f"keys {dct.keys()}"
)
new_dct[regex_key_to_actual_key[regex_key]] = dct_with_regex_keys[regex_key]
self.assertEqual(dct, new_dct)
def default_args_generator(seed_value):
flat_args, args_spec = pytree.tree_flatten(seed_value)
for i in range(3):
new_flat_arg = []
for val in flat_args:
if isinstance(val, torch.Tensor):
new_val = val + 0.1 * i
elif isinstance(val, int):
new_val = val + 1 * i
elif isinstance(val, float):
new_val = val + 0.1 * i
elif isinstance(val, enum.Enum):
new_val = val
else:
raise AssertionError("unexpected arg type")
new_flat_arg.append(new_val)
new_args = pytree.tree_unflatten(new_flat_arg, args_spec)
yield new_args
class HigherOrderOpTests(torch._dynamo.test_case.TestCase):
def _assert_wrap_fallback(self, func, args, setup=lambda: None):
counters.clear()
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
setup()
expected = func(*args)
setup()
result = torch.compile(func, backend=cnt, fullgraph=False)(*args)
num_graph_breaks = len(counters["graph_break"].keys())
self.assertGreater(num_graph_breaks, 0)
for gm in backend.graphs:
for node in gm.graph.nodes:
self.assertFalse(node.target is wrap)
self.assertEqual(result, expected)
def _test_wrap_simple(
self,
func,
args_generator,
expected_num_wrap_args,
expected_opcount=2,
return_graph=False,
):
# Given a `func` that has a single call to `wrap`,
# we check that:
# - there are no graph breaks
# - eager vs torch.compile has the same result (correctness)
# - other compilation metrics, e.g, # of ops in the dynamo captured graph,
# the wrap has the expected number of args, etc
#
# we have one or multiple runs through with each of the args from args_generator,
# and we will check:
# - correctness and no graph breaks for every run
# - other compilation metrics only for the first run, since automatic_dynamic_shapes
# may compile another dynamic version graph for the later runs
graph = None
for i, args in enumerate(args_generator):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
expected = func(*args)
result = torch.compile(func, fullgraph=True, backend=cnt)(*args)
# check correctness and no graph breaks
self.assertEqual(result, expected)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(backend.graphs), 1)
# check other compilation metrics
if i == 0:
self.assertEqual(cnt.op_count, expected_opcount)
graph = backend.graphs[0]
wrap_node = find_first_node(graph, wrap)
self.assertEqual(len(wrap_node.args), expected_num_wrap_args)
# We always return/check the graph from the first run if return_graph = True
if return_graph:
return normalize_gm(graph.print_readable(print_output=False))
def test_error_message_sane(self):
foo = []
def inner(x):
foo.append(x)
return x.clone()
@torch.compile(backend="eager", fullgraph=True)
def f(x):
return wrap(inner, x)
x = torch.randn(3)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
r"HigherOrderOperator: Mutating a variable not in the current scope \(SideEffects\)",
):
f(x)
def test_no_freevars(self):
def f(x):
return wrap(lambda x: torch.sin(x), x)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x,)), 2)
def test_enum_arg(self):
class SomeEnum(enum.Enum):
A = 0
B = 1
def g(x, val):
if val == SomeEnum.A:
return torch.sin(x)
return torch.cos(x)
def f(x, val):
return wrap(g, x, val)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x, SomeEnum.A)), 2)
def test_return_captured_var(self):
freevar = torch.randn(3)
def test(x):
return freevar
def fn(x):
return wrap(test, x)
x = torch.randn(3)
# Since, `x` is unused, we don't lift it to
# be the input.
self._test_wrap_simple(fn, default_args_generator((x,)), 2)
def test_return_captured_vars(self):
freevar1 = torch.randn(3)
freevar2 = torch.randn(3)
def test(x):
return freevar1, freevar2, freevar1
def fn(x):
return wrap(test, x)
x = torch.randn(3)
# Since, `x` is unused, we don't lift it to
# be the input.
self._test_wrap_simple(fn, default_args_generator((x,)), 3, 4)
def test_return_captured_var_used_multiple_times(self):
freevar = torch.randn(3)
def test(x):
y = x + freevar
return y, freevar
def fn(x):
return wrap(test, x)
x = torch.randn(3)
self._test_wrap_simple(fn, default_args_generator((x,)), 3, 3)
def test_capture_untracked_global(self):
def f(x):
return wrap(lambda x: x + global_var, x)
x = torch.randn(3)
self._test_wrap_simple(f, default_args_generator((x,)), 3)
def test_symint_input(self):
def f(x):
i = x.size(0)
return wrap(lambda x, i: x.view(i), x, i)
x = torch.randn(3, 1)
self._test_wrap_simple(
f,
default_args_generator((x,)),
ifdynstaticdefault(2, 3),
expected_opcount=ifdynstaticdefault(2, 3),
)
def test_wrap_pytree_args_nested(self):
def f(x, y, z):
def fn(d):
return d["x"].sin() + d["y"][0].cos() - d["y"][1][2].sin()
return wrap(fn, d)
x = torch.tensor(1.5)
y = torch.tensor(2.0)
z = torch.tensor(3.0)
d = {"x": x, "y": (y, [x, y, z])}
def my_args_generator(t):
yield t
yield t[0] + 0.1, t[1], t[2]
yield t[0], t[1] + 0.1, t[2]
actual_graph = self._test_wrap_simple(
f,
my_args_generator((x, y, z)),
4,
return_graph=True,
)
self.assertExpectedInline(
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_d_x_: "f32[]", L_d_y_0_: "f32[]", L_d_y_1_2_: "f32[]"):
l_d_x_ = L_d_x_
l_d_y_0_ = L_d_y_0_
l_d_y_1_2_ = L_d_y_1_2_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_d_x_, l_d_y_0_, l_d_y_1_2_); wrap_body_0 = l_d_x_ = l_d_y_0_ = l_d_y_1_2_ = None
getitem: "f32[]" = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, l_d_x_: "f32[]", l_d_y_0_: "f32[]", l_d_y_1_2_: "f32[]"):
sin: "f32[]" = l_d_x_.sin(); l_d_x_ = None
cos: "f32[]" = l_d_y_0_.cos(); l_d_y_0_ = None
add: "f32[]" = sin + cos; sin = cos = None
sin_1: "f32[]" = l_d_y_1_2_.sin(); l_d_y_1_2_ = None
sub: "f32[]" = add - sin_1; add = sin_1 = None
return (sub,)
""", # NOQA: B950
)
def test_wrap_pytree_args_with_symint_constant(self):
def f(x, y):
i = x.size(0)
return wrap(lambda t: t[0].view(t[2]) + t[1], (x, y, i))
x = torch.randn(3, 1)
y = 0.5
actual_graph = self._test_wrap_simple(
f,
default_args_generator((x, y)),
ifdynstaticdefault(2, 3),
expected_opcount=ifdynstaticdefault(2, 3),
return_graph=True,
)
if torch._dynamo.config.assume_static_by_default:
self.assertExpectedInline(
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 1]"):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[3, 1]"):
view: "f32[3]" = l_x_.view(3); l_x_ = None
add: "f32[3]" = view + 0.5; view = None
return (add,)
""",
)
else:
self.assertExpectedInline(
actual_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, s0: "Sym(s0)", L_x_: "f32[s0, 1]"):
l_x_ = L_x_
size: "Sym(s0)" = l_x_.size(0)
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_, size); wrap_body_0 = l_x_ = size = None
getitem: "f32[s0]" = wrap[0]; wrap = None
return (getitem,)
class GraphModule(torch.nn.Module):
def forward(self, l_x_: "f32[s0, 1]", size: "Sym(s0)"):
view: "f32[s0]" = l_x_.view(size); l_x_ = size = None
add: "f32[s0]" = view + 0.5; view = None
return (add,)
""",
)
def test_wrap_pytree_kwargs(self):
def f(x, y, z):
def fn(*, x, y, z):
z1, z2 = z
return (x * 2) + y + z1
return wrap(fn, x=x, y=y, z=z)
x = torch.randn(3)
y = torch.randn(3, 3)
def my_args_generator(t):
yield t
x1 = t[0] + 0.1
y1 = t[1] + 0.1
yield (x1, y1, (x1, y1))
x2 = t[0] + 0.2
y2 = t[0] + 0.2
yield (x2, y2, (x2, y2))
self._test_wrap_simple(f, my_args_generator((x, y, (x, y))), 3)
def test_wrap_pytree_args_not_const_symint_tensor(self):
class MyClass:
def __init__(self, x):
self.val = x
def f(x, y):
return wrap(lambda z: z[0].sin() * z[1].val.cos(), (x, y))
x = torch.tensor(1.2)
y = MyClass(torch.tensor(3.4))
self._test_wrap_simple(f, [(x, y)], 3)
def test_capture_constants(self):
x = torch.randn(3, 3)
y = 4.0
def fn(x, y, z):
if z:
return x + y
return x * y
def f(x, y, z):
return wrap(fn, x, y, z)
args = (x, 4.0, None)
opt_f = torch.compile(f, fullgraph=True, backend=CompileCounter())
expected = f(*args)
result = opt_f(*args)
self.assertEqual(result, expected)
# Ensure that we recompile here
args = (x, 5.0, None)
expected = f(*args)
result = opt_f(*args)
self.assertEqual(result, expected)
def test_capture_untracked_global_nested(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return wrap(lambda x: wrap(lambda x: x + global_var, x), x)
x = torch.randn(3)
result = f(x)
self.assertEqual(result, x + global_var)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(len(backend.graphs), 1)
wrap_node = find_first_node(backend.graphs[0], wrap)
self.assertTrue(len(wrap_node.args), 3)
body_function = getattr(backend.graphs[0], wrap_node.args[0].name)
self.assertEqual(op_count(body_function), 2)
inner_wrap_node = find_first_node(body_function, wrap)
self.assertTrue(len(inner_wrap_node.args), 3)
def test_capture_untracked_nonlocal(self):
x = torch.randn(3, 3)
y = torch.randn(3, 3)
def f(x, y):
def g(x):
return wrap(lambda x: x + y, x)
self._test_wrap_simple(g, default_args_generator((x,)), 3)
return g(x)
f(x, y)
def test_capture_tracked(self):
x = torch.randn(3, 3)
y = torch.randn(3, 3)
def f(x, y):
return wrap(lambda x: x + y, x)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
def test_capture_tracked_nested(self):
x = torch.randn(3, 3)
y = torch.randn(3, 3)
def f(x, y):
return wrap(lambda x: wrap(lambda x: x + y, x), x)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
def test_inlined_functions(self):
def g(x, y):
return x + y
def f(x, y):
return wrap(lambda x: g(x, y), x)
x = torch.randn(3, 3)
y = torch.randn(3, 3)
self._test_wrap_simple(f, default_args_generator((x, y)), 3)
def test_same_freevar_twice(self):
free = torch.randn(3)
def g(x):
y = free.sin()
z = free.cos()
return y, z
def f(x):
return wrap(g, x)
x = torch.randn(3)
# Since, `x` is unused, we don't lift it to
# be the input.
self._test_wrap_simple(f, default_args_generator((x,)), 2, 3)
def test_capture_value_created_in_subgraph(self):
backend = EagerAndRecordGraphs()
cnt = CompileCounterWithBackend(backend)
x = torch.randn(3, 3)
y = torch.randn(3, 3)
def inner(x, y):
z = x + y
return wrap(lambda x: wrap(lambda x: x + z, x), x)
@torch.compile(backend=cnt, fullgraph=True)
def f(x, y):
return wrap(inner, x, y)
result = f(x, y)
self.assertEqual(result, x + y + x)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 2)
self.assertEqual(len(backend.graphs), 1)
# No changes to args of outer wrap
gm = backend.graphs[0]
wrap_node = find_first_node(gm, wrap)
self.assertTrue(len(wrap_node.args), 3)
# z was lifted to arg of inner wrap
body_function = getattr(gm, wrap_node.args[0].name)
# addition + wrap + getitem
self.assertEqual(op_count(body_function), 3)
inner_wrap_node = find_first_node(body_function, wrap)
self.assertTrue(len(inner_wrap_node.args), 3)
# Innermost body function: z was also lifted to arg
body_function = getattr(body_function, inner_wrap_node.args[0].name)
self.assertEqual(op_count(body_function), 2)
inner_wrap_node = find_first_node(body_function, wrap)
self.assertTrue(len(inner_wrap_node.args), 3)
def test_side_effect_set_new_attr_global_obj(self):
def setup():
global global_obj
global_obj = Obj()
def f(x):
def h(x):
def g(x):
global_obj.foo = x + 1
return x.clone()
y = wrap(g, x)
return y + global_obj.foo
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,), setup=setup)
def test_side_effect_set_existing_attr_global_obj(self):
def setup():
global global_obj
global_obj = Obj()
global_obj.foo = nn.Parameter(torch.tensor(4.0))
def f(x):
def h(x):
def g(x):
global_obj.foo = x + 1
return x.clone()
y = wrap(g, x)
return y + global_obj.foo
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,), setup=setup)
def test_side_effect_del_existing_attr_global_obj(self):
def setup():
global global_obj
global_obj = Obj()
global_obj.foo = torch.tensor(4.0)
def f(x):
def h(x):
def g(x):
del global_obj.foo
return x.clone()
y = wrap(g, x)
return y
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,), setup=setup)
def test_side_effect_set_new_attr_global_module(self):
def setup():
global global_module
global_module = MyModule()
def h(x):
def g(x):
global_module.foo = nn.Parameter(x + 1)
return x.clone()
y = wrap(g, x)
return y + global_module.foo
x = torch.zeros([])
self._assert_wrap_fallback(h, (x,), setup=setup)
def test_side_effect_set_existing_attr_global_module(self):
def setup():
global global_module
global_module = MyModule()
def h(x):
def g(x):
global_module.existing = nn.Parameter(torch.tensor(4.0))
return global_module(x)
y = wrap(g, x)
return y
x = torch.zeros([])
self._assert_wrap_fallback(h, (x,), setup=setup)
def test_side_effect_del_existing_attr_global_module(self):
def setup():
global global_module
global_module = MyModule()
def h(x):
def g(x):
del global_module.existing
return x.clone()
y = wrap(g, x)
return y
x = torch.zeros([])
self._assert_wrap_fallback(h, (x,), setup=setup)
def test_side_effect_mutate_global_num(self):
def setup():
global global_num
global_num = 3.14
def f(x):
def g(x):
global global_num
global_num = global_num + 1
return x + global_num
y = wrap(g, x)
return y + global_num
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,), setup=setup)
def test_side_effect_mutate_global_num_builtin(self):
def setup():
global global_num
global_num = 3.14
def f(x):
def g(x):
global global_num
global_num += 1
return x + global_num
y = wrap(g, x)
return y + global_num
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,), setup=setup)
def test_side_effect_mutate_global_tensor(self):
def setup():
global global_var
global_var = torch.ones(3)
def f(x):
def g(x):
global global_var
global_var = global_var + 1
return x + global_var
y = wrap(g, x)
return y + global_var
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,), setup=setup)
def test_side_effect_mutate_global_tensor_builtin(self):
def setup():
global global_var
global_var = torch.ones(3)
def f(x):
def g(x):
global global_var
global_var += 1
return x + global_var
y = wrap(g, x)
return y + global_var
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,), setup=setup)
def test_side_effect_mutate_global_list(self):
def setup():
global global_list
global_list = []
def f(x):
def g(x):
val = x + 1
global_list.append(val)
return global_list[-1]
y = wrap(g, x)
z = y + global_list[-1]
return z
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,), setup=setup)
def test_side_effect_mutate_nonlocal_num(self):
def f(x):
def h(x):
val = 1
def g(x):
nonlocal val
val = val + 1
return x + val
y = wrap(g, x)
z = y + val
return z
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,))
def test_side_effect_set_new_attr_nonlocal_obj(self):
def f(x):
def h(x):
obj = Obj()
def g(x):
obj.val = x.dim()
return x.clone()
y = wrap(g, x)
z = y + obj.val
return z
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,))
def test_side_effect_set_existing_attr_nonlocal_obj(self):
def f(x):
def h(x):
obj = Obj()
obj.val = 3
def g(x):
obj.val = x.dim()
return x.clone()
y = wrap(g, x)
z = y + obj.val
return z
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,))
def test_side_effect_del_existing_attr_nonlocal_obj(self):
def f(x):
def h(x):
obj = Obj()
obj.val = 3
def g(x):
del obj.val
return x.clone()
y = wrap(g, x)
return y
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,))
def test_side_effect_set_new_attr_nonlocal_module(self):
def h(x):
obj = MyModule()
def g(x):
obj.val = x.dim()
return x.clone()
y = wrap(g, x)
z = y + obj.val
return z
x = torch.zeros([])
self._assert_wrap_fallback(h, (x,))
def test_side_effect_set_existing_attr_nonlocal_module(self):
def h(x):
obj = MyModule()
def g(x):
obj.existing = nn.Parameter(torch.tensor(3.14))
return obj(x)
y = wrap(g, x)
return y
x = torch.zeros([])
self._assert_wrap_fallback(h, (x,))
def test_side_effect_del_existing_attr_nonlocal_module(self):
def h(x):
obj = MyModule()
def g(x):
del obj.existing
return x.clone()
y = wrap(g, x)
return y
x = torch.zeros([])
self._assert_wrap_fallback(h, (x,))
def test_side_effect_mutate_nonlocal_tensor(self):
def f(x):
def h(x):
val = torch.tensor(1.0)
def g(x):
nonlocal val
val = val + 1
return x + val
y = wrap(g, x)
z = y + val
return z
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,))
def test_side_effect_mutate_nonlocal_num_builtin(self):
def f(x):
def h(x):
val = 1
def g(x):
nonlocal val
val += 1
return x + val
y = wrap(g, x)
z = y + val
return z
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,))
def test_side_effect_mutate_nonlocal_tensor_builtin(self):
def f(x):
def h(x):
val = torch.tensor(1.0)
def g(x):
nonlocal val
val += 1
return x + val
y = wrap(g, x)
z = y + val
return z
return h(x)
x = torch.zeros([])
self._assert_wrap_fallback(f, (x,))
def test_side_effect_nonlocal_list_append_graph_break(self):
def g(x):
y = []
def f(k):
m = k + 1
y.append(m)
return k
wrap(f, x)
return y[0]
x = torch.randn(3, 3)
self._assert_wrap_fallback(g, (x,))
def test_side_effect_nested_nonlocal_list_append_graph_break(self):
def g(x):
def h(x):
y = []
def f(k):
m = k + 1
y.append(m)
return k
wrap(f, x)
return y[0]
return h(x)
x = torch.randn(3, 3)
self._assert_wrap_fallback(g, (x,))
def test_side_effect_local_list_append_no_graph_break(self):
def g(x):
def f(k):
y = []
y.append(k + 1)
return y[0]
return wrap(f, x)
x = torch.randn(3, 3)
self._test_wrap_simple(g, default_args_generator((x,)), 2)
def test_wrap_kwarg(self):
def f(x, y):
return wrap(lambda x, y: x + y, x, y=y)