forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_cudagraph_trees.py
1784 lines (1388 loc) · 60.9 KB
/
test_cudagraph_trees.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["module: inductor"]
import contextlib
import functools
import gc
import importlib
import sys
import unittest
import warnings
import torch
import torch._dynamo.config as dynamo_config
import torch.nn as nn
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.cudagraph_trees import cudagraphify_impl as tree_cudagraphify_impl
from torch._inductor.test_case import TestCase as InductorTestCase
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_CI,
IS_LINUX,
IS_WINDOWS,
parametrize,
skipIfRocm,
TEST_CUDA_GRAPH,
TEST_WITH_ASAN,
)
from torch.utils._python_dispatch import TorchDispatchMode
if IS_WINDOWS and IS_CI:
sys.stderr.write(
"Windows CI does not have necessary dependencies for test_torchinductor yet\n"
)
if __name__ == "__main__":
sys.exit(0)
raise unittest.SkipTest("requires sympy/functorch/filelock")
importlib.import_module("functorch")
importlib.import_module("filelock")
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
aten = torch.ops.aten
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
requires_multigpu = functools.partial(
unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices"
)
from io import StringIO
def get_compile_fn(backend):
if backend == "cudagraphs":
return functools.partial(torch.compile, backend="cudagraphs")
else:
return functools.partial(torch.compile, mode="reduce-overhead")
class capture_stderr(list):
"""
Replace sys.stderr with a temporary StringIO
"""
def __enter__(self):
self.sys_stderr = sys.stderr
self.stringio = StringIO()
sys.stderr = self.stringio
return self
def __exit__(self, *args):
self.append(str(self.stringio.getvalue()))
del self.stringio
sys.stderr = self.sys_stderr
def cdata(t):
return t.untyped_storage()._cdata
class TestCase(InductorTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._stack = contextlib.ExitStack()
cls._stack.enter_context(
config.patch(
{
"debug": True,
"cpp.min_chunk_size": 1,
"triton.autotune_pointwise": False, # too slow
"implicit_fallbacks": False,
}
)
)
@classmethod
def tearDownClass(cls):
cls._stack.close()
super().tearDownClass()
def setUp(self):
torch._dynamo.reset()
super().setUp()
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
if HAS_CUDA and not TEST_WITH_ASAN:
def get_all_cudagraph_segments():
segments = torch.cuda.memory_snapshot()
return [segment for segment in segments if segment["segment_pool_id"] != (0, 0)]
def all_live_blocks():
blocks_addrs = []
for segment in get_all_cudagraph_segments():
addr = segment["address"]
for block in segment["blocks"]:
if block["state"] == "active_allocated":
blocks_addrs.append(addr)
addr += block["size"]
return blocks_addrs
def all_live_block_count():
return len(all_live_blocks())
class CudaGraphTreeTests(TestCase):
def setUp(self):
super().setUp()
self.graph_stack = contextlib.ExitStack()
self.graph_stack.enter_context(
config.patch(
{
"triton.cudagraphs": True,
"triton.cudagraph_trees": True,
"triton.fast_path_cudagraph_asserts": True, # too slow
"triton.slow_path_cudagraph_asserts": True,
}
)
)
self.graph_stack.enter_context(
dynamo_config.patch(automatic_dynamic_shapes=True)
)
self.device_idx = torch.rand([0], device="cuda").device.index
warnings.filterwarnings("ignore")
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
gc.collect()
torch.cuda.empty_cache()
self.graph_stack.close()
self.assertIsNone(self.get_manager())
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(len(get_all_cudagraph_segments()), 0)
warnings.resetwarnings()
def get_manager(self, device_index=None):
return torch._inductor.cudagraph_trees.get_container(
self.device_idx if not device_index else device_index
).tree_manager
def get_roots(self):
return self.get_manager().get_roots()
def curr_node(self):
return self.get_manager().current_node
def get_root_children(self):
return [root.num_descendants() for root in self.get_roots()]
def cudagraphify_impl(
self, *args, is_inference=True, is_backward=False, **kwargs
):
return tree_cudagraphify_impl(
*args,
**kwargs,
device_index=self.device_idx,
is_inference=is_inference,
is_backward=is_backward,
)
@staticmethod
def run_twc(fn, *args, **kwargs):
fn(*args, **kwargs)
return fn(*args, **kwargs)
def num_checkpoints(self):
return self.get_manager().debug_checkpointing_counter
def test_run_simple(self):
def foo(x):
return x * x * x
foo_opt = torch.compile(foo)
ones = torch.ones([4, 4], device="cuda")
zeros = torch.zeros([5, 5], device="cuda")
self.run_twc(foo_opt, ones)
self.run_twc(foo_opt, zeros)
self.assertEqual(self.get_root_children(), [0, 0])
def check_rng(self):
@torch.compile(mode="reduce-overhead")
def foo():
return torch.rand([20])
torch.manual_seed(0)
out = foo()
out2 = foo()
out3 = foo()
torch.manual_seed(0)
self.assertEqual(out, foo())
self.assertEqual(out2, foo())
self.assertEqual(out3, foo())
@torch._inductor.config.patch("fallback_random", True)
def test_rng_trees(self):
self.check_rng()
@torch._inductor.config.patch("triton.cudagraph_trees", False)
@torch._inductor.config.patch("fallback_random", True)
def test_rng_non_trees(self):
self.check_rng()
def test_mutation_reinplaced(self):
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, other, out):
input = torch.logical_xor(input=input, other=other, out=out)
return input
x = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda()
y = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float32).cuda()
z = torch.rand([1, 2, 1, 4, 9, 7], dtype=torch.float16).cuda()
model = Model().cuda()
eag = model(x, y, z)
with capture_stderr() as captured_output:
opt = torch.compile(model.forward, mode="reduce-overhead")(x, y, z)
FileCheck().check(
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
).check("torch.logical_xor").run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@requires_multigpu()
@parametrize("backend", ("inductor", "cudagraphs"))
def test_multiple_devices_msg(self, backend):
def foo(x, y):
return (x + 1, y + 2)
foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
foo(torch.ones([10], device="cuda"), torch.ones([20]))
FileCheck().check(
"skipping cudagraphs due to cpu device (arg1_1). Found from"
).check("y + 2").run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
with capture_stderr() as captured_output:
foo(
torch.ones([10], device="cuda:0"), torch.ones([10], device="cuda:1")
)
FileCheck().check("skipping cudagraphs due to multiple devices").run(
captured_output[0]
)
self.assertEqual(counters["inductor"]["cudagraph_skips"], 2)
@torch._inductor.config.patch("triton.cudagraph_skip_dynamic_graphs", True)
def test_skip_symbolic(self):
@torch.compile(dynamic=True)
def foo(x, y):
return x + y
with capture_stderr() as captured_output:
foo(torch.rand([10], device="cuda"), torch.rand([10], device="cuda"))
FileCheck().check(
"skipping cudagraphs due to graph with symbolic shapes inputs"
).check("x + y").run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_on_inp(self, backend):
def foo(x):
x.add_(2)
return x
foo = get_compile_fn(backend)(foo)
def inp():
return torch.ones([10], device="cuda")
with capture_stderr() as captured_output:
foo(inp())
FileCheck().check(
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
).check(".add_(2)").run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
# mutation on inp doesnt hit cudagraphs
self.assertEqual(len(self.get_manager().roots), 0)
# mutation on parameters/buffers hits cudagraphs
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.buf = torch.ones([10], device="cuda")
def forward(self, x):
self.buf.add_(x)
return self.buf + x
def foo(mod, x):
return mod(x)
foo = get_compile_fn(backend)(foo)
mod = Mod()
mod2 = Mod()
for _ in range(3):
self.assertEqual(foo(mod, inp()), mod2(inp()))
self.assertEqual(mod.buf, mod2.buf)
self.assertIsNotNone(self.get_manager())
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", False)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", False)
def test_mutation_cudagraph_managed_tensors_config(self, backend):
def foo(x):
return x + 1
def mut(x):
x.add_(2)
return x
def non_mut(x):
return x.add(2)
mut = get_compile_fn(backend)(mut)
foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
tmp = foo(inp)
mut_out = mut(tmp)
self.assertEqual(mut_out, non_mut(foo(inp)))
FileCheck().check_count(
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
1,
exactly=True,
).run(captured_output[0])
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensors(self, backend):
def foo(x):
return x + 1
def mut(x):
x.add_(2)
return x
def non_mut(x):
return x.add(2)
mut = get_compile_fn(backend)(mut)
foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
tmp = foo(inp)
mut_out = mut(tmp)
self.assertEqual(mut_out, non_mut(foo(inp)))
FileCheck().check_count(
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
0,
exactly=True,
).run(captured_output[0])
self.assertTrue("cudagraph_skips" not in counters["inductor"])
torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
tmp = foo(inp)
mut_inp = tmp.clone()
# in this case, what previously a mutated cudagraph managed tensor is no longer,
# now its an input from eager we should fallback to inductor without cudagraphs
with capture_stderr() as captured_output:
mut(mut_inp)
FileCheck().check(
"skipping cudagraphs due to mutated inputs (1 instances). Found from"
).check("x.add_(2)").run(captured_output[0])
self.assertEqual(mut_inp, non_mut(foo(inp)))
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensor_warn(self, backend):
def foo(x):
return x.add_(1)
def fee(y, z):
return z.add(3)
def inp():
return torch.rand([4], device="cuda")
foo = get_compile_fn(backend)(foo)
fee = get_compile_fn(backend)(fee)
with capture_stderr() as captured_output:
for _ in range(3):
torch.compiler.cudagraph_mark_step_begin()
fee(inp(), foo(inp()))
FileCheck().check_count(
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
1,
exactly=True,
).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
@parametrize("backend", ("inductor", "cudagraphs"))
@torch._dynamo.config.patch("cudagraph_backend_keep_input_mutation", True)
@torch._dynamo.config.patch("cudagraph_backend_support_input_mutation", True)
@torch._inductor.config.patch("triton.cudagraph_support_input_mutation", True)
def test_mutation_cudagraph_managed_tensor_warn_only_once(self, backend):
def foo(x):
return x + 1
def mut(x):
x.add_(2)
return x
def inp():
return torch.rand([4], device="cuda")
mut = get_compile_fn(backend)(mut)
foo = get_compile_fn(backend)(foo)
with capture_stderr() as captured_output:
# Should warn for current_node=None
mut(inp())
for i in range(3):
torch.compiler.cudagraph_mark_step_begin()
tmp = foo(inp())
mut(tmp) # should not warn
mut_inp = tmp.clone()
mut(mut_inp) # should not warn since mut has warned
FileCheck().check_count(
"skipping cudagraphs due to mutated inputs (1 instances). Found from",
1,
exactly=True,
).run(captured_output[0])
self.assertEqual(counters["inductor"]["cudagraph_skips"], 1)
def test_function_compiled_multiple_times(self):
def foo(x):
y = foo2(x)
y2 = foo2(y)
return y + y2
def foo2(x):
torch._dynamo.graph_break()
return x * x * x
foo_opt = torch.compile(foo)
ones = torch.ones([4, 4], device="cuda")
foo(ones)
foo_opt(ones)
foo_opt(ones)
self.assertEqual(foo_opt(ones), foo(ones))
# paths
children = self.get_root_children()
# one root with two children
self.assertEqual(children, [2])
def test_end_recording_early(self):
def foo(x):
y = x * x * x
torch._dynamo.graph_break()
z = x + y
return z
@torch.compile
def foo2(x):
return x + 4
foo_opt = torch.compile(foo)
for _ in range(3):
out = foo_opt(torch.ones([4, 4], device="cuda"))
del out
# when I tried inducing separate recordings via graph break,
# the frame kept interferring by keeping outputs alive
# this isnt great by simulates the logic.
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out = foo2(torch.ones([4, 4], device="cuda"))
del out
foo_opt(torch.ones([4, 4], device="cuda"))
# Two separate traces - one has a child, one doesnt
self.assertEqual(self.get_root_children(), [1, 0])
def test_execution_into_recording(self):
def foo(x):
y = x + x
if y.sum() > 0:
return y + 10
else:
return y - 10
foo_opt = torch.compile(foo)
inp = torch.zeros([4, 4], dtype=torch.float, device="cuda")
self.assertEqual(foo_opt(inp), foo(inp))
self.assertEqual(foo_opt(inp), foo(inp))
inp.add_(1)
out_eager = foo(inp)
out_warmup = foo_opt(inp)
self.assertEqual(out_warmup, out_eager)
# warmup should be have storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
out_live = foo_opt(inp)
self.assertEqual(out_live, out_eager)
# should be in recording mode, with storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
# warmup should have been freed
del out_warmup
# should be in recording mode, with storage deallocator hooked on
self.assertEqual(all_live_block_count(), 1)
del out_live
self.assertEqual(all_live_block_count(), 0)
out = foo_opt(inp)
self.assertEqual(foo(inp), out)
# should be in execution mode
self.assertEqual(all_live_block_count(), 0)
def test_forward_with_skipped_cudagraphed_backward(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x * x * x
for _ in range(3):
inp = torch.rand([20, 20], device="cuda", requires_grad=True)
out = foo(inp)
def complex_memory_overlap_new(t):
return True
try:
prev = torch._inductor.compile_fx.complex_memory_overlap
torch._inductor.compile_fx.complex_memory_overlap = (
complex_memory_overlap_new
)
back_inp = torch.empty_strided([20, 20], [0, 1], device="cuda")
out.backward(back_inp)
finally:
torch._inductor.compile_fx.complex_memory_overlap = prev
# we should not have cudagraph'd the backwards
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 1)
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
@parametrize("backend", ("inductor", "cudagraphs"))
def test_forward_backward_not_called(self, backend):
def foo(x, y):
x_out = x * x * x
torch._dynamo.graph_break()
y_out = y * y * y
return x_out, y_out
foo = get_compile_fn(backend)(foo)
for _ in range(3):
inps = [
torch.rand([20, 20], requires_grad=True, device="cuda")
for _ in range(2)
]
x_out, y_out = foo(inps[0], inps[1])
x_out.sum().backward()
self.assertFalse(self.get_manager().running_forwards_with_pending_backwards)
# we should not have cudagraph'd the y backward
new_id = self.get_manager().new_graph_id().id
self.assertEqual(new_id, 3)
def _test_unaligned_static_input_impl(self, expected_clones):
def fn(x, y):
return (x + y,)
def get_aligned_inputs():
return [torch.rand([5, 5], device="cuda") for _ in range(2)]
mod = make_fx(fn)(*get_aligned_inputs())
mode = torch._subclasses.FakeTensorMode()
with mode:
inps = [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
compiled_f = compile_fx_inner(mod, inps, num_fixed=1, cudagraphs=True)
def get_unaligned_inputs():
return [torch.rand([6, 5], device="cuda")[1:] for _ in range(2)]
class CloneCounterMode(TorchDispatchMode):
def __init__(self):
self.count = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
self.count += func is torch.ops.aten.clone.default
return func(*args, **kwargs)
for _ in range(3):
with CloneCounterMode() as m:
compiled_f(get_unaligned_inputs())
self.assertEqual(m.count, expected_clones)
compiled_f(get_aligned_inputs())
self.assertEqual(m.count, expected_clones)
def test_unaligned_static_input_trees(self):
self._test_unaligned_static_input_impl(expected_clones=0)
@torch._inductor.config.patch("triton.cudagraph_trees", False)
def test_unaligned_static_input_non_trees(self):
self._test_unaligned_static_input_impl(expected_clones=0)
@torch._inductor.config.patch("triton.cudagraphs", False)
def test_unaligned_static_input_no_cudagraphs(self):
self._test_unaligned_static_input_impl(expected_clones=0)
def test_sparsity(self):
def foo(view_6, buf31):
return aten._sparse_coo_tensor_with_dims_and_tensors(
1,
1,
[1000000, 64],
view_6,
buf31,
dtype=torch.float32,
layout=torch.sparse_coo,
device="cuda",
pin_memory=None,
)
foo_opt = torch.compile(foo)
view_6 = torch.zeros([1, 102397], dtype=torch.int64, device="cuda")
buf31 = torch.rand([102397, 64], device="cuda")
for _ in range(3):
self.assertEqual(foo_opt(view_6, buf31), foo(view_6, buf31))
def test_accumulate_multiple_recordings(self):
def foo(x):
y = x + x + x
torch._dynamo.graph_break()
if y.sum() <= 0:
return y
else:
return y * 10
foo_opt = torch.compile(foo)
# two separate compilations & recordings
out1 = self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
# out1 gets manually freed
out2 = self.run_twc(foo_opt, torch.zeros([6], device="cuda"))
self.assertEqual(all_live_block_count(), 1)
out3 = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
self.assertEqual(out3, foo(torch.ones([5], device="cuda")))
self.assertEqual(all_live_block_count(), 1)
del out1, out2
self.assertEqual(all_live_block_count(), 1)
del out3
gc.collect()
self.assertEqual(all_live_block_count(), 0)
@torch._inductor.config.patch("freezing", True)
def test_constant_output(self):
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(
torch.tensor([float(i) for i in range(10)], device="cuda")
)
def forward(self, inp):
return self.param, self.param[0:2], inp + 2
inp = torch.tensor([2], device="cuda")
m = Mod()
with torch.no_grad():
out_eager = m(inp)
m_comp = torch.compile(m)
for _ in range(3):
self.assertEqual(out_eager, m_comp(inp))
def test_live_outputs_multiple_graphs(self):
def foo(x):
x = x + x + x
y = x + 1
torch._dynamo.graph_break()
z = x * x
if z.sum() > 0:
return y + 1
else:
return y
foo_opt = torch.compile(foo)
self.run_twc(foo_opt, torch.zeros([5], device="cuda"))
self.assertEqual(self.num_checkpoints(), 0)
out = self.run_twc(foo_opt, torch.ones([5], device="cuda"))
self.assertEqual(all_live_block_count(), 1)
del out
self.assertEqual(all_live_block_count(), 0)
# we need to checkpoint from function to warmup y + 1,
# and then again to record it
self.assertEqual(self.num_checkpoints(), 2)
def test_expanded_inputs(self):
x = torch.rand(1, 512, device="cuda").expand(4, 512)
def foo(x):
return x + 4 + torch.ones([4, 512], device="cuda")
foo_opt = torch.compile()(foo)
for _ in range(3):
self.assertEqual(foo_opt(x), foo(x))
self.assertFalse(self.get_manager().new_graph_id().id == 0)
@torch._inductor.config.patch("triton.skip_cudagraph_warmup", True)
def test_tensor_dies_between_checkpoint(self):
def foo(args):
x = args[0]
args.clear()
return x + 1, x + 2
inp = torch.rand([4], device="cuda")
inp_list = [inp]
foo_cg = self.cudagraphify_impl(foo, inp_list, ())
foo_cg(inp_list)
foo_cg([inp])
out1, out2 = foo_cg([inp])
inp = [out1]
del out1, out2
def foo2(args):
x = args[0]
args.clear()
return [x * x * x]
self.assertEqual(self.num_checkpoints(), 0)
foo2_cg = self.cudagraphify_impl(foo2, inp, ())
x = foo2_cg(inp)[0]
self.assertEqual(self.num_checkpoints(), 1)
# out2 dies between the previous recording and the new one,
# need to be manually deallocated after the checkpoint
self.assertEqual(all_live_block_count(), 1)
del x
self.assertEqual(all_live_block_count(), 0)
def test_aliased_storage_single_weakref(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
x = x * 20
x_alias = x[0]
y = x * 10
y_alias = y[0]
torch._dynamo.graph_break()
ind = torch.tensor(4, device="cuda")
x_alias2 = x[ind:]
y_alias2 = y[ind:]
return x, x_alias, x_alias2, y_alias, y_alias2
for _ in range(4):
outs = foo(torch.rand([20, 20], device="cuda"))
ptr_to_ref = {
out.untyped_storage().data_ptr(): out.untyped_storage()._cdata
for out in outs
}
self.assertEqual(len(ptr_to_ref), 2)
for out in outs:
self.assertEqual(
ptr_to_ref[out.untyped_storage().data_ptr()],
out.untyped_storage()._cdata,
)
del outs
del out
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 0)
self.assertFalse(self.get_manager().new_graph_id().id == 0)
def test_aliasing_static_ref(self):
class Mod(torch.nn.Linear):
def forward(self, x):
return self.weight.T @ x, self.weight.T, self.weight[0:4]
m = Mod(10, 10).cuda()
@torch.compile(mode="reduce-overhead")
def foo(mod, x):
return mod(x)
@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[2:]
x = torch.rand([10, 10], device="cuda", requires_grad=True)
param_c = cdata(m.weight)
for _ in range(3):
out1, alias_1, alias_2 = foo(m, x)
self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1)
out2 = foo2(out1)
out2.sum().backward()
self.assertEqual(cdata(out1), cdata(out2))
node = self.curr_node()
first_node = next(node._path_from_root)
self.assertFalse(first_node.unaliased_in_all_paths[0])
self.assertTrue(first_node.cached_tensor_outputs[0] is None)
@skipIfRocm
def test_checkpointing_resets_persistent_refs(self):
@torch.compile(mode="reduce-overhead")
def foo(x):
return x @ x
def inp():
return torch.rand([20, 20], device="cuda", requires_grad=False)
for _ in range(3):
foo(inp())
self.assertEqual(self.num_checkpoints(), 0)
out = foo(inp())
out_id = id(out)
del out
self.assertEqual(id(foo(inp())), out_id)
@torch.compile(mode="reduce-overhead")
def foo2(x):
return x[0], x @ x
for i in range(2):
out = foo(inp())
from torch._dynamo.mutation_guard import GenerationTracker
GenerationTracker.generation -= 1
out_alias, out2 = foo2(out)
del out_alias
self.assertEqual(all_live_block_count(), 2)
del out
self.assertEqual(all_live_block_count(), 1)
del out2
self.assertEqual(all_live_block_count(), 0)
self.assertEqual(self.num_checkpoints(), i + 1)
new_out = foo(inp())
curr_node = self.curr_node()
self.assertFalse(curr_node.unaliased_in_all_paths[0])
self.assertFalse(out_id == id(new_out))
def test_aliased_static_parameter(self):
inp = torch.rand([20, 20], device="cuda")
def foo(args):
x = args[0]
args.clear()
return (x[0],)
foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
for _ in range(3):
out = foo_cg([inp])[0]
self.assertEqual(cdata(inp), cdata(out))
node = self.curr_node()
self.assertEqual(node.cached_tensor_outputs, [None])
self.assertEqual(node.unaliased_in_all_paths, [False])
def test_warmup_stream_sync(self):
def foo(args):
x = args[0]
args.clear()
x_orig = x
for _ in range(100):
x = x @ x
return (x,)
inp = torch.rand([4096, 4096], device="cuda")
ref = foo([inp])[0]
torch.cuda.synchronize()
user_stream = torch.cuda.Stream()
with torch.cuda.stream(user_stream):
foo_cg = self.cudagraphify_impl(foo, [inp], (0,))
out = foo_cg([inp])[0]
y = out + 1
self.assertEqual(y, ref + 1)
def test_unaligned_static_parameter(self):
def gen_inp():
inp = torch.ones([20], device="cuda")
return [inp[1:]]
def foo(args):
x = args[0]
args.clear()
return (x + x,)
foo_cg = self.cudagraphify_impl(foo, gen_inp(), (0,))
for _ in range(3):
out = foo_cg(gen_inp())
self.assertEqual(out, foo(gen_inp()))
del out
node = self.curr_node()
self.assertEqual(node.static_input_data_ptrs, [None])