forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_compiled_autograd.py
1995 lines (1680 loc) · 74.9 KB
/
test_compiled_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["module: inductor"]
import functools
import logging
import re
import sys
import unittest
from importlib.machinery import SourceFileLoader
from pathlib import Path
from unittest import mock
import torch
import torch.nn as nn
from torch import _inductor as inductor
from torch._dynamo import compiled_autograd, config
from torch._dynamo.utils import counters
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
from torch.testing._internal.logging_utils import logs_to_string
# note: these tests are not run on windows due to inductor_utils.HAS_CPU
def make_compiler_fn(fullgraph=True, dynamic=True):
def _compiler_fn(gm):
"""Same as torch.compile() but counts number of compiles"""
def _inner_compiler(gm_, example_inputs_):
counters["compiled_autograd"]["compiles"] += 1
return inductor.compile(gm_, example_inputs_)
return torch.compile(
gm, backend=_inner_compiler, fullgraph=fullgraph, dynamic=dynamic
)
return _compiler_fn
compiler_fn = make_compiler_fn()
# TODO(jansel): hooks as lambdas creates recompiles in dynamo, we should fix that
def hook1(grad):
return grad * 2
def hook2(grads):
return (grads[0] + 1,)
def hook3(gI, gO):
return (torch.sin(gI[0]) + gO[0],)
class TestCompiledAutograd(TestCase):
def setUp(self) -> None:
super().setUp()
torch._logging.set_logs(compiled_autograd_verbose=False)
config.compiled_autograd = False
compiled_autograd.reset()
def tearDown(self) -> None:
super().tearDown()
torch._logging.set_logs(compiled_autograd_verbose=False)
config.compiled_autograd = False
compiled_autograd.reset()
def check_output_and_recompiles(
self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False
):
if isinstance(count, list):
captures, compiles = count
else:
captures, compiles = count, count
with torch.autograd.set_multithreading_enabled(False):
torch._dynamo.reset()
counters["compiled_autograd"].clear()
torch.manual_seed(123)
expected = list(fn())
torch.manual_seed(123)
with compiled_autograd.enable(compiler_fn):
opt_fn = torch.compile(fn) if compile_fn else fn
actual = list(opt_fn())
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], captures)
self.assertEqual(counters["compiled_autograd"]["compiles"], compiles)
def test_dynamo_flaky_segfault(self):
import os
import subprocess
script = """
import torch
def main():
def compiler_fn(gm):
return torch.compile(gm, backend="eager")
def inner():
x = torch.randn(1000, 3000)
w = torch.randn(1000, 3000, requires_grad=True)
def model(i):
return torch.nn.functional.linear(i, w)
out = model(x)
loss = out.sum()
with torch._dynamo.compiled_autograd.enable(compiler_fn):
loss.backward()
assert(w.grad is not None)
inner()
torch._dynamo.reset()
inner()
main()
"""
# Run it three times to catch bad dynamo state resets
for _ in range(3):
try:
subprocess.check_output(
[sys.executable, "-c", script],
stderr=subprocess.STDOUT,
# On Windows, opening the subprocess with the default CWD makes `import torch`
# fail, so just set CWD to this script's directory
cwd=os.path.dirname(os.path.realpath(__file__)),
)
except subprocess.CalledProcessError as e:
if e.returncode < 0:
self.fail("Subprocess exited with a fatal signal")
def test_basic(self):
def fn():
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
x = torch.randn([2, 4])
result = model(x).sum()
result.backward()
yield model[0].weight.grad
yield model[0].bias.grad
yield model[2].weight.grad
yield model[2].bias.grad
self.check_output_and_recompiles(fn)
def test_cache_hit(self):
def fn():
for _ in range(3):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
x = torch.randn([2, 4])
result = model(x).sum()
result.backward()
yield model[0].weight.grad
yield model[0].bias.grad
yield model[2].weight.grad
yield model[2].bias.grad
self.check_output_and_recompiles(fn)
def test_tensor_grad_hook1(self):
def fn():
for _ in range(3):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
x = torch.randn([2, 4])
model[0].weight.register_hook(hook1)
result = model(x).sum()
result.backward()
yield model[0].weight.grad
yield model[0].bias.grad
self.check_output_and_recompiles(fn)
def test_tensor_grad_hook2(self):
def fn():
for _ in range(3):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
x = torch.randn([1, 4])
result = model(x).sum()
result.grad_fn.register_prehook(hook2)
result.backward()
yield model[0].weight.grad
yield model[0].bias.grad
self.check_output_and_recompiles(fn)
def test_tensor_grad_hook3(self):
def fn():
for _ in range(3):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
x = torch.randn([1, 4])
result = model(x).sum()
result.grad_fn.register_hook(hook3)
result.backward()
yield model[0].weight.grad
yield model[0].bias.grad
self.check_output_and_recompiles(fn)
def test_torch_compile(self):
def fn():
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
opt_model = torch.compile(model, fullgraph=True)
for _ in range(3):
x = torch.randn([1, 4])
result = opt_model(x).sum()
result.backward()
yield model[0].weight.grad
yield model[0].bias.grad
model.zero_grad()
self.check_output_and_recompiles(fn)
def test_torch_compile_api_inductor(self):
def fn():
torch.manual_seed(123)
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
res = []
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
result.backward()
res.append(model[0].weight.grad)
res.append(model[0].bias.grad)
model.zero_grad()
return res
expected = fn()
with config.patch(compiled_autograd=True):
compiled_fn = torch.compile(fn)
actual = compiled_fn()
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_torch_compile_api_aot_eager(self):
def fn():
torch.manual_seed(123)
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
res = []
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
result.backward()
res.append(model[0].weight.grad)
res.append(model[0].bias.grad)
model.zero_grad()
return res
expected = fn()
with config.patch(compiled_autograd=True):
compiled_fn = torch.compile(fn, backend="aot_eager")
actual = compiled_fn()
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_torch_compile_api_eager(self):
def fn():
torch.manual_seed(123)
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
res = []
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
result.backward()
res.append(model[0].weight.grad)
res.append(model[0].bias.grad)
model.zero_grad()
return res
expected = fn()
with config.patch(compiled_autograd=True):
compiled_fn = torch.compile(fn, backend="eager")
actual = compiled_fn()
self.assertEqual(expected, actual)
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_multiple_torch_compile(self):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
x = torch.randn([1, 4])
def fn():
result = model(x).sum()
result.backward()
model2 = torch.nn.Linear(4, 4)
x2 = torch.randn([1, 4])
def fn2():
result = model2(x2).sum()
result.backward()
no_ca1 = torch.compile(fn)
no_ca1()
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
counters.clear()
with config.patch(compiled_autograd=True):
with_ca = torch.compile(fn2)
with_ca()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
counters.clear()
no_ca2 = torch.compile(fn)
no_ca2()
self.assertEqual(counters["compiled_autograd"]["captures"], 0)
def test_torch_compile_graph_break(self):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
x = torch.randn([1, 4])
@torch._dynamo.disable()
def fn():
result = model(x).sum()
result.backward()
with config.patch(compiled_autograd=True):
opt_fn = torch.compile(fn)
opt_fn()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_torch_compile_graph_break2(self):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
x = torch.randn([1, 4])
@torch._dynamo.disable()
def inner_fn(loss):
loss.backward()
def fn():
result = model(x).sum()
inner_fn(result)
with config.patch(compiled_autograd=True):
opt_fn = torch.compile(fn)
opt_fn()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_torch_compile_only_backward_call(self):
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Sigmoid(),
)
x = torch.randn([1, 4])
result = model(x).sum()
with config.patch(compiled_autograd=True):
opt_bwd = torch.compile(lambda: result.backward())
opt_bwd()
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_dynamo_boxed(self):
def get_placeholders(gm_):
placeholders = []
for node in gm_.graph.nodes:
if node.op == "placeholder":
placeholders.append(node)
return placeholders
def eager_with_check(gm, is_bwd):
def inner_compiler(gm_, example_inputs_):
placeholders = get_placeholders(gm_)
if is_bwd:
# should be boxed inputs
assert len(placeholders) == 1
pass
else:
assert len(placeholders) > 1
return gm_
return torch.compile(gm, backend=inner_compiler)
fwd_compiler_fn = functools.partial(eager_with_check, is_bwd=False)
bwd_compiler_fn = functools.partial(eager_with_check, is_bwd=True)
def fn(inputs):
args_0, args_1, args_2 = inputs
out = torch.mm(args_0, args_1)
out = torch.mm(out, args_2)
loss = out.sum()
with compiled_autograd.enable(bwd_compiler_fn):
loss.backward()
yield args_0.grad
yield args_1.grad
yield args_2.grad
inputs = [
torch.randn([1, 2], requires_grad=True),
torch.randn([2, 3], requires_grad=True),
torch.randn([3, 4], requires_grad=True),
]
compiled_fn = eager_with_check(fn, is_bwd=False)
grads = list(compiled_fn(inputs))
self.assertEqual(len(grads), 3)
self.assertNotEqual(grads[0], None)
self.assertNotEqual(grads[1], None)
self.assertNotEqual(grads[2], None)
def test_inputs_aliasing_bytecode_attr_mutations(self):
# Freeze compiled autograd graph
compiler = torch._dynamo.compiled_autograd.AutogradCompilerInstance(compiler_fn)
param = torch.ones(100)
activ = torch.ones(100) * 2
inputs = [param, activ]
proxies, _ = compiler.begin_capture(inputs=inputs, sizes=[])
param_proxy, activ_proxy = proxies
buf = activ_proxy * 2
torch.ops.inductor.accumulate_grad_.default(param_proxy, buf)
compiled_fn = compiler.end_capture(buf)
def bytecode_hook(code, out_code):
import dis
import sys
if sys.version_info < (3, 11):
call_op = "CALL_FUNCTION"
else:
call_op = "CALL"
insts = list(dis.get_instructions(out_code))
call_graph_idx = next(
i for i, inst in enumerate(insts) if inst.opname == call_op
)
# pre-graph should alias: inputs_ref_0 = inputs[0]
matches = [
inst
for inst in insts[:call_graph_idx]
if inst.opname == "STORE_FAST" and inst.argval == "inputs_ref_0"
]
self.assertTrue(len(matches) == 1)
# post-graph should access inputs_ref_0 instead of inputs
matches = [
inst for inst in insts[call_graph_idx:] if inst.argval == "inputs"
]
self.assertTrue(len(matches) == 0)
matches = [
inst
for inst in insts[call_graph_idx:]
if inst.opname == "LOAD_FAST" and inst.argval == "inputs_ref_0"
]
self.assertTrue(len(matches) == 1)
torch._dynamo.reset()
handle = torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
try:
compiled_fn(inputs=[param, activ], sizes=(), hooks=())
finally:
handle.remove()
def test_inputs_aliasing_bytecode_stack_restore(self):
logging.getLogger().setLevel(logging.WARNING)
from torch.testing._internal.logging_tensor import LoggingTensor
# Create a graph that allows inputs stealing
def forward(inputs):
add = inputs[0] + 1
add_1 = add + inputs[1] # handled in suffix for tensor subclass
out = add_1.cpu()
return (out,)
gm = torch.fx.symbolic_trace(forward)
torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
compiled_fn = torch.compile(gm)
inputs = [
torch.ones(1000000, dtype=torch.float32),
LoggingTensor(torch.ones(1)),
]
def bytecode_hook(code, out_code):
import dis
import sys
if sys.version_info < (3, 11):
call_op = "CALL_FUNCTION"
else:
call_op = "CALL"
insts = list(dis.get_instructions(out_code))
call_graph_idx = next(
i for i, inst in enumerate(insts) if inst.opname == call_op
)
# pre-graph should alias: inputs_ref_0 = inputs[0]
matches = [
inst
for inst in insts[:call_graph_idx]
if inst.opname == "STORE_FAST" and inst.argval == "inputs_ref_0"
]
self.assertTrue(len(matches) == 1)
# post-graph should access inputs_ref_0 instead of inputs
matches = [
inst for inst in insts[call_graph_idx:] if inst.argval == "inputs"
]
self.assertTrue(len(matches) == 0)
matches = [
inst
for inst in insts[call_graph_idx:]
if inst.opname == "LOAD_FAST" and inst.argval == "inputs_ref_0"
]
self.assertTrue(len(matches) == 1)
torch._dynamo.reset()
handle = torch._dynamo.convert_frame.register_bytecode_hook(bytecode_hook)
try:
out = compiled_fn(inputs)
self.assertTrue(len(inputs) == 0)
finally:
handle.remove()
def test_implicit_add(self):
def fn():
y = torch.randn(1, 4, requires_grad=True)
def model(x):
# y is used multiple times, gradients get added
return torch.sigmoid(x * y + torch.sin(y) + torch.cos(y))
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
result.backward()
yield result
yield y.grad
y.grad = None
self.check_output_and_recompiles(fn)
def test_output_nodes(self):
def fn():
y = torch.randn(1, 4, requires_grad=True)
z = torch.randn(1, 4, requires_grad=True)
def model(x):
return torch.sigmoid(x * z + torch.sin(y) + torch.cos(y))
for _ in range(3):
x = torch.randn([1, 4])
result = model(x).sum()
gy, gz = torch.autograd.grad(result, [y, z])
assert y.grad is None
assert z.grad is None
yield gy
yield gz
self.check_output_and_recompiles(fn)
def test_dynamic_shapes(self):
def fn():
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
opt_model = torch.compile(model, dynamic=True)
for b in range(10, 100, 10):
x = torch.randn([b, 4])
result = opt_model(x).sum()
result.backward()
yield model[0].weight.grad
yield model[0].bias.grad
yield model[2].weight.grad
yield model[2].bias.grad
model.zero_grad()
# TODO(jansel): we should be able to get this count to 1
self.check_output_and_recompiles(fn, count=2)
def test_accumulate_without_zero(self):
def fn():
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
opt_model = torch.compile(model, dynamic=True)
for _ in range(10):
x = torch.randn([10, 4])
result = opt_model(x).sum()
result.backward()
yield model[0].weight.grad.clone()
yield model[0].bias.grad.clone()
yield model[2].weight.grad.clone()
yield model[2].bias.grad.clone()
self.check_output_and_recompiles(fn, count=2)
def test_inplace_grad_update(self):
def fn():
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
opt_model = torch.compile(model, dynamic=True)
for _ in range(10):
w_grad = torch.rand_like(model[0].weight)
b_grad = torch.rand_like(model[0].bias)
model[0].weight.grad = w_grad
model[0].bias.grad = b_grad
x = torch.randn([10, 4])
result = opt_model(x).sum()
result.backward()
assert model[0].weight.grad is w_grad
assert model[0].bias.grad is b_grad
yield w_grad.clone()
yield b_grad.clone()
self.check_output_and_recompiles(fn, count=1)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_issue106555(self):
DEVICE = torch.device("cuda:0")
NUM_FEATURES = 256
def bias_sigmoid_mul(x1, x2, bias):
x2 = torch.sigmoid(x2 + bias)
y = x1 * x2
return y
bias_sigmoid_mul_jit = torch.compile(bias_sigmoid_mul)
class ModuleWithJit(nn.Module):
def __init__(self):
super().__init__()
self.linear_1 = nn.Linear(NUM_FEATURES, NUM_FEATURES, bias=True)
self.linear_2 = nn.Linear(NUM_FEATURES, NUM_FEATURES, bias=False)
self.linear_2_bias = nn.Parameter(torch.zeros(NUM_FEATURES))
def forward(self, input_tensor):
x1 = self.linear_1(input_tensor)
x2 = self.linear_2(input_tensor)
output = bias_sigmoid_mul_jit(x1, x2, self.linear_2_bias)
return output
class Model(nn.Module):
def __init__(self):
super().__init__()
self.module_with_jit_1 = ModuleWithJit()
self.module_with_jit_2 = ModuleWithJit()
def forward(self, x, gradient_checkpointing: bool):
if gradient_checkpointing:
y = torch.utils.checkpoint.checkpoint(
self._forward, x, use_reentrant=True
)
else:
y = self._forward(x)
return y
def _forward(self, x):
x = x + self.module_with_jit_1(x)
x = x + self.module_with_jit_2(x.transpose(-2, -3)).transpose(-2, -3)
return x
torch.cuda.set_device(device=DEVICE)
torch.manual_seed(1234567890)
model = Model()
model.train()
model.to(device=DEVICE)
model_parameters = list(model.parameters())
torch.manual_seed(1234567890)
input_tensor = torch.randn(1, 128, 256, NUM_FEATURES).to(device=DEVICE)
input_tensor.requires_grad = True
target_tensor = torch.randn(1, 128, 256, NUM_FEATURES).to(
dtype=input_tensor.dtype, device=DEVICE
)
for iteration in range(10):
for param in model_parameters:
param.grad = None
output_tensor = model(
x=input_tensor.clone(),
gradient_checkpointing=True,
)
loss = torch.mean(torch.abs(target_tensor - output_tensor))
loss.backward()
def test_keep_graph_simple(self):
x = torch.tensor([2.0], requires_grad=True)
y = x**2
# First backward pass; keep the computation graph
y.backward(retain_graph=True)
self.assertEqual(x.grad, torch.Tensor([4])) # dy/dx at x=2 is 4
# Note - this will run under both the eager and compiled regime.
def fn():
# Reset the gradients
x.grad = torch.tensor([0.0])
# Second and Third backward pass; keep the computation graph
y.backward(retain_graph=True)
self.assertEqual(x.grad, torch.Tensor([4])) # dy/dx at x=2 is 4
return x.grad
self.check_output_and_recompiles(fn, count=1)
def test_keep_graph_usage_after_compiled(self):
x = torch.tensor([2.0], requires_grad=True)
y = x**2
# First backward pass; keep the computation graph
def eager_check():
y.backward(retain_graph=True)
self.assertEqual(x.grad, torch.Tensor([4])) # dy/dx at x=2 is 4
x.grad = torch.tensor([0.0])
eager_check()
for i in range(0, 5):
with compiled_autograd.enable(compiler_fn):
eager_check()
eager_check()
def test_custom_fn_saved_tensors(self):
def fn():
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.sin(x)
@staticmethod
def backward(ctx, gO):
(x,) = ctx.saved_tensors
return gO * torch.cos(x)
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
out = MySin.apply(x)
loss = out.sum()
loss.backward()
yield x.grad
self.check_output_and_recompiles(fn, count=2)
def test_custom_fn_saved_multiple_tensors(self):
def fn():
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
ctx.save_for_backward(x, y)
return torch.sin(x), torch.sin(y)
@staticmethod
def backward(ctx, gO_x, gO_y):
(x, y) = ctx.saved_tensors
return gO_x * torch.cos(x), gO_y * torch.cos(y)
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
y = torch.arange(0.0, i, requires_grad=True)
out1, out2 = MyFn.apply(x, y)
loss = (out1 * out2).sum()
loss.backward()
yield x.grad
self.check_output_and_recompiles(fn, count=2)
def test_custom_fn_saved_multiple_tensors_dedup(self):
def fn():
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x, x)
return torch.sin(x)
@staticmethod
def backward(ctx, gO):
(x1, x2) = ctx.saved_tensors
return gO * torch.cos(x1) * torch.cos(x2)
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
out = MyFn.apply(x)
loss = out.sum()
loss.backward()
yield x.grad
self.check_output_and_recompiles(fn, count=2)
def test_custom_fn_saved_shape_tensor(self):
def fn():
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x
@staticmethod
def backward(ctx, gO):
(x,) = ctx.saved_tensors
return gO * x.shape[0]
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
out = MyFn.apply(x)
loss = out.sum()
loss.backward()
yield x.grad
self.check_output_and_recompiles(fn, count=2)
def test_custom_fn_saved_attr(self):
def fn():
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.shape = x.shape
return x
@staticmethod
def backward(ctx, gO):
x_shape = ctx.shape[0]
return gO * x_shape
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
out = MyFn.apply(x)
loss = out.sum()
loss.backward()
yield x.grad
self.check_output_and_recompiles(
fn, count=2, compiler_fn=make_compiler_fn(fullgraph=False)
)
def test_custom_fn_multiple_grads(self):
def fn():
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
return x + y, y
@staticmethod
def backward(ctx, gO_1, gO_2):
return gO_1, gO_2
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
y = torch.arange(0.0, i, requires_grad=True)
out1, out2 = MyFn.apply(x, y)
loss = (out1 + out2).sum()
loss.backward()
yield x.grad
yield y.grad
self.check_output_and_recompiles(fn, count=2)
def test_custom_fn_non_variable_input(self):
def fn():
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y, z):
return x * 2, y * 3, z * 4
@staticmethod
def backward(ctx, gO_1, gO_2, gO_3):
return gO_1, gO_2, gO_3
for i in [10, 100, 10, 15, 20, 25]:
x = torch.arange(0.0, i, requires_grad=True)
y = 1
z = torch.arange(0.0, i, requires_grad=True)
out1, out2, out3 = MyFn.apply(x, y, z)
loss = (out1 + out2 + out3).sum()
loss.backward()
yield x
yield y
yield z
self.check_output_and_recompiles(fn, count=2)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_logging_tensor_flaky(self) -> None:
# when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore
# resulting in:
# - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'`
# - python: `TypeError: not all arguments converted during string formatting`
# 1. some triton involving test
def fn():
def _fn(x):
return x
x = torch.arange(
1, 10, requires_grad=True, dtype=torch.float16, device="cuda"
)
out = _fn(x)
loss = out.sum()
loss.backward()
with compiled_autograd.enable(compiler_fn):
fn()
logging.getLogger().setLevel(
logging.WARNING
) # triton setup overwrote it to INFO
# 2. test_inputs_aliasing_bytecode_stack_restore
from torch.testing._internal.logging_tensor import LoggingTensor
def forward(inputs):
add = inputs[0] + 1
add_1 = add + inputs[1]
out = add_1.cpu()
return (out,)
gm = torch.fx.symbolic_trace(forward)
print(gm.print_readable())
torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
compiled_fn = torch.compile(gm)
inputs = [
torch.ones(1000000, dtype=torch.float32),
LoggingTensor(torch.ones(1)),
]
compiled_fn(inputs)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_custom_fn_output_metadata(self):
def my_compiler_fn(gm):
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
assert (
node.target._name != "aten::_to_copy"
), "there should be no implicit copies (e.g. dtype casting)"
def inner_compiler(gm_, example_inputs_):
counters["compiled_autograd"]["compiles"] += 1
return inductor.compile(gm_, example_inputs_)
return torch.compile(
gm, backend=inner_compiler, fullgraph=True, dynamic=True
)
def fn():
class MyFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x