235
235
< div class ="pytorch-left-menu-search ">
236
236
237
237
< div class ="version ">
238
- < a href ='https://pytorch.org/docs/versions.html '> master (2.0.0a0+git046e88a ) ▼</ a >
238
+ < a href ='https://pytorch.org/docs/versions.html '> master (2.0.0a0+git5ed7c70 ) ▼</ a >
239
239
</ div >
240
240
241
241
@@ -1786,11 +1786,8 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
1786
1786
< span class ="n "> compiler_name</ span > < span class ="o "> =</ span > < span class ="s2 "> "inductor"</ span >
1787
1787
1788
1788
< span class ="k "> def</ span > < span class ="fm "> __init__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> mode</ span > < span class ="p "> ,</ span > < span class ="n "> options</ span > < span class ="p "> ,</ span > < span class ="n "> dynamic</ span > < span class ="p "> ):</ span >
1789
- < span class ="kn "> from</ span > < span class ="nn "> torch._inductor.compile_fx</ span > < span class ="kn "> import</ span > < span class ="n "> compile_fx</ span >
1790
-
1791
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> compile_fn</ span > < span class ="o "> =</ span > < span class ="n "> compile_fx</ span >
1792
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _torchdynamo_orig_callable</ span > < span class ="o "> =</ span > < span class ="n "> compile_fx</ span >
1793
1789
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="o "> =</ span > < span class ="nb "> dict</ span > < span class ="p "> ()</ span >
1790
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> dynamic</ span > < span class ="o "> =</ span > < span class ="n "> dynamic</ span >
1794
1791
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> apply_mode</ span > < span class ="p "> (</ span > < span class ="n "> mode</ span > < span class ="p "> )</ span >
1795
1792
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> apply_options</ span > < span class ="p "> (</ span > < span class ="n "> options</ span > < span class ="p "> )</ span >
1796
1793
< span class ="k "> if</ span > < span class ="n "> dynamic</ span > < span class ="p "> :</ span >
@@ -1800,16 +1797,25 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
1800
1797
< span class ="n "> options</ span > < span class ="ow "> or</ span > < span class ="p "> ()</ span >
1801
1798
< span class ="p "> ),</ span > < span class ="s2 "> "triton.cudagraphs does not support dynamic shapes"</ span >
1802
1799
1800
+ < span class ="k "> def</ span > < span class ="fm "> __eq__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> other</ span > < span class ="p "> ):</ span >
1801
+ < span class ="k "> return</ span > < span class ="p "> (</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> other</ span > < span class ="p "> ,</ span > < span class ="n "> _TorchCompileInductorWrapper</ span > < span class ="p "> )</ span > < span class ="ow "> and</ span >
1802
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="o "> ==</ span > < span class ="n "> other</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="ow "> and</ span >
1803
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> dynamic</ span > < span class ="o "> ==</ span > < span class ="n "> other</ span > < span class ="o "> .</ span > < span class ="n "> dynamic</ span > < span class ="p "> )</ span >
1804
+
1803
1805
< span class ="k "> def</ span > < span class ="nf "> apply_mode</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> mode</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ]):</ span >
1804
- < span class ="k "> if</ span > < span class ="n "> mode</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1805
- < span class ="k "> return</ span >
1806
- < span class ="k "> elif</ span > < span class ="n "> mode</ span > < span class ="o "> ==</ span > < span class ="s2 "> "default"</ span > < span class ="p "> :</ span >
1806
+ < span class ="k "> if</ span > < span class ="n "> mode</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="ow "> or</ span > < span class ="n "> mode</ span > < span class ="o "> ==</ span > < span class ="s2 "> "default"</ span > < span class ="p "> :</ span >
1807
1807
< span class ="k "> pass</ span >
1808
1808
< span class ="k "> elif</ span > < span class ="n "> mode</ span > < span class ="o "> ==</ span > < span class ="s2 "> "reduce-overhead"</ span > < span class ="p "> :</ span >
1809
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="p "> [</ span > < span class ="s2 "> "triton.cudagraphs"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span >
1809
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> apply_options</ span > < span class ="p "> ({</ span >
1810
+ < span class ="s2 "> "triton.cudagraphs"</ span > < span class ="p "> :</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
1811
+ < span class ="s2 "> "size_asserts"</ span > < span class ="p "> :</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
1812
+ < span class ="p "> })</ span >
1810
1813
< span class ="k "> elif</ span > < span class ="n "> mode</ span > < span class ="o "> ==</ span > < span class ="s2 "> "max-autotune"</ span > < span class ="p "> :</ span >
1811
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="p "> [</ span > < span class ="s2 "> "max_autotune"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span >
1812
- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="p "> [</ span > < span class ="s2 "> "triton.cudagraphs"</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span >
1814
+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> apply_options</ span > < span class ="p "> ({</ span >
1815
+ < span class ="s2 "> "epilogue_fusion"</ span > < span class ="p "> :</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
1816
+ < span class ="s2 "> "max_autotune"</ span > < span class ="p "> :</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
1817
+ < span class ="s2 "> "triton.cudagraphs"</ span > < span class ="p "> :</ span > < span class ="kc "> True</ span > < span class ="p "> ,</ span >
1818
+ < span class ="p "> })</ span >
1813
1819
< span class ="k "> else</ span > < span class ="p "> :</ span >
1814
1820
< span class ="k "> raise</ span > < span class ="ne "> RuntimeError</ span > < span class ="p "> (</ span >
1815
1821
< span class ="sa "> f</ span > < span class ="s2 "> "Unrecognized mode=</ span > < span class ="si "> {</ span > < span class ="n "> mode</ span > < span class ="si "> }</ span > < span class ="s2 "> , should be one of: default, reduce-overhead, max-autotune"</ span >
@@ -1837,7 +1843,9 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
1837
1843
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="p "> [</ span > < span class ="n "> attr_name</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> val</ span >
1838
1844
1839
1845
< span class ="k "> def</ span > < span class ="fm "> __call__</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> model_</ span > < span class ="p "> ,</ span > < span class ="n "> inputs_</ span > < span class ="p "> ):</ span >
1840
- < span class ="k "> return</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> compile_fn</ span > < span class ="p "> (</ span > < span class ="n "> model_</ span > < span class ="p "> ,</ span > < span class ="n "> inputs_</ span > < span class ="p "> ,</ span > < span class ="n "> config_patches</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="p "> )</ span >
1846
+ < span class ="kn "> from</ span > < span class ="nn "> torch._inductor.compile_fx</ span > < span class ="kn "> import</ span > < span class ="n "> compile_fx</ span >
1847
+
1848
+ < span class ="k "> return</ span > < span class ="n "> compile_fx</ span > < span class ="p "> (</ span > < span class ="n "> model_</ span > < span class ="p "> ,</ span > < span class ="n "> inputs_</ span > < span class ="p "> ,</ span > < span class ="n "> config_patches</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="p "> )</ span >
1841
1849
1842
1850
1843
1851
< div class ="viewcode-block " id ="compile "> < a class ="viewcode-back " href ="../generated/torch.compile.html#torch.compile "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> compile</ span > < span class ="p "> (</ span > < span class ="n "> model</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="n "> Callable</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span > < span class ="p "> ,</ span > < span class ="o "> *</ span > < span class ="p "> ,</ span >
0 commit comments