forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_aotdispatch.py
5825 lines (5047 loc) · 225 KB
/
test_aotdispatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["oncall: pt2"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import itertools
import unittest
import warnings
from contextlib import nullcontext
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from unittest.mock import patch
from common_utils import decorate, decorateForModules, skip, skipOps, xfail
import torch
import torch._dynamo as torchdynamo
import torch.nn as nn
import torch.utils._pytree as pytree
from functorch import grad, jacrev, make_fx, vjp, vmap
from functorch.compile import (
aot_function,
aot_module,
compiled_function,
compiled_module,
default_decompositions,
default_partition,
get_aot_compilation_context,
make_boxed_compiler,
memory_efficient_fusion,
min_cut_rematerialization_partition,
nnc_jit,
nop,
)
from functorch.experimental import control_flow
from torch._decomp import decomposition_table
from torch._functorch.aot_autograd import (
aot_export_joint_simple,
aot_export_module,
aot_module_simplified,
)
from torch._higher_order_ops.out_dtype import out_dtype
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
from torch.nn.utils.rnn import PackedSequence
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
tol,
toleranceOverride,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
compare_equal_outs_and_grads,
instantiate_parametrized_tests,
IS_ARM64,
IS_MACOS,
IS_WINDOWS,
IS_X86,
outs_and_grads,
parametrize,
run_tests,
skipIfRocm,
skipIfTorchDynamo,
TestCase,
xfailIfTorchDynamo,
)
from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.optests import (
_test_aot_autograd_forwards_backwards_helper,
aot_autograd_check,
)
from torch.testing._internal.two_tensor import TwoTensor, TwoTensorMode
USE_TORCHVISION = False
try:
import torchvision
USE_TORCHVISION = True
except ImportError:
warnings.warn(
"Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning,
)
USE_NETWORKX = False
try:
import networkx # noqa: F401
USE_NETWORKX = True
except ImportError:
warnings.warn("Some tests use networkx but it was not installed", UserWarning)
# NB: numpy is a testing dependency!
class AOTTestCase(TestCase):
pass
class TestPythonKey(AOTTestCase):
def test_make_fx(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(3)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_grad(self, device):
def f(x):
return torch.sin(x).sum()
inp = torch.randn(3)
f = grad(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_scalar_device(self, device):
def f(a, b):
return a + b
inps = [torch.randn(3, device=device), torch.tensor(5)]
fx_f = make_fx(f)(*inps)
self.assertEqual(fx_f(*inps), f(*inps))
def test_make_fx_vmap(self, device):
def f(x):
return torch.sin(x)
inp = torch.randn(5, 3)
f = vmap(f)
fx_f = make_fx(f)(inp)
new_inp = torch.randn(5, 3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_jacrev(self, device):
def f(x):
return x.sin().sum()
inp = torch.randn(3)
f = jacrev(jacrev(f))
fx_f = make_fx(f)(inp)
new_inp = torch.randn(3)
self.assertEqual(fx_f(new_inp), f(new_inp))
def test_make_fx_vjp(self, device):
def f(x):
return torch.sin(x).sum()
primals = torch.randn(3)
_, vjp_fn = vjp(f, primals)
cotangent = torch.randn(())
fx_f = make_fx(vjp_fn)(cotangent, True, True)
new_cotangent = torch.randn(())
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
def test_make_fx_functionalize(self, device):
from functorch.experimental import functionalize
def fn(a):
a = a * 2
a.relu_()
return a
a = torch.randn(3, device=device)
symbolic_gm = torch.fx.symbolic_trace(fn)
includes_method_relu_ = any(
str(n.target) == "relu_" for n in symbolic_gm.graph.nodes
)
self.assertTrue(includes_method_relu_)
# Also verifies fix for https://github.com/pytorch/pytorch/issues/84570
gm = make_fx(functionalize(symbolic_gm))(a)
includes_aten_relu = any(
n.target == torch.ops.aten.relu.default for n in gm.graph.nodes
)
self.assertTrue(includes_aten_relu)
def test_make_fx_no_decompose(self, device):
# FIXME
return self.skipTest("error: maximum recursion reached")
def f(x):
return torch.tanh(x).sum()
fx_f = make_fx(grad(f))(torch.randn(5))
ops = {i.target for i in fx_f.graph.nodes}
self.assertEqual(torch.ops.aten.tanh_backward in ops, True)
fx_f = make_fx(grad(f), decomposition_table)(torch.randn(5))
ops = {i.target for i in fx_f.graph.nodes}
self.assertEqual(torch.ops.aten.tanh_backward in ops, False)
def test_nnc_jit(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(3)
self.assertEqual(jit_f(inp), f(inp))
def test_nnc_scalar(self, device):
def f(x):
return torch.sin(x)
jit_f = nnc_jit(f)
inp = torch.randn(())
self.assertEqual(jit_f(inp), f(inp))
def test_nnc_pytrees(self, device):
def f(x):
return [torch.sin(x[0])]
jit_f = nnc_jit(f)
inp = [torch.randn(3)]
self.assertEqual(jit_f(inp), f(inp))
def test_external_calls(self, device):
def f(a, b):
return torch.mv(a, b)
jit_f = nnc_jit(f)
inp = [torch.randn(3, 3), torch.randn(3)]
self.assertEqual(jit_f(*inp), f(*inp))
def test_nnc_passthrough(self, device):
def f(x, y):
return x + y, y
inp = (torch.randn(3), torch.randn(3))
jit_f = nnc_jit(f)
self.assertEqual(jit_f(*inp), f(*inp))
def f(x):
x["a"] = x["a"] * 2
return x
inp = ({"a": torch.randn(3), "b": torch.randn(3)},)
jit_f = nnc_jit(f)
self.assertEqual(jit_f(*inp), f(*inp))
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self, device):
mod = torchvision.models.resnet18()
def f(x):
out = mod(x)
out.sum().backward()
return [a.grad for a in mod.parameters()]
inp = torch.randn(3, 3, 250, 250, requires_grad=True)
grads = f(inp)
mod.zero_grad()
mod(inp).sum().backward()
grads2 = [a.grad for a in mod.parameters()]
self.assertEqual(grads, grads2)
def get_base(t):
return t._base if t._is_view() else t
def is_in_base(t, maybe_tensors):
t_base = get_base(t)
for maybe_tensor in maybe_tensors:
if isinstance(maybe_tensor, torch.Tensor):
if t_base is get_base(maybe_tensor):
return True
return False
class TestAOTAutograd(AOTTestCase):
# test_mutation will:
# - Ensure that inputs are non-leaves, so our graphs can mutate them
# - try to mutate outputs of the graph (to ensure that autograd meta is set properly on outputs)
@patch("functorch.compile.config.debug_assert", True)
def verify_aot_autograd(
self,
f,
inp_: Union[Callable, List[Any]],
*,
test_mutation: bool = False,
keep_inp_mutations: bool = False,
decompositions: Optional[Dict] = None,
dynamic: bool = False,
# Only active when inp_ is Callable.
# TODO: probably consolidate all tests to make inp a Callable.
make_inputs_subclasses: bool = False,
):
for keep_input_mutations in [True] if keep_inp_mutations else [True, False]:
# Some tests pass in a callable for inp, to generate the inputs
# (useful if we want to generate complicated aliasing inputs)
if isinstance(inp_, Callable):
inp_callable = inp_
# The callable should return a tuple of f_inputs, f_graph_inputs
# (The idea is that we might want to compile a function with the graph inputs,
# but test autograd backprop all the way through the actual inputs)
with TwoTensorMode() if make_inputs_subclasses else nullcontext():
inp_copy, graph_inps_copy = inp_callable()
inp, graph_inps = inp_callable()
else:
inp_copy = []
inp = []
# Our input clones need to mimic when inputs are duplicates of one another
dupes_map = {}
for i, x in enumerate(inp_):
if x in dupes_map:
x_dupe_idx = dupes_map[x]
inp_copy.append(inp_copy[x_dupe_idx])
inp.append(inp[x_dupe_idx])
else:
dupes_map[x] = i
if not isinstance(x, torch.Tensor):
x_copy = x
x_copy2 = x
else:
x_copy = x.clone().detach().requires_grad_(x.requires_grad)
x_copy2 = x.clone().detach().requires_grad_(x.requires_grad)
if x.requires_grad and not x.is_leaf:
x_copy = x_copy.clone()
x_copy2 = x_copy2.clone()
inp_copy.append(x_copy)
inp.append(x_copy2)
if test_mutation:
# For graphs where we mutate inputs, need our test to make sure inputs aren't leaves
graph_inps = [x.add(1) for x in inp]
graph_inps_copy = [x.add(1) for x in inp_copy]
else:
graph_inps = inp
graph_inps_copy = inp_copy
fw_graph_cell = [None]
if isinstance(f, nn.Module):
compiled_f = aot_module(
f,
fw_compiler=make_boxed_compiler(
partial(extract_graph, graph_cell=fw_graph_cell)
),
bw_compiler=nop,
decompositions=decompositions,
keep_inference_input_mutations=keep_input_mutations,
dynamic=dynamic,
)
else:
compiled_f = aot_function(
f,
fw_compiler=make_boxed_compiler(
partial(extract_graph, graph_cell=fw_graph_cell)
),
bw_compiler=nop,
decompositions=decompositions,
keep_inference_input_mutations=keep_input_mutations,
dynamic=dynamic,
)
ref_out, ref_grad = outs_and_grads(f, graph_inps, inp)
test_out, test_grad = outs_and_grads(compiled_f, graph_inps_copy, inp_copy)
self.assertEqual(ref_grad, test_grad)
if isinstance(ref_out, torch.Tensor):
self.assertTrue(isinstance(test_out, torch.Tensor))
ref_out, test_out = [ref_out], [test_out]
for ref_o, test_o in zip(ref_out, test_out):
if isinstance(ref_o, torch.Tensor):
self.assertEqual(ref_o.requires_grad, test_o.requires_grad)
self.assertEqual(ref_o.is_leaf, test_o.is_leaf)
ref_is_view_of_non_interm = is_in_base(
ref_o, graph_inps
) or is_in_base(ref_o, ref_out)
test_is_view_of_non_interm = is_in_base(
test_o, graph_inps_copy
) or is_in_base(test_o, test_out)
self.assertEqual(
ref_is_view_of_non_interm, test_is_view_of_non_interm
)
self.assertEqual(ref_o, test_o)
if test_mutation:
# This tests that autograd meta is set properly on the output we can
# mutate it.
ref_o.mul_(2)
test_o.mul_(2)
self.assertEqual(ref_o, test_o)
for ref_i, test_i in zip(inp, inp_copy):
if isinstance(ref_i, torch.Tensor):
self.assertEqual(ref_i.requires_grad, test_i.requires_grad)
self.assertEqual(ref_i, test_i)
return fw_graph_cell[0]
def test_non_tensor_and_none_inputs(self):
# int, None, Tensor
def f(a, b, c):
return a * c
inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=True)]
self.verify_aot_autograd(f, inp)
inp = [2, None, torch.ones(3, 3, dtype=torch.float32, requires_grad=False)]
self.verify_aot_autograd(f, inp)
def test_single_output(self):
def f(a, b):
return a + b
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
def test_multi_output(self):
def f(a, b):
return a + b, a - b
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
def test_multi_output_list(self):
def f(a, b):
return [a + b, a - b]
inp = [torch.randn(3, 3, requires_grad=True), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
inp = [torch.randn(3, 3, requires_grad=False), torch.randn(3, 3)]
self.verify_aot_autograd(f, inp)
# Test for bug occurring at the intersection of fake tensors & functionalization.
def test_squeeze_mutation(self):
def f(a):
b = a.clone().squeeze(-1)
b.add_(1.0)
return a + b
inp = [torch.randn(3, 1, requires_grad=True)]
self.verify_aot_autograd(f, inp, dynamic=True)
inp = [torch.randn(3, 1, requires_grad=False)]
self.verify_aot_autograd(f, inp, dynamic=True)
def test_complex_linear(self):
# https://github.com/pytorch/pytorch/issues/93424
inp = [torch.randn(1, 10, 10, dtype=torch.complex64)]
class F(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10, dtype=torch.complex64)
def forward(self, x):
return self.linear(x).sum().abs()
self.verify_aot_autograd(F(), inp)
def test_embedding_bag_view_dynamic(self):
# Backwards pass tries to wrap a sparse tensor in a FunctionalTensorWrapper;
# test that this works even though the sparse tensor has no storage.
class F(torch.nn.Module):
def __init__(self):
super().__init__()
self.emb = torch.nn.EmbeddingBag(100, 8, sparse=True)
def forward(self, x, y):
return self.emb(x, y).view(-1)
x = torch.arange(3)
y = torch.arange(3)
self.verify_aot_autograd(F(), [x, y], dynamic=False)
self.verify_aot_autograd(F(), [x, y], dynamic=True)
def test_input_mutation_simple(self):
def f(a):
a.mul_(2)
return a * 3
inp = [torch.ones(3, 3, requires_grad=True)]
fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
inp = [torch.ones(3, 3, requires_grad=False)]
self.verify_aot_autograd(f, inp, test_mutation=True)
# Things to note:
# - the extra clone is because we need to pass the pre-mutated input to grad(),
# but autograd operates above functionalization so we need to manually clone.
# Hopefully backends can optimize this easily.
# - The extra return arg is because the compiled forward returns (mutated inputs + outputs)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None
mul_1 = torch.ops.aten.mul.Tensor(mul, 3)
return [mul, mul_1]""",
)
def test_input_mutation_set__input_mutation(self):
def f(a):
b = torch.arange(9, dtype=a.dtype).reshape(3, 3)
with torch.no_grad():
a.set_(b)
return a * b
inp = [torch.ones(3, 3, requires_grad=True)]
self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
inp = [torch.ones(3, 3, requires_grad=False)]
self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
def test_set__steals_view_chain(self):
def f(a, b):
a_ = a.mul(2)
b_ = b.mul(2)
b_slice = b_[1].view(3, 3)
# a_clone should inherit the view chain from b_slice
a_.set_(b_slice)
# Also mutates b_,
a_.view(-1).mul_(2)
return a_ * b_slice
inp = [
torch.ones(3, 3, requires_grad=False),
torch.zeros(3, 9, requires_grad=False),
]
self.verify_aot_autograd(f, inp, keep_inp_mutations=True)
def test_set__and_data_mutation_good(self):
def f(a, b):
# The data mutation happens *after* the set_(). This is ok (see the graph below)
with torch.no_grad():
a.set_(b)
b.mul_(2)
return a + b
inp = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
fw_graph = self.verify_aot_autograd(
f, inp, test_mutation=True, keep_inp_mutations=True
)
inp = [
torch.ones(3, 3, requires_grad=False),
torch.zeros(3, 3, requires_grad=False),
]
self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
# Important things to note:
# - "return a.set_(b)" desugars into "return b"
# - Both a and b are recorded as experiencing mutations,
# which is why we see "b_updated" (output of the mul) twice in the graph outputs.
# a is recorded as both a data mutation and a metadata mutation (due to set_ swapping its storage).
# - the runtime epilogue for a is "a.set_(mul)"
# - the runtime epilogue for b is "b.copy_(mul)"
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2):
mul = torch.ops.aten.mul.Tensor(primals_2, 2)
add = torch.ops.aten.add.Tensor(mul, mul)
set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul); primals_1 = None
copy_ = torch.ops.aten.copy_.default(primals_2, mul); primals_2 = mul = None
return [add]""",
)
# This is a (hopefully) extremely rare case that is difficult to handle,
# so we ban it.
# https://github.com/pytorch/pytorch/issues/126236
# https://github.com/pytorch/pytorch/pull/126113
@xfailIfTorchDynamo
def test_set__and_data_mutation_bad(self):
def f(a):
a_view = a.view(-1)
tmp = torch.ones(3, 3, requires_grad=True)
# Now, any mutations on either tmp
# will be tracked as graph input mutations.
with torch.no_grad():
a.set_(tmp)
# BAD: a_view is now detached from every graph input,
# so we won't recognize that this caused an input mutation!
a_view.mul_(2)
return a + tmp
inp = [torch.ones(3, 3, requires_grad=True)]
with self.assertRaisesRegex(
RuntimeError, "cannot mutate tensors with frozen storage"
):
self.verify_aot_autograd(
f, inp, test_mutation=True, keep_inp_mutations=True
)
def test_set__not_allowed(self):
def f(a, b):
with torch.no_grad():
a.set_(b)
# Mutating a will change a's grad_fn, which requires us to replay the mutation outside of the graph.
# We currently ban this today, when the input also received a set_() input mutation.
a.mul_(2)
return a + b
inp = [
torch.ones(3, 3, requires_grad=True),
torch.ones(3, 3, requires_grad=True),
]
with self.assertRaisesRegex(
AssertionError, "but the input has other mutations that we cannot"
):
fw_graph = self.verify_aot_autograd(
f, inp, test_mutation=True, keep_inp_mutations=True
)
def test_input_mutation_set__nop(self):
def f(a):
b = torch.arange(9, dtype=a.dtype)
a_old = torch.ops.aten.alias.default(a)
with torch.no_grad():
a.set_(b)
a.set_(a_old)
return a + b.reshape(3, 3)
inp = [torch.ones(3, 3, requires_grad=True)]
fw_graph = self.verify_aot_autograd(
f, inp, test_mutation=True, keep_inp_mutations=True
)
inp = [torch.ones(3, 3, requires_grad=False)]
self.verify_aot_autograd(f, inp, test_mutation=True, keep_inp_mutations=True)
# Things to note:
# - There are no set_() calls in the graph (we functionalize a.set_(b) into "b")
# - There is only **1** graph output. We properly realized that the two set_() calls
# undo each other, and so effectively no inputs are mutated.
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1):
arange = torch.ops.aten.arange.default(9, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
alias = torch.ops.aten.alias.default(primals_1); primals_1 = None
view = torch.ops.aten.view.default(arange, [3, 3]); arange = None
add = torch.ops.aten.add.Tensor(alias, view); alias = view = None
return [add]""",
)
def test_input_mutation_simple_with_none_and_nontensor(self):
# Tensor, None, int
def f(a, b, c):
return a * c
f_compiled = aot_function(f, nop)
for req_grad in [True, False]:
inp = [torch.ones(3, 3, requires_grad=req_grad), None, 3]
out_ref = f(*inp)
out_test = f_compiled(*inp)
self.assertEqual(out_ref, out_test)
# https://github.com/pytorch/pytorch/issues/93363
def test_mutates_input_noncontiguous(self):
def f(a):
a.add_(1)
return ()
f_compiled = aot_function(f, nop)
ref = torch.ones(4, requires_grad=True) + 0
ref_view = ref[0::2]
test = torch.ones(4, requires_grad=True) + 0
test_view = test[0::2]
out_ref = f(ref_view)
out_test = f_compiled(test_view)
print(ref)
print(test)
self.assertEqual(ref, test)
def test_input_mutation_modifies_autograd_meta_of_aliases(self):
def f(a):
a.mul_(2)
out = a + 1
return out.detach()
x_ref = torch.ones(3, 3, requires_grad=True).clone()
x_ref_view = x_ref.view(3, 3)
x_test = torch.ones(3, 3, requires_grad=True).clone()
x_test_view = x_test.view(3, 3)
f_compiled = aot_function(f, nop, keep_inference_input_mutations=True)
f(x_ref)
f_compiled(x_test)
# f will mutate aliases of the input, including its autograd metadata!
# y.grad_fn is AsStridedBackward
self.assertEqual(x_ref_view, x_test_view)
self.assertEqual(x_ref_view._version, x_test_view._version)
self.assertEqual(x_ref_view.grad_fn.__class__, x_test_view.grad_fn.__class__)
# Test the actual gradients are correct
(x_ref * x_ref_view).sum().backward()
(x_test * x_test_view).sum().backward()
self.assertEqual(x_ref.grad, x_test.grad)
self.assertEqual(x_ref_view.grad, x_test_view.grad)
def test_outputs_are_aliased(self):
# Tensor, None, int
def f(a):
b = a.mul(2)
c = b.view(-1)
return b, c
f_compiled = aot_function(f, nop)
for req_grad in [True, False]:
inp = torch.ones(3, requires_grad=req_grad)
out_ref = f(inp)
out_test = f_compiled(inp)
self.assertEqual(out_ref[0], out_test[0])
self.assertEqual(out_ref[1], out_test[1])
# Try mutating one of the outputs, which is aliased.
out_ref[0].mul_(3)
out_test[0].mul_(3)
# Assert that the aliasing relationship was preserved
self.assertEqual(out_ref[0], out_test[0])
self.assertEqual(out_ref[1], out_test[1])
def test_input_mutation_is_output(self):
def f(a):
a.mul_(2)
return a
inp = [torch.ones(3, 3, requires_grad=True)]
fw_graph = self.verify_aot_autograd(f, inp, test_mutation=True)
inp = [torch.ones(3, 3, requires_grad=False)]
self.verify_aot_autograd(f, inp, test_mutation=True)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None
return [mul, mul]""",
)
def test_input_mutation_multiple(self):
def f(a, b, c):
a.mul_(2)
c.mul_(2)
return a + b + c
def create_inp(req_grad):
return [
torch.ones(3, 3, requires_grad=req_grad),
torch.ones(3, 3, requires_grad=req_grad),
torch.ones(3, 3, requires_grad=req_grad),
]
self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
fw_graph = self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, primals_1, primals_2, primals_3):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1 = torch.ops.aten.clone.default(primals_3); primals_3 = None
mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None
mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None
add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None
add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None
return [mul, mul_1, add_1]""",
)
def test_input_mutation_return(self):
def f(a, b):
return torch.sin(a, out=b)
inp = [torch.randn(3, 3), torch.ones(3, 3)]
fw_graph = self.verify_aot_autograd(
f, inp, test_mutation=True, keep_inp_mutations=True
)
self.assertExpectedInline(
fw_graph.code.strip(),
"""\
def forward(self, arg0_1, arg1_1):
sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None
copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None
return (copy_,)""",
)
def test_input_mutation_metadata(self):
def f(a, b):
a.transpose_(1, 0)
return a + b
def create_inp(req_grad):
return [
torch.ones(3, 3, requires_grad=req_grad),
torch.ones(3, 3, requires_grad=req_grad),
]
self.verify_aot_autograd(f, create_inp(True), test_mutation=True)
self.verify_aot_autograd(f, create_inp(False), test_mutation=True)
def test_input_mutation_storage_resize_up(self):
def f(a):
torch.ops.inductor.resize_storage_bytes_(a, 32)
# float32, 4 bytes per element, 32 bytes == 8 elements
with torch.no_grad():
a.copy_(torch.ones(8))
return a + 1
inp = torch.zeros(8, requires_grad=True)
# Input starts with zero-size-storage
inp.untyped_storage().resize_(0)
fw_graph_cell = [None]
compiled_f = aot_function(
f,
fw_compiler=make_boxed_compiler(
partial(extract_graph, graph_cell=fw_graph_cell)
),
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,
)
out = compiled_f(inp)
# Final functionalized graph has two mutation ops:
# (1) a resize_() to resize input tensor up
# (2) a copy_() to fill in the resized input with valid data
self.assertExpectedInline(
fw_graph_cell[0].code.strip(),
"""\
def forward(self, primals_1):
ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
copy = torch.ops.aten.copy.default(primals_1, ones); ones = None
add = torch.ops.aten.add.Tensor(copy, 1)
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32)
copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = None
return [add]""",
)
def test_input_mutation_storage_resize_down(self):
def f(a):
out = a.sin()
torch.ops.inductor.resize_storage_bytes_(a, 0)
return out
inp = torch.zeros(8, requires_grad=True)
fw_graph_cell = [None]
compiled_f = aot_function(
f,
fw_compiler=make_boxed_compiler(
partial(extract_graph, graph_cell=fw_graph_cell)
),
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,
)
out = compiled_f(inp)
# Final functionalized graph has one mutation ops:
# (1) a resize_() to resize input tensor down
# Even though there was technically a "data mutation" on the input (from a.copy_()),
# We don't include it in the graph since the final input size has zero storage
self.assertExpectedInline(
fw_graph_cell[0].code.strip(),
"""\
def forward(self, primals_1):
sin = torch.ops.aten.sin.default(primals_1)
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0)
return [sin, primals_1]""",
)
def test_input_mutation_storage_resize_up_down(self):
def f(a):
torch.ops.inductor.resize_storage_bytes_(a, 32)
# float32, 4 bytes per element, 32 bytes == 8 elements
with torch.no_grad():
a.copy_(torch.ones(8))
out = a.sin()
torch.ops.inductor.resize_storage_bytes_(a, 0)
return out
inp = torch.zeros(8, requires_grad=True)
# Input starts with zero-size-storage
inp.untyped_storage().resize_(0)
fw_graph_cell = [None]
compiled_f = aot_function(
f,
fw_compiler=make_boxed_compiler(
partial(extract_graph, graph_cell=fw_graph_cell)
),
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,
)
out = compiled_f(inp)
# Final graph has two interesting properties:
# (1) no resizes in the functional graph, since the two resizes cancel out
# and the final size is zero
# (2) no copy_ in the functional graph, even though we copied data into the input,
# because the input has no storage at the end of graph execution (so no data to copy)
self.assertExpectedInline(
fw_graph_cell[0].code.strip(),
"""\
def forward(self, primals_1):
ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
copy = torch.ops.aten.copy.default(primals_1, ones); primals_1 = ones = None
sin = torch.ops.aten.sin.default(copy)
return [sin, copy]""",
)
def test_input_mutation_storage_resize_down_and_set_(self):
# Meant to mimic ppFSDP
class TracableCreateParameter(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, placeholder):
assert not tensor.requires_grad
return placeholder.set_(tensor)
@staticmethod
def backward(ctx, grad):
return None, grad # grad flows to placeholder
def f(dummy_param, param_shard):
# simulate allgather
with torch.no_grad():
allgather_param = torch.cat([param_shard, param_shard])
# simulate propagating grad state through dummy param, using data of allgather param
dummy_param_with_grad_state = TracableCreateParameter.apply(
allgather_param, dummy_param
)
out = dummy_param.sin()
# Resize out dummy param, which now has the allgather data
torch.ops.inductor.resize_storage_bytes_(dummy_param, 0)
return out
# Simulates the local shard of our param
param_shard = torch.zeros(8, requires_grad=True)
# The dummy, zero-sized allgathered param that autograd will actually compute gradients on
dummy_param = torch.zeros(16, requires_grad=True)
dummy_param.untyped_storage().resize_(0)
fw_graph_cell = [None]
compiled_f = aot_function(
f,
fw_compiler=make_boxed_compiler(
partial(extract_graph, graph_cell=fw_graph_cell)
),
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,
)
out = compiled_f(dummy_param, param_shard)
# Important stuff to point out:
# (1) We save cat for backward (input to the sin()).
# While the original code was dummy_param.sin(),
# dummy_param actually contains the `cat` tensor due to the set_() call
# (2) We emit a cat.resize_storage_(0) in the graph.
# After the set_(), cat is the actually data of dummy_param, which is what we call resize_() on
self.assertExpectedInline(
fw_graph_cell[0].code.strip(),
"""\
def forward(self, primals_1, primals_2):
cat = torch.ops.aten.cat.default([primals_2, primals_2]); primals_2 = None
sin = torch.ops.aten.sin.default(cat)
set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = None
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(set_, 0); set_ = None
return [sin, cat]""",
)
def test_input_mutation_storage_resize_before_set__not_supported(self):
def f(a):
with torch.no_grad():
torch.ops.inductor.resize_storage_bytes_(a, 0)
a.set_(torch.ones(2))
inp = torch.zeros(8, requires_grad=True)
# See Note [Ordering of resize_() and set_()]
with self.assertRaisesRegex(RuntimeError, "not supported today"):
compiled_f = aot_function(
f,
fw_compiler=nop,
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,