-
Notifications
You must be signed in to change notification settings - Fork 257
/
Copy pathgeneric_join.html
1102 lines (899 loc) ยท 86.1 KB
/
generic_join.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Distributed Training with Uneven Inputs Using the Join Context Manager — PyTorch Tutorials 1.10.2+cu102 documentation</title>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!-- <link rel="stylesheet" href="../_static/pygments.css" type="text/css" /> -->
<link rel="stylesheet" href="../_static/copybutton.css" type="text/css" />
<link rel="stylesheet" href="../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.11/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="../_static/katex-math.css" type="text/css" />
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="iOS์์์ ์ด๋ฏธ์ง ๋ถํ DeepLapV3" href="../beginner/deeplabv3_on_ios.html" />
<link rel="prev" title="๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํ์ต" href="ddp_pipeline.html" />
<script src="../_static/js/modernizr.min.js"></script>
<!-- Preload the theme fonts -->
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-book.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium-italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<!-- Preload the katex fonts -->
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Math-Italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Main-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Main-Bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size1-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size4-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size2-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size3-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Caligraphic-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.2/css/all.css" integrity="sha384-vSIIfh2YWi9wW0r9iZe7RJPrKwp6bG+s9QZMoITbCckVJqGCCRhc+ccxNcdpHuYu" crossorigin="anonymous">
</head>
<div class="container-fluid header-holder tutorials-header" id="header-holder">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.kr/" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li>
<a href="https://pytorch.kr/get-started">์์ํ๊ธฐ</a>
</li>
<li class="active">
<a href="https://tutorials.pytorch.kr">ํํ ๋ฆฌ์ผ</a>
</li>
<li>
<a href="https://pytorch.kr/hub">ํ๋ธ</a>
</li>
<li>
<a href="https://discuss.pytorch.kr">์ปค๋ฎค๋ํฐ</a>
</li>
</ul>
</div>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<body class="pytorch-body">
<div class="table-of-contents-link-wrapper">
<span>Table of Contents</span>
<a href="#" class="toggle-table-of-contents" data-behavior="toggle-table-of-contents"></a>
</div>
<nav data-toggle="wy-nav-shift" class="pytorch-left-menu" id="pytorch-left-menu">
<div class="pytorch-side-scroll">
<div class="pytorch-menu pytorch-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<div class="pytorch-left-menu-search">
<div class="version">
1.10.2+cu102
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search Tutorials" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<p class="caption"><span class="caption-text">ํ์ดํ ์น(PyTorch) ๋ ์ํผ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../recipes/recipes_index.html">๋ชจ๋ ๋ ์ํผ ๋ณด๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../prototype/prototype_index.html">๋ชจ๋ ํ๋กํ ํ์
๋ ์ํผ ๋ณด๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">ํ์ดํ ์น(PyTorch) ์์ํ๊ธฐ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/intro.html">ํ์ดํ ์น(PyTorch) ๊ธฐ๋ณธ ์ตํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/quickstart_tutorial.html">๋น ๋ฅธ ์์(Quickstart)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/tensorqs_tutorial.html">ํ
์(Tensor)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/data_tutorial.html">Dataset๊ณผ DataLoader</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/transforms_tutorial.html">๋ณํ(Transform)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/buildmodel_tutorial.html">์ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์ฑํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/autogradqs_tutorial.html"><code class="docutils literal notranslate"><span class="pre">torch.autograd</span></code>๋ฅผ ์ฌ์ฉํ ์๋ ๋ฏธ๋ถ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/optimization_tutorial.html">๋ชจ๋ธ ๋งค๊ฐ๋ณ์ ์ต์ ํํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/saveloadrun_tutorial.html">๋ชจ๋ธ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">Introduction to PyTorch on YouTube</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt.html">Introduction to PyTorch - YouTube Series</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/introyt1_tutorial.html">Introduction to PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/tensors_deeper_tutorial.html">Introduction to PyTorch Tensors</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/autogradyt_tutorial.html">The Fundamentals of Autograd</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/modelsyt_tutorial.html">Building Models with PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/tensorboardyt_tutorial.html">PyTorch TensorBoard Support</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/trainingyt.html">Training with PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/captumyt.html">Model Understanding with Captum</a></li>
</ul>
<p class="caption"><span class="caption-text">ํ์ดํ ์น(PyTorch) ๋ฐฐ์ฐ๊ธฐ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/deep_learning_60min_blitz.html">PyTorch๋ก ๋ฅ๋ฌ๋ํ๊ธฐ: 60๋ถ๋ง์ ๋์ฅ๋ด๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/pytorch_with_examples.html">์์ ๋ก ๋ฐฐ์ฐ๋ ํ์ดํ ์น(PyTorch)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/nn_tutorial.html"><cite>torch.nn</cite> ์ด <em>์ค์ ๋ก</em> ๋ฌด์์ธ๊ฐ์?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/tensorboard_tutorial.html">TensorBoard๋ก ๋ชจ๋ธ, ๋ฐ์ดํฐ, ํ์ต ์๊ฐํํ๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">์ด๋ฏธ์ง/๋น๋์ค</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/torchvision_tutorial.html">TorchVision ๊ฐ์ฒด ๊ฒ์ถ ๋ฏธ์ธ์กฐ์ (Finetuning) ํํ ๋ฆฌ์ผ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/transfer_learning_tutorial.html">์ปดํจํฐ ๋น์ (Vision)์ ์ํ ์ ์ดํ์ต(Transfer Learning)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/fgsm_tutorial.html">์ ๋์ ์์ ์์ฑ(Adversarial Example Generation)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/dcgan_faces_tutorial.html">DCGAN ํํ ๋ฆฌ์ผ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/vt_tutorial.html">๋ฐฐํฌ๋ฅผ ์ํ ๋น์ ํธ๋์คํฌ๋จธ(Vision Transformer) ๋ชจ๋ธ ์ต์ ํํ๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">์ค๋์ค</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_io_tutorial.html">Audio I/O</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_resampling_tutorial.html">Audio Resampling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_data_augmentation_tutorial.html">Audio Data Augmentation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_feature_extractions_tutorial.html">Audio Feature Extractions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_feature_augmentation_tutorial.html">Audio Feature Augmentation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_datasets_tutorial.html">Audio Datasets</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/speech_recognition_pipeline_tutorial.html">Speech Recognition with Wav2Vec2</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/speech_command_classification_with_torchaudio_tutorial.html">Speech Command Classification with torchaudio</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/text_to_speech_with_torchaudio.html">Text-to-speech with torchaudio</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/forced_alignment_with_torchaudio_tutorial.html">Forced Alignment with Wav2Vec2</a></li>
</ul>
<p class="caption"><span class="caption-text">ํ
์คํธ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/transformer_tutorial.html">nn.Transformer ์ TorchText ๋ก ์ํ์ค-ํฌ-์ํ์ค(Sequence-to-Sequence) ๋ชจ๋ธ๋งํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/char_rnn_classification_tutorial.html">๊ธฐ์ด๋ถํฐ ์์ํ๋ NLP: ๋ฌธ์-๋จ์ RNN์ผ๋ก ์ด๋ฆ ๋ถ๋ฅํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/char_rnn_generation_tutorial.html">๊ธฐ์ด๋ถํฐ ์์ํ๋ NLP: ๋ฌธ์-๋จ์ RNN์ผ๋ก ์ด๋ฆ ์์ฑํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/seq2seq_translation_tutorial.html">๊ธฐ์ด๋ถํฐ ์์ํ๋ NLP: Sequence to Sequence ๋คํธ์ํฌ์ Attention์ ์ด์ฉํ ๋ฒ์ญ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/text_sentiment_ngrams_tutorial.html">torchtext ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ํ
์คํธ ๋ถ๋ฅํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/translation_transformer.html">nn.Transformer์ torchtext๋ก ์ธ์ด ๋ฒ์ญํ๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">๊ฐํํ์ต</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/reinforcement_q_learning.html">๊ฐํ ํ์ต (DQN) ํํ ๋ฆฌ์ผ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/mario_rl_tutorial.html">Train a Mario-playing RL Agent</a></li>
</ul>
<p class="caption"><span class="caption-text">PyTorch ๋ชจ๋ธ์ ํ๋ก๋์
ํ๊ฒฝ์ ๋ฐฐํฌํ๊ธฐ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/flask_rest_api_tutorial.html">Flask๋ฅผ ์ฌ์ฉํ์ฌ Python์์ PyTorch๋ฅผ REST API๋ก ๋ฐฐํฌํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/Intro_to_TorchScript_tutorial.html">TorchScript ์๊ฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="cpp_export.html">C++์์ TorchScript ๋ชจ๋ธ ๋ก๋ฉํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="super_resolution_with_onnxruntime.html">(์ ํ) PyTorch ๋ชจ๋ธ์ ONNX์ผ๋ก ๋ณํํ๊ณ ONNX ๋ฐํ์์์ ์คํํ๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">Code Transforms with FX</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/fx_conv_bn_fuser.html">(๋ฒ ํ) FX์์ ํฉ์ฑ๊ณฑ/๋ฐฐ์น ์ ๊ทํ(Convolution/Batch Norm) ๊ฒฐํฉ๊ธฐ(Fuser) ๋ง๋ค๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/fx_profiling_tutorial.html">(beta) Building a Simple CPU Performance Profiler with FX</a></li>
</ul>
<p class="caption"><span class="caption-text">ํ๋ก ํธ์๋ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/memory_format_tutorial.html">(๋ฒ ํ) PyTorch๋ฅผ ์ฌ์ฉํ Channels Last ๋ฉ๋ชจ๋ฆฌ ํ์</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/forward_ad_usage.html">Forward-mode Automatic Differentiation (Beta)</a></li>
<li class="toctree-l1"><a class="reference internal" href="cpp_frontend.html">PyTorch C++ ํ๋ก ํธ์๋ ์ฌ์ฉํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="torch-script-parallelism.html">TorchScript์ ๋์ ๋ณ๋ ฌ ์ฒ๋ฆฌ(Dynamic Parallelism)</a></li>
<li class="toctree-l1"><a class="reference internal" href="cpp_autograd.html">C++ ํ๋ก ํธ์๋์ ์๋ ๋ฏธ๋ถ (autograd)</a></li>
</ul>
<p class="caption"><span class="caption-text">PyTorch ํ์ฅํ๊ธฐ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/custom_function_double_backward_tutorial.html">Double Backward with Custom Functions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/custom_function_conv_bn_tutorial.html">Fusing Convolution and Batch Norm using Custom Function</a></li>
<li class="toctree-l1"><a class="reference internal" href="cpp_extension.html">Custom C++ and CUDA Extensions</a></li>
<li class="toctree-l1"><a class="reference internal" href="torch_script_custom_ops.html">Extending TorchScript with Custom C++ Operators</a></li>
<li class="toctree-l1"><a class="reference internal" href="torch_script_custom_classes.html">์ปค์คํ
C++ ํด๋์ค๋ก TorchScript ํ์ฅํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="dispatcher.html">Registering a Dispatched Operator in C++</a></li>
<li class="toctree-l1"><a class="reference internal" href="extend_dispatcher.html">Extending dispatcher for a new backend in C++</a></li>
</ul>
<p class="caption"><span class="caption-text">๋ชจ๋ธ ์ต์ ํ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/profiler.html">PyTorch ๋ชจ๋ ํ๋กํ์ผ๋ง ํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/tensorboard_profiler_tutorial.html">PyTorch Profiler With TensorBoard</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/hyperparameter_tuning_tutorial.html">Hyperparameter tuning with Ray Tune</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/vt_tutorial.html">๋ฐฐํฌ๋ฅผ ์ํ ๋น์ ํธ๋์คํฌ๋จธ(Vision Transformer) ๋ชจ๋ธ ์ต์ ํํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/parametrizations.html">Parametrizations Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/pruning_tutorial.html">๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ(Pruning) ํํ ๋ฆฌ์ผ</a></li>
<li class="toctree-l1"><a class="reference internal" href="dynamic_quantization_tutorial.html">(๋ฒ ํ) LSTM ๊ธฐ๋ฐ ๋จ์ด ๋จ์ ์ธ์ด ๋ชจ๋ธ์ ๋์ ์์ํ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/dynamic_quantization_bert_tutorial.html">(๋ฒ ํ) BERT ๋ชจ๋ธ ๋์ ์์ํํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/quantized_transfer_learning_tutorial.html">(๋ฒ ํ) ์ปดํจํฐ ๋น์ ํํ ๋ฆฌ์ผ์ ์ํ ์์ํ๋ ์ ์ดํ์ต(Quantized Transfer Learning)</a></li>
<li class="toctree-l1"><a class="reference internal" href="static_quantization_tutorial.html">(beta) Static Quantization with Eager Mode in PyTorch</a></li>
</ul>
<p class="caption"><span class="caption-text">๋ณ๋ ฌ ๋ฐ ๋ถ์ฐ ํ์ต</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../beginner/dist_overview.html">PyTorch Distributed Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/model_parallel_tutorial.html">๋จ์ผ ๋จธ์ ์ ์ฌ์ฉํ ๋ชจ๋ธ ๋ณ๋ ฌํ ๋ชจ๋ฒ ์ฌ๋ก</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/ddp_tutorial.html">๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ ์์ํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/dist_tuto.html">PyTorch๋ก ๋ถ์ฐ ์ดํ๋ฆฌ์ผ์ด์
๊ฐ๋ฐํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/rpc_tutorial.html">Getting Started with Distributed RPC Framework</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/rpc_param_server_tutorial.html">Implementing a Parameter Server Using Distributed RPC Framework</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/dist_pipeline_parallel_tutorial.html">Distributed Pipeline Parallelism Using RPC</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/rpc_async_execution.html">Implementing Batch RPC Processing Using Asynchronous Executions</a></li>
<li class="toctree-l1"><a class="reference internal" href="rpc_ddp_tutorial.html">๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ(DDP)๊ณผ ๋ถ์ฐ RPC ํ๋ ์์ํฌ ๊ฒฐํฉ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/pipeline_tutorial.html">ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ๋ก ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํ์ต์ํค๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="ddp_pipeline.html">๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํ์ต</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Distributed Training with Uneven Inputs Using the Join Context Manager</a></li>
</ul>
<p class="caption"><span class="caption-text">Mobile</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/deeplabv3_on_ios.html">iOS์์์ ์ด๋ฏธ์ง ๋ถํ DeepLapV3</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/deeplabv3_on_android.html">์๋๋ก์ด๋์์์ ์ด๋ฏธ์ง ๋ถํ DeepLapV3</a></li>
</ul>
</div>
</div>
</nav>
<div class="pytorch-container">
<div class="pytorch-page-level-bar" id="pytorch-page-level-bar">
<div class="pytorch-breadcrumbs-wrapper">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="pytorch-breadcrumbs">
<li>
<a href="../index.html">
Tutorials
</a> >
</li>
<li>Distributed Training with Uneven Inputs Using the Join Context Manager</li>
<li class="pytorch-breadcrumbs-aside">
<a href="../_sources/advanced/generic_join.rst.txt" rel="nofollow"><img src="../_static/images/view-page-source-icon.svg"></a>
</li>
</ul>
</div>
</div>
<div class="pytorch-shortcuts-wrapper" id="pytorch-shortcuts-wrapper">
Shortcuts
</div>
</div>
<section data-toggle="wy-nav-shift" id="pytorch-content-wrap" class="pytorch-content-wrap">
<div class="pytorch-content-left">
<div class="pytorch-call-to-action-links">
<div id="tutorial-type">advanced/generic_join</div>
<div id="google-colab-link">
<img class="call-to-action-img" src="../_static/images/pytorch-colab.svg"/>
<div class="call-to-action-desktop-view">Run in Google Colab</div>
<div class="call-to-action-mobile-view">Colab</div>
</div>
<div id="download-notebook-link">
<img class="call-to-action-notebook-img" src="../_static/images/pytorch-download.svg"/>
<div class="call-to-action-desktop-view">Download Notebook</div>
<div class="call-to-action-mobile-view">Notebook</div>
</div>
<div id="github-view-link">
<img class="call-to-action-img" src="../_static/images/pytorch-github.svg"/>
<div class="call-to-action-desktop-view">View on GitHub</div>
<div class="call-to-action-mobile-view">GitHub</div>
</div>
</div>
<div class="rst-content">
<div role="main" class="main-content" itemscope="itemscope" itemtype="http://schema.org/Article">
<article itemprop="articleBody" id="pytorch-article" class="pytorch-article">
<div class="section" id="distributed-training-with-uneven-inputs-using-the-join-context-manager">
<h1>Distributed Training with Uneven Inputs Using the Join Context Manager<a class="headerlink" href="#distributed-training-with-uneven-inputs-using-the-join-context-manager" title="Permalink to this headline">ยถ</a></h1>
<p><strong>Author</strong>: <a class="reference external" href="https://github.com/andwgu">Andrew Gu</a></p>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last"><code class="docutils literal notranslate"><span class="pre">Join</span></code> is introduced in PyTorch 1.10 as a prototype feature. This
API is subject to change.</p>
</div>
<p>In this tutorial, you will see:</p>
<ul class="simple">
<li>An overview of the <a class="reference external" href="https://pytorch.org/docs/master/distributed.algorithms.join.html">Join</a> context manager.</li>
<li>An example of how to use the context manager with <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code>.</li>
<li>An example of how to use the context manager with both
<code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> and <code class="docutils literal notranslate"><span class="pre">ZeroRedundancyOptimizer</span></code>.</li>
<li>An example of passing in keyword arguments to the context manager.</li>
<li>A dive into how the <a class="reference external" href="https://pytorch.org/docs/master/distributed.algorithms.join.html">Join</a> context manager works.</li>
<li>An example showing how to make a toy class compatible with the context
manager.</li>
</ul>
<div class="section" id="requirements">
<h2>Requirements<a class="headerlink" href="#requirements" title="Permalink to this headline">ยถ</a></h2>
<ul class="simple">
<li>PyTorch 1.10+</li>
<li><a class="reference external" href="https://tutorials.pytorch.kr/intermediate/ddp_tutorial.html">Getting Started with Distributed Data Parallel</a></li>
<li><a class="reference external" href="https://tutorials.pytorch.kr/recipes/zero_redundancy_optimizer.html">Shard Optimizer States with ZeroRedundancyOptimizer</a></li>
</ul>
</div>
<div class="section" id="what-is-join">
<h2>What is <code class="docutils literal notranslate"><span class="pre">Join</span></code>?<a class="headerlink" href="#what-is-join" title="Permalink to this headline">ยถ</a></h2>
<p>In <a class="reference external" href="https://tutorials.pytorch.kr/intermediate/ddp_tutorial.html#basic-use-case">Getting Started with Distributed Data Parallel - Basic Use Case</a>, you saw
the general skeleton for using <a class="reference external" href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html">DistributedDataParallel</a> to perform data
parallel training. This implicitly schedules all-reduces in each backward pass
to synchronize gradients across ranks. Such <a class="reference external" href="https://pytorch.org/docs/stable/distributed.html">collective communications</a> require participation
from all ranks in the process group, so if a rank has fewer inputs, then the
other ranks will hang or error (depending on the backend). More generally, this
problem persists for any class that performs per-iteration synchronous
collective communications.</p>
<p><code class="docutils literal notranslate"><span class="pre">Join</span></code> is a context manager to be used around your per-rank training loop to
facilitate training with uneven inputs. The context manager allows the ranks
that exhaust their inputs early (i.e. <em>join</em> early) to shadow the collective
communications performed by those that have not yet joined. The ways in which
the communications are shadowed are specified by hooks.</p>
</div>
<div class="section" id="using-join-with-distributeddataparallel">
<h2>Using <code class="docutils literal notranslate"><span class="pre">Join</span></code> with <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code><a class="headerlink" href="#using-join-with-distributeddataparallel" title="Permalink to this headline">ยถ</a></h2>
<p>PyTorchโs <a class="reference external" href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html">DistributedDataParallel</a> works out-of-the-box with the <code class="docutils literal notranslate"><span class="pre">Join</span></code>
context manager. Here is an example usage:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
<span class="kn">import</span> <span class="nn">torch.multiprocessing</span> <span class="k">as</span> <span class="nn">mp</span>
<span class="kn">from</span> <span class="nn">torch.distributed.algorithms.join</span> <span class="kn">import</span> <span class="n">Join</span>
<span class="kn">from</span> <span class="nn">torch.nn.parallel</span> <span class="kn">import</span> <span class="n">DistributedDataParallel</span> <span class="k">as</span> <span class="n">DDP</span>
<span class="n">BACKEND</span> <span class="o">=</span> <span class="s2">"nccl"</span>
<span class="n">WORLD_SIZE</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">NUM_INPUTS</span> <span class="o">=</span> <span class="mi">5</span>
<span class="k">def</span> <span class="nf">worker</span><span class="p">(</span><span class="n">rank</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'MASTER_ADDR'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'localhost'</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'MASTER_PORT'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'29500'</span>
<span class="n">dist</span><span class="o">.</span><span class="n">init_process_group</span><span class="p">(</span><span class="n">BACKEND</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="o">=</span><span class="n">WORLD_SIZE</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">DDP</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">rank</span><span class="p">),</span> <span class="n">device_ids</span><span class="o">=</span><span class="p">[</span><span class="n">rank</span><span class="p">])</span>
<span class="c1"># Rank 1 gets one more input than rank 0</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">NUM_INPUTS</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)]</span>
<span class="n">num_inputs</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">with</span> <span class="n">Join</span><span class="p">([</span><span class="n">model</span><span class="p">]):</span>
<span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
<span class="n">num_inputs</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Rank </span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s2"> has exhausted all </span><span class="si">{</span><span class="n">num_inputs</span><span class="si">}</span><span class="s2"> of its inputs!"</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="n">mp</span><span class="o">.</span><span class="n">spawn</span><span class="p">(</span><span class="n">worker</span><span class="p">,</span> <span class="n">nprocs</span><span class="o">=</span><span class="n">WORLD_SIZE</span><span class="p">,</span> <span class="n">join</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"__main__"</span><span class="p">:</span>
<span class="n">main</span><span class="p">()</span>
</pre></div>
</div>
<p>This produces the following output (where the <code class="docutils literal notranslate"><span class="pre">print()</span></code> s from rank 0 and
rank 1 may be arbitrarily ordered):</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>Rank 0 has exhausted all 5 of its inputs!
Rank 1 has exhausted all 6 of its inputs!
</pre></div>
</div>
<div class="admonition note">
<p class="first admonition-title">Note</p>
<p class="last"><a class="reference external" href="https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html">DistributedDataParallel</a> provided its own <a class="reference external" href="https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.join">join()</a> context manager
prior to the introduction of this generic <code class="docutils literal notranslate"><span class="pre">Join</span></code> context manager. In the
above example, using <code class="docutils literal notranslate"><span class="pre">with</span> <span class="pre">Join([model]):</span></code> is equivalent to using
<code class="docutils literal notranslate"><span class="pre">with</span> <span class="pre">model.join():</span></code>. One limitation of the existing
<code class="docutils literal notranslate"><span class="pre">DistributedDataParallel.join()</span></code> is that it does not allow multiple
participating classes, e.g. <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> and
<a class="reference external" href="https://pytorch.org/docs/stable/distributed.optim.html">ZeroRedundancyOptimizer</a> together.</p>
</div>
</div>
<div class="section" id="using-join-with-distributeddataparallel-and-zeroredundancyoptimizer">
<h2>Using <code class="docutils literal notranslate"><span class="pre">Join</span></code> with <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> and <code class="docutils literal notranslate"><span class="pre">ZeroRedundancyOptimizer</span></code><a class="headerlink" href="#using-join-with-distributeddataparallel-and-zeroredundancyoptimizer" title="Permalink to this headline">ยถ</a></h2>
<p>The <code class="docutils literal notranslate"><span class="pre">Join</span></code> context manager works not only with a single class but also with
multiple classes together. PyTorchโs <code class="docutils literal notranslate"><span class="pre">ZeroRedundancyOptimizer</span></code> is also
compatible with the context manager, so here, we examine how to modify the
previous example to use both <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> and
<code class="docutils literal notranslate"><span class="pre">ZeroRedundancyOptimizer</span></code>:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">torch.distributed.optim</span> <span class="kn">import</span> <span class="n">ZeroRedundancyOptimizer</span> <span class="k">as</span> <span class="n">ZeRO</span>
<span class="kn">from</span> <span class="nn">torch.optim</span> <span class="kn">import</span> <span class="n">Adam</span>
<span class="k">def</span> <span class="nf">worker</span><span class="p">(</span><span class="n">rank</span><span class="p">):</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'MASTER_ADDR'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'localhost'</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'MASTER_PORT'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'29500'</span>
<span class="n">dist</span><span class="o">.</span><span class="n">init_process_group</span><span class="p">(</span><span class="n">BACKEND</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="o">=</span><span class="n">WORLD_SIZE</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">DDP</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">rank</span><span class="p">),</span> <span class="n">device_ids</span><span class="o">=</span><span class="p">[</span><span class="n">rank</span><span class="p">])</span>
<span class="n">optim</span> <span class="o">=</span> <span class="n">ZeRO</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">Adam</span><span class="p">,</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">)</span>
<span class="c1"># Rank 1 gets one more input than rank 0</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">NUM_INPUTS</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)]</span>
<span class="n">num_inputs</span> <span class="o">=</span> <span class="mi">0</span>
<span class="c1"># Pass both `model` and `optim` into `Join()`</span>
<span class="k">with</span> <span class="n">Join</span><span class="p">([</span><span class="n">model</span><span class="p">,</span> <span class="n">optim</span><span class="p">]):</span>
<span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
<span class="n">num_inputs</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="n">loss</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span>
<span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">optim</span><span class="o">.</span><span class="n">step</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Rank </span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s2"> has exhausted all </span><span class="si">{</span><span class="n">num_inputs</span><span class="si">}</span><span class="s2"> of its inputs!"</span><span class="p">)</span>
</pre></div>
</div>
<p>This will yield the same output as before. The notable change was
additionally passing in the <code class="docutils literal notranslate"><span class="pre">ZeroRedundancyOptimizer</span></code> instance into
<code class="docutils literal notranslate"><span class="pre">Join()</span></code>.</p>
</div>
<div class="section" id="passing-keyword-arguments">
<h2>Passing Keyword Arguments<a class="headerlink" href="#passing-keyword-arguments" title="Permalink to this headline">ยถ</a></h2>
<p>Classes may provide keyword arguments that modify their behavior in the context
manager at run time. For example, <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> provides an
argument <code class="docutils literal notranslate"><span class="pre">divide_by_initial_world_size</span></code>, which determines if gradients are
divided by the initial world size or by the effective world size (i.e. number
of non-joined ranks). Such keyword arguments can be passed directly into the
context manager.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">with</span> <span class="n">Join</span><span class="p">([</span><span class="n">model</span><span class="p">,</span> <span class="n">optim</span><span class="p">],</span> <span class="n">divide_by_initial_world_size</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
<span class="k">for</span> <span class="nb">input</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
<span class="o">...</span>
</pre></div>
</div>
<div class="admonition warning">
<p class="first admonition-title">Warning</p>
<p class="last">The keyword arguments passed into the context manager are shared across
all participating classes. This should not be a limitation since we do
not expect cases where multiple <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> s need differing settings
of the same argument. Nonetheless, this is something to keep in mind.</p>
</div>
</div>
<div class="section" id="how-does-join-work">
<h2>How Does <code class="docutils literal notranslate"><span class="pre">Join</span></code> Work?<a class="headerlink" href="#how-does-join-work" title="Permalink to this headline">ยถ</a></h2>
<p>Now that we have seen some preliminary examples of how to use the <code class="docutils literal notranslate"><span class="pre">Join</span></code>
context manager, let us delve deeper into how it works. This will provide a
greater insight into the full capability that it offers and prepare you to make
your own custom classes compatible. Here, we will go over the <code class="docutils literal notranslate"><span class="pre">Join</span></code> class as
well as the supporting classes <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> and <code class="docutils literal notranslate"><span class="pre">JoinHook</span></code>.</p>
<div class="section" id="joinable">
<h3><code class="docutils literal notranslate"><span class="pre">Joinable</span></code><a class="headerlink" href="#joinable" title="Permalink to this headline">ยถ</a></h3>
<p>To begin, classes compatible with the <code class="docutils literal notranslate"><span class="pre">Join</span></code> context manager must inherit
from the abstract base class <code class="docutils literal notranslate"><span class="pre">Joinable</span></code>. In particular, a <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> must
implement:</p>
<ul class="simple">
<li><code class="docutils literal notranslate"><span class="pre">join_hook(self,</span> <span class="pre">**kwargs)</span> <span class="pre">-></span> <span class="pre">JoinHook</span></code></li>
</ul>
<p>This returns the <code class="docutils literal notranslate"><span class="pre">JoinHook</span></code> instance for the <code class="docutils literal notranslate"><span class="pre">Joinable</span></code>, determining how
joined processes should shadow the per-iteration collective communications
performed by the <code class="docutils literal notranslate"><span class="pre">Joinable</span></code>.</p>
<ul class="simple">
<li><code class="docutils literal notranslate"><span class="pre">join_device(self)</span> <span class="pre">-></span> <span class="pre">torch.device</span></code></li>
</ul>
<p>This returns a device to be used by the <code class="docutils literal notranslate"><span class="pre">Join</span></code> context manager to perform
collective communications, e.g. <code class="docutils literal notranslate"><span class="pre">torch.device("cuda:0")</span></code> or
<code class="docutils literal notranslate"><span class="pre">torch.device("cpu")</span></code>.</p>
<ul class="simple">
<li><code class="docutils literal notranslate"><span class="pre">join_process_group(self)</span> <span class="pre">-></span> <span class="pre">ProcessGroup</span></code></li>
</ul>
<p>This returns the process group to be used by the <code class="docutils literal notranslate"><span class="pre">Join</span></code> context manager to
perform collective communications.</p>
<p>In particular, the <code class="docutils literal notranslate"><span class="pre">join_device</span></code> and <code class="docutils literal notranslate"><span class="pre">join_process_group</span></code> are required
attributes to ensure that the context manager can schedule collective
communications between joined and non-joined processes. One usage is to count
the number of non-joined processes on each iteration using an all-reduce.
Another usage is for implementing the mechanism required for
<code class="docutils literal notranslate"><span class="pre">throw_on_early_termination=True</span></code>, which we will explain later below.</p>
<p><code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> and <code class="docutils literal notranslate"><span class="pre">ZeroRedundancyOptimizer</span></code> already inherit
from <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> and implement the above methods, which is why we could
directly use them in the previous examples.</p>
<p><code class="docutils literal notranslate"><span class="pre">Joinable</span></code> classes should make sure to call the <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> constructor
since it initializes a <code class="docutils literal notranslate"><span class="pre">JoinConfig</span></code> instance, which is used internally by
the context manager to ensure correctness. This will be saved in each
<code class="docutils literal notranslate"><span class="pre">Joinable</span></code> as a field <code class="docutils literal notranslate"><span class="pre">_join_config</span></code>.</p>
</div>
<div class="section" id="joinhook">
<h3><code class="docutils literal notranslate"><span class="pre">JoinHook</span></code><a class="headerlink" href="#joinhook" title="Permalink to this headline">ยถ</a></h3>
<p>Next, let us break down the <code class="docutils literal notranslate"><span class="pre">JoinHook</span></code> class. A <code class="docutils literal notranslate"><span class="pre">JoinHook</span></code> provides two
entry points into a context manager:</p>
<ul class="simple">
<li><code class="docutils literal notranslate"><span class="pre">main_hook(self)</span> <span class="pre">-></span> <span class="pre">None</span></code></li>
</ul>
<p>This hook is called repeatedly by each joined rank while there exists a rank
that has not yet joined. It is meant to shadow the collective communications
performed by the <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> in each training iteration (e.g. in one forward
pass, backward pass, and optimizer step).</p>
<ul class="simple">
<li><code class="docutils literal notranslate"><span class="pre">post_hook(self,</span> <span class="pre">is_last_joiner:</span> <span class="pre">bool)</span> <span class="pre">-></span> <span class="pre">None</span></code></li>
</ul>
<p>This hook is called once all ranks have joined. It is passed an additional
<code class="docutils literal notranslate"><span class="pre">bool</span></code> argument <code class="docutils literal notranslate"><span class="pre">is_last_joiner</span></code>, which indicates if the rank was one of
the last to join. The argument may be useful for synchronization.</p>
<p>To give concrete examples of what these hooks may look like, the provided
<code class="docutils literal notranslate"><span class="pre">ZeroRedundancyOptimizer</span></code> main hook performs an optimizer step per normal
since the joined rank is still responsible for updating and synchronizing its
shard of the parameters, and the provided <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> post-hook
broadcasts the final updated model from one of the last joining ranks to ensure
that it is the same across all ranks.</p>
</div>
<div class="section" id="join">
<h3><code class="docutils literal notranslate"><span class="pre">Join</span></code><a class="headerlink" href="#join" title="Permalink to this headline">ยถ</a></h3>
<p>Finally, let us examine how these fit into the <code class="docutils literal notranslate"><span class="pre">Join</span></code> class itself.</p>
<ul class="simple">
<li><code class="docutils literal notranslate"><span class="pre">__init__(self,</span> <span class="pre">joinables:</span> <span class="pre">List[Joinable],</span> <span class="pre">enable:</span> <span class="pre">bool</span> <span class="pre">=</span> <span class="pre">True,</span> <span class="pre">throw_on_early_termination:</span> <span class="pre">bool</span> <span class="pre">=</span> <span class="pre">False)</span></code></li>
</ul>
<p>As we saw in the previous examples, the constructor takes in a list of the
<code class="docutils literal notranslate"><span class="pre">Joinable</span></code> s that participate in the training loop. These should be the
classes that perform collective communications in each iteration.</p>
<p><code class="docutils literal notranslate"><span class="pre">enable</span></code> is a <code class="docutils literal notranslate"><span class="pre">bool</span></code> that can be set to <code class="docutils literal notranslate"><span class="pre">False</span></code> if you know that there
will not be uneven inputs, in which case the context manager becomes vacuous
similar to <code class="docutils literal notranslate"><span class="pre">contextlib.nullcontext()</span></code>. This also may disable join-related
computation in the participating <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> s.</p>
<p><code class="docutils literal notranslate"><span class="pre">throw_on_early_termination</span></code> is a <code class="docutils literal notranslate"><span class="pre">bool</span></code> that can be set to <code class="docutils literal notranslate"><span class="pre">True</span></code> to
have each rank raise an exception the moment that uneven inputs are detected.
This is useful for cases that do not conform to the context managerโs
requirements, which is most typically when there are collective communications
from different classes that may be arbitrarily interleaved, such as when using
<code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> with a model that has <code class="docutils literal notranslate"><span class="pre">SyncBatchNorm</span></code> layers. In
such cases, this argument should be set to <code class="docutils literal notranslate"><span class="pre">True</span></code> so that the application
logic can catch the exception and determine how to proceed.</p>
<ul class="simple">
<li>The core logic occurs in the <code class="docutils literal notranslate"><span class="pre">__exit__()</span></code> method, which loops while there
exists a non-joined rank, calling each <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> โs main hook, and
then once all ranks have joined, calls their post hooks. Both the main hooks
and post-hooks are iterated over in the order that the <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> s are
passed in.</li>
<li>The context manager requires a heartbeat from non-joined processes. As such,
each <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> class should make a call to <code class="docutils literal notranslate"><span class="pre">Join.notify_join_context()</span></code>
before its per-iteration collective communications. The context manager will
ensure that only the first <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> passed in actually sends the
heartbeat.</li>
</ul>
<div class="admonition warning">
<p class="first admonition-title">Warning</p>
<p class="last">As mentioned above regarding <code class="docutils literal notranslate"><span class="pre">throw_on_early_termination</span></code>, the
<code class="docutils literal notranslate"><span class="pre">Join</span></code> context manager is not compatible with certain compositions of
classes. The <code class="docutils literal notranslate"><span class="pre">Joinable</span></code> โs <code class="docutils literal notranslate"><span class="pre">JoinHook</span></code> s must be serializable since each
hook is fully executed before proceeding to the next. In other words, two
hooks cannot overlap. Moreover, currently, both the main hooks and post-
hooks are iterated over in the same deterministic order. If this appears to
be a major limitation, we may modify the API to permit a customizable
ordering.</p>
</div>
</div>
</div>
<div class="section" id="making-a-toy-class-work-with-join">
<h2>Making a Toy Class Work with <code class="docutils literal notranslate"><span class="pre">Join</span></code><a class="headerlink" href="#making-a-toy-class-work-with-join" title="Permalink to this headline">ยถ</a></h2>
<p>Since the previous section introduced several concepts, let us see them in
practice with a toy example. Here, we will implement a class that counts the
number of inputs that are seen across all ranks before its rank joins. This
should provide a basic idea of how you may make your own class compatible
with the <code class="docutils literal notranslate"><span class="pre">Join</span></code> context manager.</p>
<p>Specifically, the following code has each rank print out (1) the number of
inputs across all ranks that seen before it joins and (2) the total number
of inputs across all ranks.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.distributed</span> <span class="k">as</span> <span class="nn">dist</span>
<span class="kn">import</span> <span class="nn">torch.multiprocessing</span> <span class="k">as</span> <span class="nn">mp</span>
<span class="kn">from</span> <span class="nn">torch.distributed.algorithms.join</span> <span class="kn">import</span> <span class="n">Join</span><span class="p">,</span> <span class="n">Joinable</span><span class="p">,</span> <span class="n">JoinHook</span>
<span class="n">BACKEND</span> <span class="o">=</span> <span class="s2">"nccl"</span>
<span class="n">WORLD_SIZE</span> <span class="o">=</span> <span class="mi">2</span>
<span class="n">NUM_INPUTS</span> <span class="o">=</span> <span class="mi">5</span>
<span class="k">class</span> <span class="nc">CounterJoinHook</span><span class="p">(</span><span class="n">JoinHook</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> Join hook for :class:`Counter`.</span>
<span class="sd"> Arguments:</span>
<span class="sd"> counter (Counter): the :class:`Counter` object using this hook.</span>
<span class="sd"> sync_max_count (bool): whether to sync the max count once all ranks</span>
<span class="sd"> join.</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">counter</span><span class="p">,</span>
<span class="n">sync_max_count</span>
<span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">counter</span> <span class="o">=</span> <span class="n">counter</span>
<span class="bp">self</span><span class="o">.</span><span class="n">sync_max_count</span> <span class="o">=</span> <span class="n">sync_max_count</span>
<span class="k">def</span> <span class="nf">main_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> Shadows the counter's all-reduce by all-reducing a dim-1 zero tensor.</span>
<span class="sd"> """</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">counter</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">dist</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">post_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">is_last_joiner</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> Synchronizes the max count across all :class:`Counter` s if</span>
<span class="sd"> ``sync_max_count=True``.</span>
<span class="sd"> """</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">sync_max_count</span><span class="p">:</span>
<span class="k">return</span>
<span class="n">rank</span> <span class="o">=</span> <span class="n">dist</span><span class="o">.</span><span class="n">get_rank</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">counter</span><span class="o">.</span><span class="n">process_group</span><span class="p">)</span>
<span class="n">common_rank</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">counter</span><span class="o">.</span><span class="n">find_common_rank</span><span class="p">(</span><span class="n">rank</span><span class="p">,</span> <span class="n">is_last_joiner</span><span class="p">)</span>
<span class="k">if</span> <span class="n">rank</span> <span class="o">==</span> <span class="n">common_rank</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">counter</span><span class="o">.</span><span class="n">max_count</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">counter</span><span class="o">.</span><span class="n">count</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span>
<span class="n">dist</span><span class="o">.</span><span class="n">broadcast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">counter</span><span class="o">.</span><span class="n">max_count</span><span class="p">,</span> <span class="n">src</span><span class="o">=</span><span class="n">common_rank</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">Counter</span><span class="p">(</span><span class="n">Joinable</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> Example :class:`Joinable` that counts the number of training iterations</span>
<span class="sd"> that it participates in.</span>
<span class="sd"> """</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">device</span><span class="p">,</span> <span class="n">process_group</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">Counter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">device</span> <span class="o">=</span> <span class="n">device</span>
<span class="bp">self</span><span class="o">.</span><span class="n">process_group</span> <span class="o">=</span> <span class="n">process_group</span>
<span class="bp">self</span><span class="o">.</span><span class="n">count</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">max_count</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
<span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> Counts the number of inputs processed on this iteration by all ranks</span>
<span class="sd"> by all-reducing a dim-1 one tensor; increments its own internal count.</span>
<span class="sd"> """</span>
<span class="n">Join</span><span class="o">.</span><span class="n">notify_join_context</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()</span>
<span class="n">dist</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">t</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">count</span> <span class="o">+=</span> <span class="n">t</span>
<span class="k">def</span> <span class="nf">join_hook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="o">-></span> <span class="n">JoinHook</span><span class="p">:</span>
<span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> Return a join hook that shadows the all-reduce in :meth:`__call__`.</span>
<span class="sd"> This join hook supports the following keyword arguments:</span>
<span class="sd"> sync_max_count (bool, optional): whether to synchronize the maximum</span>
<span class="sd"> count across all ranks once all ranks join; default is ``False``.</span>
<span class="sd"> """</span>
<span class="n">sync_max_count</span> <span class="o">=</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"sync_max_count"</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">return</span> <span class="n">CounterJoinHook</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sync_max_count</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">join_device</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-></span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">device</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">join_process_group</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_group</span>
<span class="k">def</span> <span class="nf">find_common_rank</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rank</span><span class="p">,</span> <span class="n">to_consider</span><span class="p">):</span>
<span class="sa">r</span><span class="sd">"""</span>
<span class="sd"> Returns the max rank of the ones to consider over the process group.</span>
<span class="sd"> """</span>
<span class="n">common_rank</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="n">rank</span> <span class="k">if</span> <span class="n">to_consider</span> <span class="k">else</span> <span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">)</span>
<span class="n">dist</span><span class="o">.</span><span class="n">all_reduce</span><span class="p">(</span><span class="n">common_rank</span><span class="p">,</span> <span class="n">op</span><span class="o">=</span><span class="n">dist</span><span class="o">.</span><span class="n">ReduceOp</span><span class="o">.</span><span class="n">MAX</span><span class="p">,</span> <span class="n">group</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">process_group</span><span class="p">)</span>
<span class="n">common_rank</span> <span class="o">=</span> <span class="n">common_rank</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
<span class="k">return</span> <span class="n">common_rank</span>
<span class="k">def</span> <span class="nf">worker</span><span class="p">(</span><span class="n">rank</span><span class="p">):</span>
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">device_count</span><span class="p">()</span> <span class="o">>=</span> <span class="n">WORLD_SIZE</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'MASTER_ADDR'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'localhost'</span>
<span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">'MASTER_PORT'</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'29500'</span>
<span class="n">dist</span><span class="o">.</span><span class="n">init_process_group</span><span class="p">(</span><span class="n">BACKEND</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="n">rank</span><span class="p">,</span> <span class="n">world_size</span><span class="o">=</span><span class="n">WORLD_SIZE</span><span class="p">)</span>
<span class="n">counter</span> <span class="o">=</span> <span class="n">Counter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="sa">f</span><span class="s2">"cuda:</span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s2">"</span><span class="p">),</span> <span class="n">dist</span><span class="o">.</span><span class="n">group</span><span class="o">.</span><span class="n">WORLD</span><span class="p">)</span>
<span class="n">inputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">float</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">NUM_INPUTS</span> <span class="o">+</span> <span class="n">rank</span><span class="p">)]</span>
<span class="k">with</span> <span class="n">Join</span><span class="p">([</span><span class="n">counter</span><span class="p">],</span> <span class="n">sync_max_count</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span>
<span class="n">counter</span><span class="p">()</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">int</span><span class="p">(</span><span class="n">counter</span><span class="o">.</span><span class="n">count</span><span class="o">.</span><span class="n">item</span><span class="p">())</span><span class="si">}</span><span class="s2"> inputs processed before rank </span><span class="si">{</span><span class="n">rank</span><span class="si">}</span><span class="s2"> joined!"</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">int</span><span class="p">(</span><span class="n">counter</span><span class="o">.</span><span class="n">max_count</span><span class="o">.</span><span class="n">item</span><span class="p">())</span><span class="si">}</span><span class="s2"> inputs processed across all ranks!"</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
<span class="n">mp</span><span class="o">.</span><span class="n">spawn</span><span class="p">(</span><span class="n">worker</span><span class="p">,</span> <span class="n">nprocs</span><span class="o">=</span><span class="n">WORLD_SIZE</span><span class="p">,</span> <span class="n">join</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"__main__"</span><span class="p">:</span>
<span class="n">main</span><span class="p">()</span>
</pre></div>
</div>
<p>Since rank 0 sees 5 inputs and rank 1 sees 6, this yields the output:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>10 inputs processed before rank 0 joined!
11 inputs processed across all ranks!
11 inputs processed before rank 1 joined!
11 inputs processed across all ranks!
</pre></div>
</div>
<p>Some key points to highlight:</p>
<ul class="simple">
<li>A <code class="docutils literal notranslate"><span class="pre">Counter</span></code> instance performs a single all-reduce per iteration, so the
main hook performs a single all-reduce as well to shadow it.</li>
<li>The <code class="docutils literal notranslate"><span class="pre">Counter</span></code> class makes a call to <code class="docutils literal notranslate"><span class="pre">Join.notify_join_context()</span></code> at the
beginning of its <code class="docutils literal notranslate"><span class="pre">__call__()</span></code> method since that is a place before its per-
iteration collective communications (i.e. its all-reduce).</li>
<li>The <code class="docutils literal notranslate"><span class="pre">is_last_joiner</span></code> argument is used to determine the broadcast source in
the post-hooks.</li>
<li>We pass in the <code class="docutils literal notranslate"><span class="pre">sync_max_count</span></code> keyword argument to the context manager,
which is then forwarded to <code class="docutils literal notranslate"><span class="pre">Counter</span></code> โs join hook.</li>
</ul>
</div>
</div>
</article>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="../beginner/deeplabv3_on_ios.html" class="btn btn-neutral float-right" title="iOS์์์ ์ด๋ฏธ์ง ๋ถํ DeepLapV3" accesskey="n" rel="next">Next <img src="../_static/images/chevron-right-orange.svg" class="next-page"></a>
<a href="ddp_pipeline.html" class="btn btn-neutral" title="๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํ์ต" accesskey="p" rel="prev"><img src="../_static/images/chevron-right-orange.svg" class="previous-page"> Previous</a>
</div>
<hr class="rating-hr hr-top">
<div class="rating-container">
<div class="rating-prompt">Rate this Tutorial</div>
<div class="stars-outer">
<i class="far fa-star" title="1 Star" data-behavior="tutorial-rating" data-count="1"></i>
<i class="far fa-star" title="2 Stars" data-behavior="tutorial-rating" data-count="2"></i>
<i class="far fa-star" title="3 Stars" data-behavior="tutorial-rating" data-count="3"></i>
<i class="far fa-star" title="4 Stars" data-behavior="tutorial-rating" data-count="4"></i>
<i class="far fa-star" title="5 Stars" data-behavior="tutorial-rating" data-count="5"></i>
</div>
</div>
<hr class="rating-hr hr-bottom"/>
<div role="contentinfo">
<p>
© Copyright 2021, PyTorch & PyTorch Korea Community.
</p>
</div>
<div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</div>
</footer>
</div>
</div>
<div class="pytorch-content-right" id="pytorch-content-right">
<div class="pytorch-right-menu" id="pytorch-right-menu">
<div class="pytorch-side-scroll" id="pytorch-side-scroll-right">
<ul>
<li><a class="reference internal" href="#">Distributed Training with Uneven Inputs Using the Join Context Manager</a><ul>
<li><a class="reference internal" href="#requirements">Requirements</a></li>
<li><a class="reference internal" href="#what-is-join">What is <code class="docutils literal notranslate"><span class="pre">Join</span></code>?</a></li>
<li><a class="reference internal" href="#using-join-with-distributeddataparallel">Using <code class="docutils literal notranslate"><span class="pre">Join</span></code> with <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code></a></li>
<li><a class="reference internal" href="#using-join-with-distributeddataparallel-and-zeroredundancyoptimizer">Using <code class="docutils literal notranslate"><span class="pre">Join</span></code> with <code class="docutils literal notranslate"><span class="pre">DistributedDataParallel</span></code> and <code class="docutils literal notranslate"><span class="pre">ZeroRedundancyOptimizer</span></code></a></li>
<li><a class="reference internal" href="#passing-keyword-arguments">Passing Keyword Arguments</a></li>
<li><a class="reference internal" href="#how-does-join-work">How Does <code class="docutils literal notranslate"><span class="pre">Join</span></code> Work?</a><ul>
<li><a class="reference internal" href="#joinable"><code class="docutils literal notranslate"><span class="pre">Joinable</span></code></a></li>
<li><a class="reference internal" href="#joinhook"><code class="docutils literal notranslate"><span class="pre">JoinHook</span></code></a></li>
<li><a class="reference internal" href="#join"><code class="docutils literal notranslate"><span class="pre">Join</span></code></a></li>
</ul>
</li>
<li><a class="reference internal" href="#making-a-toy-class-work-with-join">Making a Toy Class Work with <code class="docutils literal notranslate"><span class="pre">Join</span></code></a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
<script type="text/javascript" src="../_static/jquery.js"></script>
<script type="text/javascript" src="../_static/underscore.js"></script>
<script type="text/javascript" src="../_static/doctools.js"></script>
<script type="text/javascript" src="../_static/clipboard.min.js"></script>
<script type="text/javascript" src="../_static/copybutton.js"></script>
<script type="text/javascript" src="https://cdn.jsdelivr.net/npm/katex@0.13.11/dist/katex.min.js"></script>
<script type="text/javascript" src="https://cdn.jsdelivr.net/npm/katex@0.13.11/dist/contrib/auto-render.min.js"></script>
<script type="text/javascript" src="../_static/katex_autorenderer.js"></script>
<script type="text/javascript" src="../_static/js/vendor/popper.min.js"></script>
<script type="text/javascript" src="../_static/js/vendor/bootstrap.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/list.js/1.5.0/list.min.js"></script>
<script type="text/javascript" src="../_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
<script>
//add microsoft link
if(window.location.href.indexOf("/beginner/basics/")!= -1)
{
var url="https://docs.microsoft.com/learn/paths/pytorch-fundamentals/?wt.mc_id=aiml-7486-cxa";
switch(window.location.pathname.split("/").pop().replace('.html',''))
{
case"quickstart_tutorial":
url="https://docs.microsoft.com/learn/modules/intro-machine-learning-pytorch/9-quickstart?WT.mc_id=aiml-7486-cxa";
break;
case"tensorqs_tutorial":
url="https://docs.microsoft.com/learn/modules/intro-machine-learning-pytorch/2-tensors?WT.mc_id=aiml-7486-cxa";
break;
case"data_tutorial":
url="https://docs.microsoft.com/learn/modules/intro-machine-learning-pytorch/3-data?WT.mc_id=aiml-7486-cxa";
break;
case"transforms_tutorial":
url="https://docs.microsoft.com/learn/modules/intro-machine-learning-pytorch/4-transforms?WT.mc_id=aiml-7486-cxa";
break;
case"buildmodel_tutorial":
url="https://docs.microsoft.com/learn/modules/intro-machine-learning-pytorch/5-model?WT.mc_id=aiml-7486-cxa";
break;
case"autogradqs_tutorial":
url="https://docs.microsoft.com/learn/modules/intro-machine-learning-pytorch/6-autograd?WT.mc_id=aiml-7486-cxa";
break;
case"optimization_tutorial":
url="https://docs.microsoft.com/learn/modules/intro-machine-learning-pytorch/7-optimization?WT.mc_id=aiml-7486-cxa";
break;
case"saveloadrun_tutorial":
url="https://docs.microsoft.com/learn/modules/intro-machine-learning-pytorch/8-inference?WT.mc_id=aiml-7486-cxa";
}
$(".pytorch-call-to-action-links").children().first().before("<a href="+url+' data-behavior="call-to-action-event" data-response="Run in Microsoft Learn" target="_blank"><div id="microsoft-learn-link" style="padding-bottom: 0.625rem;border-bottom: 1px solid #f3f4f7;padding-right: 2.5rem;display: -webkit-box; display: -ms-flexbox; isplay: flex; -webkit-box-align: center;-ms-flex-align: center;align-items: center;"><img class="call-to-action-img" src="../../_static/images/microsoft-logo.svg"/><div class="call-to-action-desktop-view">Run in Microsoft Learn</div><div class="call-to-action-mobile-view">Learn</div></div></a>')
}
</script>
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-71919972-3"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'UA-71919972-3');
</script>
<script>
$("[data-behavior='call-to-action-event']").on('click', function(){
ga('send', {
hitType: 'event',
eventCategory: $(this).attr("data-response"),
eventAction: 'click',
eventLabel: window.location.href
});
gtag('event', 'click', {
'event_category': $(this).attr("data-response"),
'event_label': $("h1").first().text(),
'tutorial_link': window.location.href
});
});
$("[data-behavior='tutorial-rating']").on('click', function(){
gtag('event', 'click', {
'event_category': 'Tutorial Rating',
'event_label': $("h1").first().text(),
'value': $(this).attr("data-count")
});
});
if (location.pathname == "/") {
$(".rating-container").hide();
$(".hr-bottom").hide();
}
</script>
<script type="text/javascript">
var collapsedSections = ['ํ์ดํ ์น(PyTorch) ๋ ์ํผ', 'ํ์ดํ ์น(PyTorch) ๋ฐฐ์ฐ๊ธฐ', '์ด๋ฏธ์ง/๋น๋์ค', '์ค๋์ค', 'ํ
์คํธ', '๊ฐํํ์ต', 'PyTorch ๋ชจ๋ธ์ ํ๋ก๋์
ํ๊ฒฝ์ ๋ฐฐํฌํ๊ธฐ', 'Code Transforms with FX', 'ํ๋ก ํธ์๋ API', 'PyTorch ํ์ฅํ๊ธฐ', '๋ชจ๋ธ ์ต์ ํ', '๋ณ๋ ฌ ๋ฐ ๋ถ์ฐ ํ์ต', 'Mobile'];
</script>
<!-- Begin Footer -->
<div class="container-fluid docs-tutorials-resources" id="docs-tutorials-resources">
<div class="container">
<div class="row">
<div class="col-md-4 text-center">
<h2>๊ณต์ ๋ฌธ์ (์์ด)</h2>
<p>PyTorch ๊ณต์ ๋ฌธ์์
๋๋ค.</p>
<a id="orgTutorialLink" class="with-right-arrow" href="https://pytorch.org/docs/stable/index.html" target="_blank">๊ณต์ ๋ฌธ์๋ก ์ด๋</a>
</div>
<div class="col-md-4 text-center">
<h2>ํ๊ตญ์ด ํํ ๋ฆฌ์ผ</h2>
<p>ํ๊ตญ์ด๋ก ๋ฒ์ญ ์ค์ธ PyTorch ํํ ๋ฆฌ์ผ์
๋๋ค.</p>
<a class="with-right-arrow" href="https://tutorials.pytorch.kr">ํํ ๋ฆฌ์ผ๋ก ์ด๋</a>
</div>
<div class="col-md-4 text-center">
<h2>์ปค๋ฎค๋ํฐ</h2>