-
Notifications
You must be signed in to change notification settings - Fork 257
/
Copy pathcpp_extension.html
1803 lines (1538 loc) ยท 192 KB
/
cpp_extension.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Custom C++ and CUDA Extensions — PyTorch Tutorials 1.10.2+cu102 documentation</title>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!-- <link rel="stylesheet" href="../_static/pygments.css" type="text/css" /> -->
<link rel="stylesheet" href="../_static/copybutton.css" type="text/css" />
<link rel="stylesheet" href="../_static/gallery.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.10.0-beta/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.11/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="../_static/katex-math.css" type="text/css" />
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Extending TorchScript with Custom C++ Operators" href="torch_script_custom_ops.html" />
<link rel="prev" title="Fusing Convolution and Batch Norm using Custom Function" href="../intermediate/custom_function_conv_bn_tutorial.html" />
<script src="../_static/js/modernizr.min.js"></script>
<!-- Preload the theme fonts -->
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-book.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium-italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<!-- Preload the katex fonts -->
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Math-Italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Main-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Main-Bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size1-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size4-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size2-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Size3-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/katex@0.10.0/dist/fonts/KaTeX_Caligraphic-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.2/css/all.css" integrity="sha384-vSIIfh2YWi9wW0r9iZe7RJPrKwp6bG+s9QZMoITbCckVJqGCCRhc+ccxNcdpHuYu" crossorigin="anonymous">
</head>
<div class="container-fluid header-holder tutorials-header" id="header-holder">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.kr/" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li>
<a href="https://pytorch.kr/get-started">์์ํ๊ธฐ</a>
</li>
<li class="active">
<a href="https://tutorials.pytorch.kr">ํํ ๋ฆฌ์ผ</a>
</li>
<li>
<a href="https://pytorch.kr/hub">ํ๋ธ</a>
</li>
<li>
<a href="https://discuss.pytorch.kr">์ปค๋ฎค๋ํฐ</a>
</li>
</ul>
</div>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<body class="pytorch-body">
<div class="table-of-contents-link-wrapper">
<span>Table of Contents</span>
<a href="#" class="toggle-table-of-contents" data-behavior="toggle-table-of-contents"></a>
</div>
<nav data-toggle="wy-nav-shift" class="pytorch-left-menu" id="pytorch-left-menu">
<div class="pytorch-side-scroll">
<div class="pytorch-menu pytorch-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<div class="pytorch-left-menu-search">
<div class="version">
1.10.2+cu102
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search Tutorials" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<p class="caption"><span class="caption-text">ํ์ดํ ์น(PyTorch) ๋ ์ํผ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../recipes/recipes_index.html">๋ชจ๋ ๋ ์ํผ ๋ณด๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../prototype/prototype_index.html">๋ชจ๋ ํ๋กํ ํ์
๋ ์ํผ ๋ณด๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">ํ์ดํ ์น(PyTorch) ์์ํ๊ธฐ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/intro.html">ํ์ดํ ์น(PyTorch) ๊ธฐ๋ณธ ์ตํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/quickstart_tutorial.html">๋น ๋ฅธ ์์(Quickstart)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/tensorqs_tutorial.html">ํ
์(Tensor)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/data_tutorial.html">Dataset๊ณผ DataLoader</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/transforms_tutorial.html">๋ณํ(Transform)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/buildmodel_tutorial.html">์ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์ฑํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/autogradqs_tutorial.html"><code class="docutils literal notranslate"><span class="pre">torch.autograd</span></code>๋ฅผ ์ฌ์ฉํ ์๋ ๋ฏธ๋ถ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/optimization_tutorial.html">๋ชจ๋ธ ๋งค๊ฐ๋ณ์ ์ต์ ํํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/basics/saveloadrun_tutorial.html">๋ชจ๋ธ ์ ์ฅํ๊ณ ๋ถ๋ฌ์ค๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">Introduction to PyTorch on YouTube</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt.html">Introduction to PyTorch - YouTube Series</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/introyt1_tutorial.html">Introduction to PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/tensors_deeper_tutorial.html">Introduction to PyTorch Tensors</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/autogradyt_tutorial.html">The Fundamentals of Autograd</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/modelsyt_tutorial.html">Building Models with PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/tensorboardyt_tutorial.html">PyTorch TensorBoard Support</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/trainingyt.html">Training with PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/introyt/captumyt.html">Model Understanding with Captum</a></li>
</ul>
<p class="caption"><span class="caption-text">ํ์ดํ ์น(PyTorch) ๋ฐฐ์ฐ๊ธฐ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/deep_learning_60min_blitz.html">PyTorch๋ก ๋ฅ๋ฌ๋ํ๊ธฐ: 60๋ถ๋ง์ ๋์ฅ๋ด๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/pytorch_with_examples.html">์์ ๋ก ๋ฐฐ์ฐ๋ ํ์ดํ ์น(PyTorch)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/nn_tutorial.html"><cite>torch.nn</cite> ์ด <em>์ค์ ๋ก</em> ๋ฌด์์ธ๊ฐ์?</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/tensorboard_tutorial.html">TensorBoard๋ก ๋ชจ๋ธ, ๋ฐ์ดํฐ, ํ์ต ์๊ฐํํ๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">์ด๋ฏธ์ง/๋น๋์ค</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/torchvision_tutorial.html">TorchVision ๊ฐ์ฒด ๊ฒ์ถ ๋ฏธ์ธ์กฐ์ (Finetuning) ํํ ๋ฆฌ์ผ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/transfer_learning_tutorial.html">์ปดํจํฐ ๋น์ (Vision)์ ์ํ ์ ์ดํ์ต(Transfer Learning)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/fgsm_tutorial.html">์ ๋์ ์์ ์์ฑ(Adversarial Example Generation)</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/dcgan_faces_tutorial.html">DCGAN ํํ ๋ฆฌ์ผ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/vt_tutorial.html">๋ฐฐํฌ๋ฅผ ์ํ ๋น์ ํธ๋์คํฌ๋จธ(Vision Transformer) ๋ชจ๋ธ ์ต์ ํํ๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">์ค๋์ค</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_io_tutorial.html">Audio I/O</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_resampling_tutorial.html">Audio Resampling</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_data_augmentation_tutorial.html">Audio Data Augmentation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_feature_extractions_tutorial.html">Audio Feature Extractions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_feature_augmentation_tutorial.html">Audio Feature Augmentation</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/audio_datasets_tutorial.html">Audio Datasets</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/speech_recognition_pipeline_tutorial.html">Speech Recognition with Wav2Vec2</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/speech_command_classification_with_torchaudio_tutorial.html">Speech Command Classification with torchaudio</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/text_to_speech_with_torchaudio.html">Text-to-speech with torchaudio</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/forced_alignment_with_torchaudio_tutorial.html">Forced Alignment with Wav2Vec2</a></li>
</ul>
<p class="caption"><span class="caption-text">ํ
์คํธ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/transformer_tutorial.html">nn.Transformer ์ TorchText ๋ก ์ํ์ค-ํฌ-์ํ์ค(Sequence-to-Sequence) ๋ชจ๋ธ๋งํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/char_rnn_classification_tutorial.html">๊ธฐ์ด๋ถํฐ ์์ํ๋ NLP: ๋ฌธ์-๋จ์ RNN์ผ๋ก ์ด๋ฆ ๋ถ๋ฅํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/char_rnn_generation_tutorial.html">๊ธฐ์ด๋ถํฐ ์์ํ๋ NLP: ๋ฌธ์-๋จ์ RNN์ผ๋ก ์ด๋ฆ ์์ฑํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/seq2seq_translation_tutorial.html">๊ธฐ์ด๋ถํฐ ์์ํ๋ NLP: Sequence to Sequence ๋คํธ์ํฌ์ Attention์ ์ด์ฉํ ๋ฒ์ญ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/text_sentiment_ngrams_tutorial.html">torchtext ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ํ
์คํธ ๋ถ๋ฅํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/translation_transformer.html">nn.Transformer์ torchtext๋ก ์ธ์ด ๋ฒ์ญํ๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">๊ฐํํ์ต</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/reinforcement_q_learning.html">๊ฐํ ํ์ต (DQN) ํํ ๋ฆฌ์ผ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/mario_rl_tutorial.html">Train a Mario-playing RL Agent</a></li>
</ul>
<p class="caption"><span class="caption-text">PyTorch ๋ชจ๋ธ์ ํ๋ก๋์
ํ๊ฒฝ์ ๋ฐฐํฌํ๊ธฐ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/flask_rest_api_tutorial.html">Flask๋ฅผ ์ฌ์ฉํ์ฌ Python์์ PyTorch๋ฅผ REST API๋ก ๋ฐฐํฌํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/Intro_to_TorchScript_tutorial.html">TorchScript ์๊ฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="cpp_export.html">C++์์ TorchScript ๋ชจ๋ธ ๋ก๋ฉํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="super_resolution_with_onnxruntime.html">(์ ํ) PyTorch ๋ชจ๋ธ์ ONNX์ผ๋ก ๋ณํํ๊ณ ONNX ๋ฐํ์์์ ์คํํ๊ธฐ</a></li>
</ul>
<p class="caption"><span class="caption-text">Code Transforms with FX</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/fx_conv_bn_fuser.html">(๋ฒ ํ) FX์์ ํฉ์ฑ๊ณฑ/๋ฐฐ์น ์ ๊ทํ(Convolution/Batch Norm) ๊ฒฐํฉ๊ธฐ(Fuser) ๋ง๋ค๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/fx_profiling_tutorial.html">(beta) Building a Simple CPU Performance Profiler with FX</a></li>
</ul>
<p class="caption"><span class="caption-text">ํ๋ก ํธ์๋ API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/memory_format_tutorial.html">(๋ฒ ํ) PyTorch๋ฅผ ์ฌ์ฉํ Channels Last ๋ฉ๋ชจ๋ฆฌ ํ์</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/forward_ad_usage.html">Forward-mode Automatic Differentiation (Beta)</a></li>
<li class="toctree-l1"><a class="reference internal" href="cpp_frontend.html">PyTorch C++ ํ๋ก ํธ์๋ ์ฌ์ฉํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="torch-script-parallelism.html">TorchScript์ ๋์ ๋ณ๋ ฌ ์ฒ๋ฆฌ(Dynamic Parallelism)</a></li>
<li class="toctree-l1"><a class="reference internal" href="cpp_autograd.html">C++ ํ๋ก ํธ์๋์ ์๋ ๋ฏธ๋ถ (autograd)</a></li>
</ul>
<p class="caption"><span class="caption-text">PyTorch ํ์ฅํ๊ธฐ</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../intermediate/custom_function_double_backward_tutorial.html">Double Backward with Custom Functions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/custom_function_conv_bn_tutorial.html">Fusing Convolution and Batch Norm using Custom Function</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Custom C++ and CUDA Extensions</a></li>
<li class="toctree-l1"><a class="reference internal" href="torch_script_custom_ops.html">Extending TorchScript with Custom C++ Operators</a></li>
<li class="toctree-l1"><a class="reference internal" href="torch_script_custom_classes.html">์ปค์คํ
C++ ํด๋์ค๋ก TorchScript ํ์ฅํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="dispatcher.html">Registering a Dispatched Operator in C++</a></li>
<li class="toctree-l1"><a class="reference internal" href="extend_dispatcher.html">Extending dispatcher for a new backend in C++</a></li>
</ul>
<p class="caption"><span class="caption-text">๋ชจ๋ธ ์ต์ ํ</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/profiler.html">PyTorch ๋ชจ๋ ํ๋กํ์ผ๋ง ํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/tensorboard_profiler_tutorial.html">PyTorch Profiler With TensorBoard</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/hyperparameter_tuning_tutorial.html">Hyperparameter tuning with Ray Tune</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/vt_tutorial.html">๋ฐฐํฌ๋ฅผ ์ํ ๋น์ ํธ๋์คํฌ๋จธ(Vision Transformer) ๋ชจ๋ธ ์ต์ ํํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/parametrizations.html">Parametrizations Tutorial</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/pruning_tutorial.html">๊ฐ์ง์น๊ธฐ ๊ธฐ๋ฒ(Pruning) ํํ ๋ฆฌ์ผ</a></li>
<li class="toctree-l1"><a class="reference internal" href="dynamic_quantization_tutorial.html">(๋ฒ ํ) LSTM ๊ธฐ๋ฐ ๋จ์ด ๋จ์ ์ธ์ด ๋ชจ๋ธ์ ๋์ ์์ํ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/dynamic_quantization_bert_tutorial.html">(๋ฒ ํ) BERT ๋ชจ๋ธ ๋์ ์์ํํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/quantized_transfer_learning_tutorial.html">(๋ฒ ํ) ์ปดํจํฐ ๋น์ ํํ ๋ฆฌ์ผ์ ์ํ ์์ํ๋ ์ ์ดํ์ต(Quantized Transfer Learning)</a></li>
<li class="toctree-l1"><a class="reference internal" href="static_quantization_tutorial.html">(beta) Static Quantization with Eager Mode in PyTorch</a></li>
</ul>
<p class="caption"><span class="caption-text">๋ณ๋ ฌ ๋ฐ ๋ถ์ฐ ํ์ต</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/dist_overview.html">PyTorch Distributed Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/model_parallel_tutorial.html">๋จ์ผ ๋จธ์ ์ ์ฌ์ฉํ ๋ชจ๋ธ ๋ณ๋ ฌํ ๋ชจ๋ฒ ์ฌ๋ก</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/ddp_tutorial.html">๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ ์์ํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/dist_tuto.html">PyTorch๋ก ๋ถ์ฐ ์ดํ๋ฆฌ์ผ์ด์
๊ฐ๋ฐํ๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/rpc_tutorial.html">Getting Started with Distributed RPC Framework</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/rpc_param_server_tutorial.html">Implementing a Parameter Server Using Distributed RPC Framework</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/dist_pipeline_parallel_tutorial.html">Distributed Pipeline Parallelism Using RPC</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/rpc_async_execution.html">Implementing Batch RPC Processing Using Asynchronous Executions</a></li>
<li class="toctree-l1"><a class="reference internal" href="rpc_ddp_tutorial.html">๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ(DDP)๊ณผ ๋ถ์ฐ RPC ํ๋ ์์ํฌ ๊ฒฐํฉ</a></li>
<li class="toctree-l1"><a class="reference internal" href="../intermediate/pipeline_tutorial.html">ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ๋ก ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํ์ต์ํค๊ธฐ</a></li>
<li class="toctree-l1"><a class="reference internal" href="ddp_pipeline.html">๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ๋ณ๋ ฌ ์ฒ๋ฆฌ ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํ์ต</a></li>
<li class="toctree-l1"><a class="reference internal" href="generic_join.html">Distributed Training with Uneven Inputs Using the Join Context Manager</a></li>
</ul>
<p class="caption"><span class="caption-text">Mobile</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../beginner/deeplabv3_on_ios.html">iOS์์์ ์ด๋ฏธ์ง ๋ถํ DeepLapV3</a></li>
<li class="toctree-l1"><a class="reference internal" href="../beginner/deeplabv3_on_android.html">์๋๋ก์ด๋์์์ ์ด๋ฏธ์ง ๋ถํ DeepLapV3</a></li>
</ul>
</div>
</div>
</nav>
<div class="pytorch-container">
<div class="pytorch-page-level-bar" id="pytorch-page-level-bar">
<div class="pytorch-breadcrumbs-wrapper">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="pytorch-breadcrumbs">
<li>
<a href="../index.html">
Tutorials
</a> >
</li>
<li>Custom C++ and CUDA Extensions</li>
<li class="pytorch-breadcrumbs-aside">
<a href="../_sources/advanced/cpp_extension.rst.txt" rel="nofollow"><img src="../_static/images/view-page-source-icon.svg"></a>
</li>
</ul>
</div>
</div>
<div class="pytorch-shortcuts-wrapper" id="pytorch-shortcuts-wrapper">
Shortcuts
</div>
</div>
<section data-toggle="wy-nav-shift" id="pytorch-content-wrap" class="pytorch-content-wrap">
<div class="pytorch-content-left">
<div class="pytorch-call-to-action-links">
<div id="tutorial-type">advanced/cpp_extension</div>
<div id="google-colab-link">
<img class="call-to-action-img" src="../_static/images/pytorch-colab.svg"/>
<div class="call-to-action-desktop-view">Run in Google Colab</div>
<div class="call-to-action-mobile-view">Colab</div>
</div>
<div id="download-notebook-link">
<img class="call-to-action-notebook-img" src="../_static/images/pytorch-download.svg"/>
<div class="call-to-action-desktop-view">Download Notebook</div>
<div class="call-to-action-mobile-view">Notebook</div>
</div>
<div id="github-view-link">
<img class="call-to-action-img" src="../_static/images/pytorch-github.svg"/>
<div class="call-to-action-desktop-view">View on GitHub</div>
<div class="call-to-action-mobile-view">GitHub</div>
</div>
</div>
<div class="rst-content">
<div role="main" class="main-content" itemscope="itemscope" itemtype="http://schema.org/Article">
<article itemprop="articleBody" id="pytorch-article" class="pytorch-article">
<div class="section" id="custom-c-and-cuda-extensions">
<h1>Custom C++ and CUDA Extensions<a class="headerlink" href="#custom-c-and-cuda-extensions" title="Permalink to this headline">ยถ</a></h1>
<p><strong>Author</strong>: <a class="reference external" href="https://www.goldsborough.me/">Peter Goldsborough</a></p>
<p>PyTorch provides a plethora of operations related to neural networks, arbitrary
tensor algebra, data wrangling and other purposes. However, you may still find
yourself in need of a more customized operation. For example, you might want to
use a novel activation function you found in a paper, or implement an operation
you developed as part of your research.</p>
<p>The easiest way of integrating such a custom operation in PyTorch is to write it
in Python by extending <code class="xref py py-class docutils literal notranslate"><span class="pre">Function</span></code> and <code class="xref py py-class docutils literal notranslate"><span class="pre">Module</span></code> as outlined <a class="reference external" href="https://pytorch.org/docs/master/notes/extending.html">here</a>. This gives you the full
power of automatic differentiation (spares you from writing derivative
functions) as well as the usual expressiveness of Python. However, there may be
times when your operation is better implemented in C++. For example, your code
may need to be <em>really</em> fast because it is called very frequently in your model
or is very expensive even for few calls. Another plausible reason is that it
depends on or interacts with other C or C++ libraries. To address such cases,
PyTorch provides a very easy way of writing custom <em>C++ extensions</em>.</p>
<p>C++ extensions are a mechanism we have developed to allow users (you) to create
PyTorch operators defined <em>out-of-source</em>, i.e. separate from the PyTorch
backend. This approach is <em>different</em> from the way native PyTorch operations are
implemented. C++ extensions are intended to spare you much of the boilerplate
associated with integrating an operation with PyTorchโs backend while providing
you with a high degree of flexibility for your PyTorch-based projects.
Nevertheless, once you have defined your operation as a C++ extension, turning
it into a native PyTorch function is largely a matter of code organization,
which you can tackle after the fact if you decide to contribute your operation
upstream.</p>
<div class="section" id="motivation-and-example">
<h2>Motivation and Example<a class="headerlink" href="#motivation-and-example" title="Permalink to this headline">ยถ</a></h2>
<p>The rest of this note will walk through a practical example of writing and using
a C++ (and CUDA) extension. If you are being chased or someone will fire you if
you donโt get that op done by the end of the day, you can skip this section and
head straight to the implementation details in the next section.</p>
<p>Letโs say youโve come up with a new kind of recurrent unit that you found to
have superior properties compared to the state of the art. This recurrent unit
is similar to an LSTM, but differs in that it lacks a <em>forget gate</em> and uses an
<em>Exponential Linear Unit</em> (ELU) as its internal activation function. Because
this unit never forgets, weโll call it <em>LLTM</em>, or <em>Long-Long-Term-Memory</em> unit.</p>
<p>The two ways in which LLTMs differ from vanilla LSTMs are significant enough
that we canโt configure PyTorchโs <code class="docutils literal notranslate"><span class="pre">LSTMCell</span></code> for our purposes, so weโll have to
create a custom cell. The first and easiest approach for this โ and likely in
all cases a good first step โ is to implement our desired functionality in
plain PyTorch with Python. For this, we need to subclass
<code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.Module</span></code> and implement the forward pass of the LLTM. This would
look something like this:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">LLTM</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_features</span><span class="p">,</span> <span class="n">state_size</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">LLTM</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_features</span> <span class="o">=</span> <span class="n">input_features</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state_size</span> <span class="o">=</span> <span class="n">state_size</span>
<span class="c1"># 3 * state_size for input gate, output gate and candidate cell gate.</span>
<span class="c1"># input_features + state_size because we will multiply with [input, h].</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weights</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">3</span> <span class="o">*</span> <span class="n">state_size</span><span class="p">,</span> <span class="n">input_features</span> <span class="o">+</span> <span class="n">state_size</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">3</span> <span class="o">*</span> <span class="n">state_size</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reset_parameters</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">reset_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">stdv</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">state_size</span><span class="p">)</span>
<span class="k">for</span> <span class="n">weight</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
<span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="n">stdv</span><span class="p">,</span> <span class="o">+</span><span class="n">stdv</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
<span class="n">old_h</span><span class="p">,</span> <span class="n">old_cell</span> <span class="o">=</span> <span class="n">state</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">old_h</span><span class="p">,</span> <span class="nb">input</span><span class="p">],</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="c1"># Compute the input, output and candidate cell gates with one MM.</span>
<span class="n">gate_weights</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">linear</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">)</span>
<span class="c1"># Split the combined gate weight matrix into its components.</span>
<span class="n">gates</span> <span class="o">=</span> <span class="n">gate_weights</span><span class="o">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">input_gate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">output_gate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="c1"># Here we use an ELU instead of the usual tanh.</span>
<span class="n">candidate_cell</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">elu</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span>
<span class="c1"># Compute the new cell state.</span>
<span class="n">new_cell</span> <span class="o">=</span> <span class="n">old_cell</span> <span class="o">+</span> <span class="n">candidate_cell</span> <span class="o">*</span> <span class="n">input_gate</span>
<span class="c1"># Compute the new hidden state and output.</span>
<span class="n">new_h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">new_cell</span><span class="p">)</span> <span class="o">*</span> <span class="n">output_gate</span>
<span class="k">return</span> <span class="n">new_h</span><span class="p">,</span> <span class="n">new_cell</span>
</pre></div>
</div>
<p>which we could then use as expected:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">input_features</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span>
<span class="n">rnn</span> <span class="o">=</span> <span class="n">LLTM</span><span class="p">(</span><span class="n">input_features</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span>
<span class="n">new_h</span><span class="p">,</span> <span class="n">new_C</span> <span class="o">=</span> <span class="n">rnn</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">C</span><span class="p">))</span>
</pre></div>
</div>
<p>Naturally, if at all possible and plausible, you should use this approach to
extend PyTorch. Since PyTorch has highly optimized implementations of its
operations for CPU <em>and</em> GPU, powered by libraries such as <a class="reference external" href="https://developer.nvidia.com/cudnn">NVIDIA cuDNN</a>, <a class="reference external" href="https://software.intel.com/en-us/mkl">Intel MKL</a> or <a class="reference external" href="https://github.com/Maratyszcza/NNPACK">NNPACK</a>, PyTorch code like above will often be
fast enough. However, we can also see why, under certain circumstances, there is
room for further performance improvements. The most obvious reason is that
PyTorch has no knowledge of the <em>algorithm</em> you are implementing. It knows only
of the individual operations you use to compose your algorithm. As such, PyTorch
must execute your operations individually, one after the other. Since each
individual call to the implementation (or <em>kernel</em>) of an operation, which may
involve the launch of a CUDA kernel, has a certain amount of overhead, this
overhead may become significant across many function calls. Furthermore, the
Python interpreter that is running our code can itself slow down our program.</p>
<p>A definite method of speeding things up is therefore to rewrite parts in C++ (or
CUDA) and <em>fuse</em> particular groups of operations. Fusing means combining the
implementations of many functions into a single function, which profits from
fewer kernel launches as well as other optimizations we can perform with
increased visibility of the global flow of data.</p>
<p>Letโs see how we can use C++ extensions to implement a <em>fused</em> version of the
LLTM. Weโll begin by writing it in plain C++, using the <a class="reference external" href="https://github.com/zdevito/ATen">ATen</a> library that powers much of PyTorchโs
backend, and see how easily it lets us translate our Python code. Weโll then
speed things up even more by moving parts of the model to CUDA kernel to benefit
from the massive parallelism GPUs provide.</p>
</div>
<div class="section" id="writing-a-c-extension">
<h2>Writing a C++ Extension<a class="headerlink" href="#writing-a-c-extension" title="Permalink to this headline">ยถ</a></h2>
<p>C++ extensions come in two flavors: They can be built โahead of timeโ with
<code class="xref py py-mod docutils literal notranslate"><span class="pre">setuptools</span></code>, or โjust in timeโ via
<code class="xref py py-func docutils literal notranslate"><span class="pre">torch.utils.cpp_extension.load()</span></code>. Weโll begin with the first approach and
discuss the latter later.</p>
<div class="section" id="building-with-setuptools">
<h3>Building with <code class="xref py py-mod docutils literal notranslate"><span class="pre">setuptools</span></code><a class="headerlink" href="#building-with-setuptools" title="Permalink to this headline">ยถ</a></h3>
<p>For the โahead of timeโ flavor, we build our C++ extension by writing a
<code class="docutils literal notranslate"><span class="pre">setup.py</span></code> script that uses setuptools to compile our C++ code. For the LLTM, it
looks as simple as this:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">setuptools</span> <span class="kn">import</span> <span class="n">setup</span><span class="p">,</span> <span class="n">Extension</span>
<span class="kn">from</span> <span class="nn">torch.utils</span> <span class="kn">import</span> <span class="n">cpp_extension</span>
<span class="n">setup</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'lltm_cpp'</span><span class="p">,</span>
<span class="n">ext_modules</span><span class="o">=</span><span class="p">[</span><span class="n">cpp_extension</span><span class="o">.</span><span class="n">CppExtension</span><span class="p">(</span><span class="s1">'lltm_cpp'</span><span class="p">,</span> <span class="p">[</span><span class="s1">'lltm.cpp'</span><span class="p">])],</span>
<span class="n">cmdclass</span><span class="o">=</span><span class="p">{</span><span class="s1">'build_ext'</span><span class="p">:</span> <span class="n">cpp_extension</span><span class="o">.</span><span class="n">BuildExtension</span><span class="p">})</span>
</pre></div>
</div>
<p>In this code, <code class="xref py py-class docutils literal notranslate"><span class="pre">CppExtension</span></code> is a convenience wrapper around
<code class="xref py py-class docutils literal notranslate"><span class="pre">setuptools.Extension</span></code> that passes the correct include paths and sets
the language of the extension to C++. The equivalent vanilla <code class="xref py py-mod docutils literal notranslate"><span class="pre">setuptools</span></code>
code would simply be:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Extension</span><span class="p">(</span>
<span class="n">name</span><span class="o">=</span><span class="s1">'lltm_cpp'</span><span class="p">,</span>
<span class="n">sources</span><span class="o">=</span><span class="p">[</span><span class="s1">'lltm.cpp'</span><span class="p">],</span>
<span class="n">include_dirs</span><span class="o">=</span><span class="n">cpp_extension</span><span class="o">.</span><span class="n">include_paths</span><span class="p">(),</span>
<span class="n">language</span><span class="o">=</span><span class="s1">'c++'</span><span class="p">)</span>
</pre></div>
</div>
<p><code class="xref py py-class docutils literal notranslate"><span class="pre">BuildExtension</span></code> performs a number of required configuration steps and
checks and also manages mixed compilation in the case of mixed C++/CUDA
extensions. And thatโs all we really need to know about building C++ extensions
for now! Letโs now take a look at the implementation of our C++ extension,
which goes into <code class="docutils literal notranslate"><span class="pre">lltm.cpp</span></code>.</p>
</div>
<div class="section" id="writing-the-c-op">
<h3>Writing the C++ Op<a class="headerlink" href="#writing-the-c-op" title="Permalink to this headline">ยถ</a></h3>
<p>Letโs start implementing the LLTM in C++! One function weโll need for the
backward pass is the derivative of the sigmoid. This is a small enough piece of
code to discuss the overall environment that is available to us when writing C++
extensions:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="cp">#include</span> <span class="cpf"><torch/extension.h></span><span class="cp"></span>
<span class="cp">#include</span> <span class="cpf"><iostream></span><span class="cp"></span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">d_sigmoid</span><span class="p">(</span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">z</span><span class="p">)</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">s</span> <span class="o">=</span> <span class="n">torch</span><span class="o">::</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">z</span><span class="p">);</span>
<span class="k">return</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">s</span><span class="p">)</span> <span class="o">*</span> <span class="n">s</span><span class="p">;</span>
<span class="p">}</span>
</pre></div>
</div>
<p><code class="docutils literal notranslate"><span class="pre"><torch/extension.h></span></code> is the one-stop header to include all the necessary PyTorch
bits to write C++ extensions. It includes:</p>
<ul class="simple">
<li>The ATen library, which is our primary API for tensor computation,</li>
<li><a class="reference external" href="https://github.com/pybind/pybind11">pybind11</a>, which is how we create Python bindings for our C++ code,</li>
<li>Headers that manage the details of interaction between ATen and pybind11.</li>
</ul>
<p>The implementation of <code class="xref py py-func docutils literal notranslate"><span class="pre">d_sigmoid()</span></code> shows how to use the ATen API.
PyTorchโs tensor and variable interface is generated automatically from the
ATen library, so we can more or less translate our Python implementation 1:1
into C++. Our primary datatype for all computations will be
<code class="xref py py-class docutils literal notranslate"><span class="pre">torch::Tensor</span></code>. Its full API can be inspected <a class="reference external" href="https://pytorch.org/cppdocs/api/classat_1_1_tensor.html">here</a>. Notice
also that we can include <code class="docutils literal notranslate"><span class="pre"><iostream></span></code> or <em>any other C or C++ header</em> โ we have
the full power of C++11 at our disposal.</p>
<div class="section" id="forward-pass">
<h4>Forward Pass<a class="headerlink" href="#forward-pass" title="Permalink to this headline">ยถ</a></h4>
<p>Next we can port our entire forward pass to C++:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="cp">#include</span> <span class="cpf"><vector></span><span class="cp"></span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">at</span><span class="o">::</span><span class="n">Tensor</span><span class="o">></span> <span class="n">lltm_forward</span><span class="p">(</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">input</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">weights</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">bias</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">old_h</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">old_cell</span><span class="p">)</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="o">::</span><span class="n">cat</span><span class="p">({</span><span class="n">old_h</span><span class="p">,</span> <span class="n">input</span><span class="p">},</span> <span class="cm">/*dim=*/</span><span class="mi">1</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">gate_weights</span> <span class="o">=</span> <span class="n">torch</span><span class="o">::</span><span class="n">addmm</span><span class="p">(</span><span class="n">bias</span><span class="p">,</span> <span class="n">X</span><span class="p">,</span> <span class="n">weights</span><span class="p">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">));</span>
<span class="k">auto</span> <span class="n">gates</span> <span class="o">=</span> <span class="n">gate_weights</span><span class="p">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="cm">/*dim=*/</span><span class="mi">1</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">input_gate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">::</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">0</span><span class="p">]);</span>
<span class="k">auto</span> <span class="n">output_gate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">::</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">1</span><span class="p">]);</span>
<span class="k">auto</span> <span class="n">candidate_cell</span> <span class="o">=</span> <span class="n">torch</span><span class="o">::</span><span class="n">elu</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="cm">/*alpha=*/</span><span class="mf">1.0</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">new_cell</span> <span class="o">=</span> <span class="n">old_cell</span> <span class="o">+</span> <span class="n">candidate_cell</span> <span class="o">*</span> <span class="n">input_gate</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">new_h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">::</span><span class="n">tanh</span><span class="p">(</span><span class="n">new_cell</span><span class="p">)</span> <span class="o">*</span> <span class="n">output_gate</span><span class="p">;</span>
<span class="k">return</span> <span class="p">{</span><span class="n">new_h</span><span class="p">,</span>
<span class="n">new_cell</span><span class="p">,</span>
<span class="n">input_gate</span><span class="p">,</span>
<span class="n">output_gate</span><span class="p">,</span>
<span class="n">candidate_cell</span><span class="p">,</span>
<span class="n">X</span><span class="p">,</span>
<span class="n">gate_weights</span><span class="p">};</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
<div class="section" id="backward-pass">
<h4>Backward Pass<a class="headerlink" href="#backward-pass" title="Permalink to this headline">ยถ</a></h4>
<p>The C++ extension API currently does not provide a way of automatically
generating a backwards function for us. As such, we have to also implement the
backward pass of our LLTM, which computes the derivative of the loss with
respect to each input of the forward pass. Ultimately, we will plop both the
forward and backward function into a <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.autograd.Function</span></code> to create
a nice Python binding. The backward function is slightly more involved, so
weโll not dig deeper into the code (if you are interested, <a class="reference external" href="https://www.cs.toronto.edu/~graves/phd.pdf">Alex Gravesโ thesis</a> is a good read for more
information on this):</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="c1">// tanh'(z) = 1 - tanh^2(z)</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">d_tanh</span><span class="p">(</span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">z</span><span class="p">)</span> <span class="p">{</span>
<span class="k">return</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">z</span><span class="p">.</span><span class="n">tanh</span><span class="p">().</span><span class="n">pow</span><span class="p">(</span><span class="mi">2</span><span class="p">);</span>
<span class="p">}</span>
<span class="c1">// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">d_elu</span><span class="p">(</span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">z</span><span class="p">,</span> <span class="n">torch</span><span class="o">::</span><span class="n">Scalar</span> <span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">e</span> <span class="o">=</span> <span class="n">z</span><span class="p">.</span><span class="n">exp</span><span class="p">();</span>
<span class="k">auto</span> <span class="n">mask</span> <span class="o">=</span> <span class="p">(</span><span class="n">alpha</span> <span class="o">*</span> <span class="p">(</span><span class="n">e</span> <span class="o">-</span> <span class="mi">1</span><span class="p">))</span> <span class="o"><</span> <span class="mi">0</span><span class="p">;</span>
<span class="k">return</span> <span class="p">(</span><span class="n">z</span> <span class="o">></span> <span class="mi">0</span><span class="p">).</span><span class="n">type_as</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="o">+</span> <span class="n">mask</span><span class="p">.</span><span class="n">type_as</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">alpha</span> <span class="o">*</span> <span class="n">e</span><span class="p">);</span>
<span class="p">}</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span><span class="o">></span> <span class="n">lltm_backward</span><span class="p">(</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">grad_h</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">grad_cell</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">new_cell</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">input_gate</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">output_gate</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">candidate_cell</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">X</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">gate_weights</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">weights</span><span class="p">)</span> <span class="p">{</span>
<span class="k">auto</span> <span class="n">d_output_gate</span> <span class="o">=</span> <span class="n">torch</span><span class="o">::</span><span class="n">tanh</span><span class="p">(</span><span class="n">new_cell</span><span class="p">)</span> <span class="o">*</span> <span class="n">grad_h</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">d_tanh_new_cell</span> <span class="o">=</span> <span class="n">output_gate</span> <span class="o">*</span> <span class="n">grad_h</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">d_new_cell</span> <span class="o">=</span> <span class="n">d_tanh</span><span class="p">(</span><span class="n">new_cell</span><span class="p">)</span> <span class="o">*</span> <span class="n">d_tanh_new_cell</span> <span class="o">+</span> <span class="n">grad_cell</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">d_old_cell</span> <span class="o">=</span> <span class="n">d_new_cell</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">d_candidate_cell</span> <span class="o">=</span> <span class="n">input_gate</span> <span class="o">*</span> <span class="n">d_new_cell</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">d_input_gate</span> <span class="o">=</span> <span class="n">candidate_cell</span> <span class="o">*</span> <span class="n">d_new_cell</span><span class="p">;</span>
<span class="k">auto</span> <span class="n">gates</span> <span class="o">=</span> <span class="n">gate_weights</span><span class="p">.</span><span class="n">chunk</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="cm">/*dim=*/</span><span class="mi">1</span><span class="p">);</span>
<span class="n">d_input_gate</span> <span class="o">*=</span> <span class="n">d_sigmoid</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">0</span><span class="p">]);</span>
<span class="n">d_output_gate</span> <span class="o">*=</span> <span class="n">d_sigmoid</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">1</span><span class="p">]);</span>
<span class="n">d_candidate_cell</span> <span class="o">*=</span> <span class="n">d_elu</span><span class="p">(</span><span class="n">gates</span><span class="p">[</span><span class="mi">2</span><span class="p">]);</span>
<span class="k">auto</span> <span class="n">d_gates</span> <span class="o">=</span>
<span class="n">torch</span><span class="o">::</span><span class="n">cat</span><span class="p">({</span><span class="n">d_input_gate</span><span class="p">,</span> <span class="n">d_output_gate</span><span class="p">,</span> <span class="n">d_candidate_cell</span><span class="p">},</span> <span class="cm">/*dim=*/</span><span class="mi">1</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">d_weights</span> <span class="o">=</span> <span class="n">d_gates</span><span class="p">.</span><span class="n">t</span><span class="p">().</span><span class="n">mm</span><span class="p">(</span><span class="n">X</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">d_bias</span> <span class="o">=</span> <span class="n">d_gates</span><span class="p">.</span><span class="n">sum</span><span class="p">(</span><span class="cm">/*dim=*/</span><span class="mi">0</span><span class="p">,</span> <span class="cm">/*keepdim=*/</span><span class="nb">true</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">d_X</span> <span class="o">=</span> <span class="n">d_gates</span><span class="p">.</span><span class="n">mm</span><span class="p">(</span><span class="n">weights</span><span class="p">);</span>
<span class="k">const</span> <span class="k">auto</span> <span class="n">state_size</span> <span class="o">=</span> <span class="n">grad_h</span><span class="p">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">d_old_h</span> <span class="o">=</span> <span class="n">d_X</span><span class="p">.</span><span class="n">slice</span><span class="p">(</span><span class="cm">/*dim=*/</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">state_size</span><span class="p">);</span>
<span class="k">auto</span> <span class="n">d_input</span> <span class="o">=</span> <span class="n">d_X</span><span class="p">.</span><span class="n">slice</span><span class="p">(</span><span class="cm">/*dim=*/</span><span class="mi">1</span><span class="p">,</span> <span class="n">state_size</span><span class="p">);</span>
<span class="k">return</span> <span class="p">{</span><span class="n">d_old_h</span><span class="p">,</span> <span class="n">d_input</span><span class="p">,</span> <span class="n">d_weights</span><span class="p">,</span> <span class="n">d_bias</span><span class="p">,</span> <span class="n">d_old_cell</span><span class="p">};</span>
<span class="p">}</span>
</pre></div>
</div>
</div>
</div>
<div class="section" id="binding-to-python">
<h3>Binding to Python<a class="headerlink" href="#binding-to-python" title="Permalink to this headline">ยถ</a></h3>
<p>Once you have your operation written in C++ and ATen, you can use pybind11 to
bind your C++ functions or classes into Python in a very simple manner.
Questions or issues you have about this part of PyTorch C++ extensions will
largely be addressed by <a class="reference external" href="https://pybind11.readthedocs.io/en/stable/">pybind11 documentation</a>.</p>
<p>For our extensions, the necessary binding code spans only four lines:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="n">PYBIND11_MODULE</span><span class="p">(</span><span class="n">TORCH_EXTENSION_NAME</span><span class="p">,</span> <span class="n">m</span><span class="p">)</span> <span class="p">{</span>
<span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span><span class="s">"forward"</span><span class="p">,</span> <span class="o">&</span><span class="n">lltm_forward</span><span class="p">,</span> <span class="s">"LLTM forward"</span><span class="p">);</span>
<span class="n">m</span><span class="p">.</span><span class="n">def</span><span class="p">(</span><span class="s">"backward"</span><span class="p">,</span> <span class="o">&</span><span class="n">lltm_backward</span><span class="p">,</span> <span class="s">"LLTM backward"</span><span class="p">);</span>
<span class="p">}</span>
</pre></div>
</div>
<p>One bit to note here is the macro <code class="docutils literal notranslate"><span class="pre">TORCH_EXTENSION_NAME</span></code>. The torch extension
build will define it as the name you give your extension in the <code class="docutils literal notranslate"><span class="pre">setup.py</span></code>
script. In this case, the value of <code class="docutils literal notranslate"><span class="pre">TORCH_EXTENSION_NAME</span></code> would be โlltm_cppโ.
This is to avoid having to maintain the name of the extension in two places
(the build script and your C++ code), as a mismatch between the two can lead to
nasty and hard to track issues.</p>
</div>
<div class="section" id="using-your-extension">
<h3>Using Your Extension<a class="headerlink" href="#using-your-extension" title="Permalink to this headline">ยถ</a></h3>
<p>We are now set to import our extension in PyTorch. At this point, your directory
structure could look something like this:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">pytorch</span><span class="o">/</span>
<span class="n">lltm</span><span class="o">-</span><span class="n">extension</span><span class="o">/</span>
<span class="n">lltm</span><span class="o">.</span><span class="n">cpp</span>
<span class="n">setup</span><span class="o">.</span><span class="n">py</span>
</pre></div>
</div>
<p>Now, run <code class="docutils literal notranslate"><span class="pre">python</span> <span class="pre">setup.py</span> <span class="pre">install</span></code> to build and install your extension. This
should look something like this:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span>running install
running bdist_egg
running egg_info
creating lltm_cpp.egg-info
writing lltm_cpp.egg-info/PKG-INFO
writing dependency_links to lltm_cpp.egg-info/dependency_links.txt
writing top-level names to lltm_cpp.egg-info/top_level.txt
writing manifest file 'lltm_cpp.egg-info/SOURCES.txt'
reading manifest file 'lltm_cpp.egg-info/SOURCES.txt'
writing manifest file 'lltm_cpp.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_ext
building 'lltm_cpp' extension
creating build
creating build/temp.linux-x86_64-3.7
gcc -pthread -B ~/local/miniconda/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I~/local/miniconda/lib/python3.7/site-packages/torch/include -I~/local/miniconda/lib/python3.7/site-packages/torch/include/torch/csrc/api/include -I~/local/miniconda/lib/python3.7/site-packages/torch/include/TH -I~/local/miniconda/lib/python3.7/site-packages/torch/include/THC -I~/local/miniconda/include/python3.7m -c lltm.cpp -o build/temp.linux-x86_64-3.7/lltm.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=lltm_cpp -D_GLIBCXX_USE_CXX11_ABI=1 -std=c++11
cc1plus: warning: command line option โ-Wstrict-prototypesโ is valid for C/ObjC but not for C++
creating build/lib.linux-x86_64-3.7
g++ -pthread -shared -B ~/local/miniconda/compiler_compat -L~/local/miniconda/lib -Wl,-rpath=~/local/miniconda/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.7/lltm.o -o build/lib.linux-x86_64-3.7/lltm_cpp.cpython-37m-x86_64-linux-gnu.so
creating build/bdist.linux-x86_64
creating build/bdist.linux-x86_64/egg
copying build/lib.linux-x86_64-3.7/lltm_cpp.cpython-37m-x86_64-linux-gnu.so -> build/bdist.linux-x86_64/egg
creating stub loader for lltm_cpp.cpython-37m-x86_64-linux-gnu.so
byte-compiling build/bdist.linux-x86_64/egg/lltm_cpp.py to lltm_cpp.cpython-37.pyc
creating build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/PKG-INFO -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/SOURCES.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/dependency_links.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
copying lltm_cpp.egg-info/top_level.txt -> build/bdist.linux-x86_64/egg/EGG-INFO
writing build/bdist.linux-x86_64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
__pycache__.lltm_cpp.cpython-37: module references __file__
creating 'dist/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg' and adding 'build/bdist.linux-x86_64/egg' to it
removing 'build/bdist.linux-x86_64/egg' (and everything under it)
Processing lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
removing '~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg' (and everything under it)
creating ~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
Extracting lltm_cpp-0.0.0-py3.7-linux-x86_64.egg to ~/local/miniconda/lib/python3.7/site-packages
lltm-cpp 0.0.0 is already the active version in easy-install.pth
Installed ~/local/miniconda/lib/python3.7/site-packages/lltm_cpp-0.0.0-py3.7-linux-x86_64.egg
Processing dependencies for lltm-cpp==0.0.0
Finished processing dependencies for lltm-cpp==0.0.0
</pre></div>
</div>
<p>A small note on compilers: Due to ABI versioning issues, the compiler you use to
build your C++ extension must be <em>ABI-compatible</em> with the compiler PyTorch was
built with. In practice, this means that you must use GCC version 4.9 and above on Linux.
For Ubuntu 16.04 and other more-recent Linux distributions, this should be the
default compiler already. On MacOS, you must use clang (which does not have any ABI versioning issues). In the worst
case, you can build PyTorch from source with your compiler and then build the
extension with that same compiler.</p>
<p>Once your extension is built, you can simply import it in Python, using the
name you specified in your <code class="docutils literal notranslate"><span class="pre">setup.py</span></code> script. Just be sure to <code class="docutils literal notranslate"><span class="pre">import</span>
<span class="pre">torch</span></code> first, as this will resolve some symbols that the dynamic linker must
see:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">In</span> <span class="p">[</span><span class="mi">1</span><span class="p">]:</span> <span class="kn">import</span> <span class="nn">torch</span>
<span class="n">In</span> <span class="p">[</span><span class="mi">2</span><span class="p">]:</span> <span class="kn">import</span> <span class="nn">lltm_cpp</span>
<span class="n">In</span> <span class="p">[</span><span class="mi">3</span><span class="p">]:</span> <span class="n">lltm_cpp</span><span class="o">.</span><span class="n">forward</span>
<span class="n">Out</span><span class="p">[</span><span class="mi">3</span><span class="p">]:</span> <span class="o"><</span><span class="n">function</span> <span class="n">lltm</span><span class="o">.</span><span class="n">PyCapsule</span><span class="o">.</span><span class="n">forward</span><span class="o">></span>
</pre></div>
</div>
<p>If we call <code class="docutils literal notranslate"><span class="pre">help()</span></code> on the function or module, we can see that its signature
matches our C++ code:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">In</span><span class="p">[</span><span class="mi">4</span><span class="p">]</span> <span class="n">help</span><span class="p">(</span><span class="n">lltm_cpp</span><span class="o">.</span><span class="n">forward</span><span class="p">)</span>
<span class="n">forward</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="n">method</span> <span class="n">of</span> <span class="n">builtins</span><span class="o">.</span><span class="n">PyCapsule</span> <span class="n">instance</span>
<span class="n">forward</span><span class="p">(</span><span class="n">arg0</span><span class="p">:</span> <span class="n">torch</span><span class="p">::</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">arg1</span><span class="p">:</span> <span class="n">torch</span><span class="p">::</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">arg2</span><span class="p">:</span> <span class="n">torch</span><span class="p">::</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">arg3</span><span class="p">:</span> <span class="n">torch</span><span class="p">::</span><span class="n">Tensor</span><span class="p">,</span> <span class="n">arg4</span><span class="p">:</span> <span class="n">torch</span><span class="p">::</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-></span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="p">::</span><span class="n">Tensor</span><span class="p">]</span>
<span class="n">LLTM</span> <span class="n">forward</span>
</pre></div>
</div>
<p>Since we are now able to call our C++ functions from Python, we can wrap them
with <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.autograd.Function</span></code> and <code class="xref py py-class docutils literal notranslate"><span class="pre">torch.nn.Module</span></code> to make them first
class citizens of PyTorch:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">math</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="c1"># Our module!</span>
<span class="kn">import</span> <span class="nn">lltm_cpp</span>
<span class="k">class</span> <span class="nc">LLTMFunction</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">autograd</span><span class="o">.</span><span class="n">Function</span><span class="p">):</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">old_h</span><span class="p">,</span> <span class="n">old_cell</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">lltm_cpp</span><span class="o">.</span><span class="n">forward</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">weights</span><span class="p">,</span> <span class="n">bias</span><span class="p">,</span> <span class="n">old_h</span><span class="p">,</span> <span class="n">old_cell</span><span class="p">)</span>
<span class="n">new_h</span><span class="p">,</span> <span class="n">new_cell</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span>
<span class="n">variables</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="o">+</span> <span class="p">[</span><span class="n">weights</span><span class="p">]</span>
<span class="n">ctx</span><span class="o">.</span><span class="n">save_for_backward</span><span class="p">(</span><span class="o">*</span><span class="n">variables</span><span class="p">)</span>
<span class="k">return</span> <span class="n">new_h</span><span class="p">,</span> <span class="n">new_cell</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">(</span><span class="n">ctx</span><span class="p">,</span> <span class="n">grad_h</span><span class="p">,</span> <span class="n">grad_cell</span><span class="p">):</span>
<span class="n">outputs</span> <span class="o">=</span> <span class="n">lltm_cpp</span><span class="o">.</span><span class="n">backward</span><span class="p">(</span>
<span class="n">grad_h</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span> <span class="n">grad_cell</span><span class="o">.</span><span class="n">contiguous</span><span class="p">(),</span> <span class="o">*</span><span class="n">ctx</span><span class="o">.</span><span class="n">saved_tensors</span><span class="p">)</span>
<span class="n">d_old_h</span><span class="p">,</span> <span class="n">d_input</span><span class="p">,</span> <span class="n">d_weights</span><span class="p">,</span> <span class="n">d_bias</span><span class="p">,</span> <span class="n">d_old_cell</span> <span class="o">=</span> <span class="n">outputs</span>
<span class="k">return</span> <span class="n">d_input</span><span class="p">,</span> <span class="n">d_weights</span><span class="p">,</span> <span class="n">d_bias</span><span class="p">,</span> <span class="n">d_old_h</span><span class="p">,</span> <span class="n">d_old_cell</span>
<span class="k">class</span> <span class="nc">LLTM</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_features</span><span class="p">,</span> <span class="n">state_size</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">LLTM</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">input_features</span> <span class="o">=</span> <span class="n">input_features</span>
<span class="bp">self</span><span class="o">.</span><span class="n">state_size</span> <span class="o">=</span> <span class="n">state_size</span>
<span class="bp">self</span><span class="o">.</span><span class="n">weights</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span>
<span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">3</span> <span class="o">*</span> <span class="n">state_size</span><span class="p">,</span> <span class="n">input_features</span> <span class="o">+</span> <span class="n">state_size</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span><span class="mi">3</span> <span class="o">*</span> <span class="n">state_size</span><span class="p">))</span>
<span class="bp">self</span><span class="o">.</span><span class="n">reset_parameters</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">reset_parameters</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="n">stdv</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">state_size</span><span class="p">)</span>
<span class="k">for</span> <span class="n">weight</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
<span class="n">weight</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">uniform_</span><span class="p">(</span><span class="o">-</span><span class="n">stdv</span><span class="p">,</span> <span class="o">+</span><span class="n">stdv</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">state</span><span class="p">):</span>
<span class="k">return</span> <span class="n">LLTMFunction</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span><span class="p">,</span> <span class="o">*</span><span class="n">state</span><span class="p">)</span>
</pre></div>
</div>
<div class="section" id="performance-comparison">
<h4>Performance Comparison<a class="headerlink" href="#performance-comparison" title="Permalink to this headline">ยถ</a></h4>
<p>Now that we are able to use and call our C++ code from PyTorch, we can run a
small benchmark to see how much performance we gained from rewriting our op in
C++. Weโll run the LLTM forwards and backwards a few times and measure the
duration:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">input_features</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">state_size</span> <span class="o">=</span> <span class="mi">128</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">input_features</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span>
<span class="n">rnn</span> <span class="o">=</span> <span class="n">LLTM</span><span class="p">(</span><span class="n">input_features</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span>
<span class="n">forward</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">backward</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100000</span><span class="p">):</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">new_h</span><span class="p">,</span> <span class="n">new_C</span> <span class="o">=</span> <span class="n">rnn</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">C</span><span class="p">))</span>
<span class="n">forward</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="p">(</span><span class="n">new_h</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">+</span> <span class="n">new_C</span><span class="o">.</span><span class="n">sum</span><span class="p">())</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">backward</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">'Forward: </span><span class="si">{:.3f}</span><span class="s1"> us | Backward </span><span class="si">{:.3f}</span><span class="s1"> us'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">forward</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="mf">1e5</span><span class="p">,</span> <span class="n">backward</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="mf">1e5</span><span class="p">))</span>
</pre></div>
</div>
<p>If we run this code with the original LLTM we wrote in pure Python at the start
of this post, we get the following numbers (on my machine):</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Forward</span><span class="p">:</span> <span class="mf">506.480</span> <span class="n">us</span> <span class="o">|</span> <span class="n">Backward</span> <span class="mf">444.694</span> <span class="n">us</span>
</pre></div>
</div>
<p>and with our new C++ version:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Forward</span><span class="p">:</span> <span class="mf">349.335</span> <span class="n">us</span> <span class="o">|</span> <span class="n">Backward</span> <span class="mf">443.523</span> <span class="n">us</span>
</pre></div>
</div>
<p>We can already see a significant speedup for the forward function (more than
30%). For the backward function, a speedup is visible, albeit not a major one.
The backward pass I wrote above was not particularly optimized and could
definitely be improved. Also, PyTorchโs automatic differentiation engine can
automatically parallelize computation graphs, may use a more efficient flow of
operations overall, and is also implemented in C++, so itโs expected to be
fast. Nevertheless, this is a good start.</p>
</div>
<div class="section" id="performance-on-gpu-devices">
<h4>Performance on GPU Devices<a class="headerlink" href="#performance-on-gpu-devices" title="Permalink to this headline">ยถ</a></h4>
<p>A wonderful fact about PyTorchโs <em>ATen</em> backend is that it abstracts the
computing device you are running on. This means the same code we wrote for CPU
can <em>also</em> run on GPU, and individual operations will correspondingly dispatch
to GPU-optimized implementations. For certain operations like matrix multiply
(like <code class="docutils literal notranslate"><span class="pre">mm</span></code> or <code class="docutils literal notranslate"><span class="pre">addmm</span></code>), this is a big win. Letโs take a look at how much
performance we gain from running our C++ code with CUDA tensors. No changes to
our implementation are required, we simply need to put our tensors in GPU
memory from Python, with either adding <code class="docutils literal notranslate"><span class="pre">device=cuda_device</span></code> argument at
creation time or using <code class="docutils literal notranslate"><span class="pre">.to(cuda_device)</span></code> after creation:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="k">assert</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span>
<span class="n">cuda_device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">"cuda"</span><span class="p">)</span> <span class="c1"># device object representing GPU</span>
<span class="n">batch_size</span> <span class="o">=</span> <span class="mi">16</span>
<span class="n">input_features</span> <span class="o">=</span> <span class="mi">32</span>
<span class="n">state_size</span> <span class="o">=</span> <span class="mi">128</span>
<span class="c1"># Note the device=cuda_device arguments here</span>
<span class="n">X</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">input_features</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">cuda_device</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">state_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">cuda_device</span><span class="p">)</span>
<span class="n">C</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">state_size</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">cuda_device</span><span class="p">)</span>
<span class="n">rnn</span> <span class="o">=</span> <span class="n">LLTM</span><span class="p">(</span><span class="n">input_features</span><span class="p">,</span> <span class="n">state_size</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">cuda_device</span><span class="p">)</span>
<span class="n">forward</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">backward</span> <span class="o">=</span> <span class="mi">0</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100000</span><span class="p">):</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="n">new_h</span><span class="p">,</span> <span class="n">new_C</span> <span class="o">=</span> <span class="n">rnn</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">C</span><span class="p">))</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">forward</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
<span class="p">(</span><span class="n">new_h</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o">+</span> <span class="n">new_C</span><span class="o">.</span><span class="n">sum</span><span class="p">())</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>
<span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">synchronize</span><span class="p">()</span>
<span class="n">backward</span> <span class="o">+=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start</span>
<span class="nb">print</span><span class="p">(</span><span class="s1">'Forward: </span><span class="si">{:.3f}</span><span class="s1"> us | Backward </span><span class="si">{:.3f}</span><span class="s1"> us'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">forward</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="mf">1e5</span><span class="p">,</span> <span class="n">backward</span> <span class="o">*</span> <span class="mf">1e6</span><span class="o">/</span><span class="mf">1e5</span><span class="p">))</span>
</pre></div>
</div>
<p>Once more comparing our plain PyTorch code with our C++ version, now both
running on CUDA devices, we again see performance gains. For Python/PyTorch:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Forward</span><span class="p">:</span> <span class="mf">187.719</span> <span class="n">us</span> <span class="o">|</span> <span class="n">Backward</span> <span class="mf">410.815</span> <span class="n">us</span>
</pre></div>
</div>
<p>And C++/ATen:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Forward</span><span class="p">:</span> <span class="mf">149.802</span> <span class="n">us</span> <span class="o">|</span> <span class="n">Backward</span> <span class="mf">393.458</span> <span class="n">us</span>
</pre></div>
</div>
<p>Thatโs a great overall speedup compared to non-CUDA code. However, we can pull
even more performance out of our C++ code by writing custom CUDA kernels, which
weโll dive into soon. Before that, letโs discuss another way of building your C++
extensions.</p>
</div>
</div>
<div class="section" id="jit-compiling-extensions">
<h3>JIT Compiling Extensions<a class="headerlink" href="#jit-compiling-extensions" title="Permalink to this headline">ยถ</a></h3>
<p>Previously, I mentioned there were two ways of building C++ extensions: using
<code class="xref py py-mod docutils literal notranslate"><span class="pre">setuptools</span></code> or just in time (JIT). Having covered the former, letโs
elaborate on the latter. The JIT compilation mechanism provides you with a way
of compiling and loading your extensions on the fly by calling a simple
function in PyTorchโs API called <code class="xref py py-func docutils literal notranslate"><span class="pre">torch.utils.cpp_extension.load()</span></code>. For
the LLTM, this would look as simple as this:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">torch.utils.cpp_extension</span> <span class="kn">import</span> <span class="n">load</span>
<span class="n">lltm_cpp</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"lltm_cpp"</span><span class="p">,</span> <span class="n">sources</span><span class="o">=</span><span class="p">[</span><span class="s2">"lltm.cpp"</span><span class="p">])</span>
</pre></div>
</div>
<p>Here, we provide the function with the same information as for
<code class="xref py py-mod docutils literal notranslate"><span class="pre">setuptools</span></code>. In the background, this will do the following:</p>
<ol class="arabic simple">
<li>Create a temporary directory <code class="docutils literal notranslate"><span class="pre">/tmp/torch_extensions/lltm</span></code>,</li>
<li>Emit a <a class="reference external" href="https://ninja-build.org/">Ninja</a> build file into that temporary directory,</li>
<li>Compile your source files into a shared library,</li>
<li>Import this shared library as a Python module.</li>
</ol>
<p>In fact, if you pass <code class="docutils literal notranslate"><span class="pre">verbose=True</span></code> to <code class="xref py py-func docutils literal notranslate"><span class="pre">cpp_extension.load()</span></code>, you will
be informed about the process:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">Using</span> <span class="o">/</span><span class="n">tmp</span><span class="o">/</span><span class="n">torch_extensions</span> <span class="k">as</span> <span class="n">PyTorch</span> <span class="n">extensions</span> <span class="n">root</span><span class="o">...</span>
<span class="n">Emitting</span> <span class="n">ninja</span> <span class="n">build</span> <span class="n">file</span> <span class="o">/</span><span class="n">tmp</span><span class="o">/</span><span class="n">torch_extensions</span><span class="o">/</span><span class="n">lltm_cpp</span><span class="o">/</span><span class="n">build</span><span class="o">.</span><span class="n">ninja</span><span class="o">...</span>
<span class="n">Building</span> <span class="n">extension</span> <span class="n">module</span> <span class="n">lltm_cpp</span><span class="o">...</span>
<span class="n">Loading</span> <span class="n">extension</span> <span class="n">module</span> <span class="n">lltm_cpp</span><span class="o">...</span>
</pre></div>
</div>
<p>The resulting Python module will be exactly the same as produced by setuptools,
but removes the requirement of having to maintain a separate <code class="docutils literal notranslate"><span class="pre">setup.py</span></code> build
file. If your setup is more complicated and you do need the full power of
<code class="xref py py-mod docutils literal notranslate"><span class="pre">setuptools</span></code>, you <em>can</em> write your own <code class="docutils literal notranslate"><span class="pre">setup.py</span></code> โ but in many cases
this JIT technique will do just fine. The first time you run through this line,
it will take some time, as the extension is compiling in the background. Since
we use the Ninja build system to build your sources, re-compilation is
incremental and thus re-loading the extension when you run your Python module a
second time is fast and has low overhead if you didnโt change the extensionโs
source files.</p>
</div>
</div>
<div class="section" id="writing-a-mixed-c-cuda-extension">
<h2>Writing a Mixed C++/CUDA extension<a class="headerlink" href="#writing-a-mixed-c-cuda-extension" title="Permalink to this headline">ยถ</a></h2>
<p>To really take our implementation to the next level, we can hand-write parts of
our forward and backward passes with custom CUDA kernels. For the LLTM, this has
the prospect of being particularly effective, as there are a large number of
pointwise operations in sequence, that can all be fused and parallelized in a
single CUDA kernel. Letโs see how we could write such a CUDA kernel and
integrate it with PyTorch using this extension mechanism.</p>
<p>The general strategy for writing a CUDA extension is to first write a C++ file
which defines the functions that will be called from Python, and binds those
functions to Python with pybind11. Furthermore, this file will also <em>declare</em>
functions that are defined in CUDA (<code class="docutils literal notranslate"><span class="pre">.cu</span></code>) files. The C++ functions will then
do some checks and ultimately forward its calls to the CUDA functions. In the
CUDA files, we write our actual CUDA kernels. The <code class="xref py py-mod docutils literal notranslate"><span class="pre">cpp_extension</span></code> package
will then take care of compiling the C++ sources with a C++ compiler like
<code class="docutils literal notranslate"><span class="pre">gcc</span></code> and the CUDA sources with NVIDIAโs <code class="docutils literal notranslate"><span class="pre">nvcc</span></code> compiler. This ensures that
each compiler takes care of files it knows best to compile. Ultimately, they
will be linked into one shared library that is available to us from Python
code.</p>
<p>Weโll start with the C++ file, which weโll call <code class="docutils literal notranslate"><span class="pre">lltm_cuda.cpp</span></code>, for example:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="cp">#include</span> <span class="cpf"><torch/extension.h></span><span class="cp"></span>
<span class="cp">#include</span> <span class="cpf"><vector></span><span class="cp"></span>
<span class="c1">// CUDA forward declarations</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span><span class="o">></span> <span class="n">lltm_cuda_forward</span><span class="p">(</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">input</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">weights</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">bias</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">old_h</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">old_cell</span><span class="p">);</span>
<span class="n">std</span><span class="o">::</span><span class="n">vector</span><span class="o"><</span><span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span><span class="o">></span> <span class="n">lltm_cuda_backward</span><span class="p">(</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">grad_h</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">grad_cell</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">new_cell</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">input_gate</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">output_gate</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">candidate_cell</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">X</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">gate_weights</span><span class="p">,</span>
<span class="n">torch</span><span class="o">::</span><span class="n">Tensor</span> <span class="n">weights</span><span class="p">);</span>
<span class="c1">// C++ interface</span>
<span class="cp">#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")</span>
<span class="cp">#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")</span>
<span class="cp">#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)</span>