235
235
< div class ="pytorch-left-menu-search ">
236
236
237
237
< div class ="version ">
238
- < a href ='https://pytorch.org/docs/versions.html '> master (2.0.0a0+git5ed7c70 ) ▼</ a >
238
+ < a href ='https://pytorch.org/docs/versions.html '> master (2.0.0a0+git65b9983 ) ▼</ a >
239
239
</ div >
240
240
241
241
@@ -494,8 +494,6 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
494
494
< span class ="k "> else</ span > < span class ="p "> :</ span >
495
495
< span class ="kn "> from</ span > < span class ="nn "> .torch_version</ span > < span class ="kn "> import</ span > < span class ="n "> __version__</ span > < span class ="k "> as</ span > < span class ="n "> __version__</ span >
496
496
497
- < span class ="kn "> from</ span > < span class ="nn "> ._six</ span > < span class ="kn "> import</ span > < span class ="n "> string_classes</ span > < span class ="k "> as</ span > < span class ="n "> _string_classes</ span >
498
-
499
497
< span class ="kn "> from</ span > < span class ="nn "> typing</ span > < span class ="kn "> import</ span > < span class ="n "> Any</ span > < span class ="p "> ,</ span > < span class ="n "> Callable</ span > < span class ="p "> ,</ span > < span class ="n "> Dict</ span > < span class ="p "> ,</ span > < span class ="n "> Optional</ span > < span class ="p "> ,</ span > < span class ="n "> Set</ span > < span class ="p "> ,</ span > < span class ="n "> Type</ span > < span class ="p "> ,</ span > < span class ="n "> TYPE_CHECKING</ span > < span class ="p "> ,</ span > < span class ="n "> Union</ span >
500
498
< span class ="kn "> import</ span > < span class ="nn "> builtins</ span >
501
499
@@ -603,29 +601,24 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
603
601
< span class ="n "> kernel32</ span > < span class ="o "> .</ span > < span class ="n "> SetErrorMode</ span > < span class ="p "> (</ span > < span class ="n "> prev_error_mode</ span > < span class ="p "> )</ span >
604
602
605
603
606
- < span class ="k "> def</ span > < span class ="nf "> _preload_cuda_deps</ span > < span class ="p "> ():</ span >
607
- < span class ="sd "> """Preloads cudnn/cublas deps if they could not be found otherwise."""</ span >
604
+ < span class ="k "> def</ span > < span class ="nf "> _preload_cuda_deps</ span > < span class ="p "> (</ span > < span class =" n " > lib_folder </ span > < span class =" p " > , </ span > < span class =" n " > lib_name </ span > < span class =" p " > ):</ span >
605
+ < span class ="sd "> """Preloads cuda deps if they could not be found otherwise."""</ span >
608
606
< span class ="c1 "> # Should only be called on Linux if default path resolution have failed</ span >
609
607
< span class ="k "> assert</ span > < span class ="n "> platform</ span > < span class ="o "> .</ span > < span class ="n "> system</ span > < span class ="p "> ()</ span > < span class ="o "> ==</ span > < span class ="s1 "> 'Linux'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'Should only be called on Linux'</ span >
610
- < span class ="n " > cublas_path </ span > < span class ="o " > = </ span > < span class =" kc " > None </ span >
611
- < span class ="n "> cudnn_path </ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
608
+ < span class ="kn " > import </ span > < span class ="nn " > glob </ span >
609
+ < span class ="n "> lib_path </ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
612
610
< span class ="k "> for</ span > < span class ="n "> path</ span > < span class ="ow "> in</ span > < span class ="n "> sys</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="p "> :</ span >
613
611
< span class ="n "> nvidia_path</ span > < span class ="o "> =</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> (</ span > < span class ="n "> path</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'nvidia'</ span > < span class ="p "> )</ span >
614
612
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> exists</ span > < span class ="p "> (</ span > < span class ="n "> nvidia_path</ span > < span class ="p "> ):</ span >
615
613
< span class ="k "> continue</ span >
616
- < span class ="n "> candidate_cublas_path</ span > < span class ="o "> =</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> (</ span > < span class ="n "> nvidia_path</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'cublas'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'lib'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'libcublas.so.11'</ span > < span class ="p "> )</ span >
617
- < span class ="k "> if</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> exists</ span > < span class ="p "> (</ span > < span class ="n "> candidate_cublas_path</ span > < span class ="p "> )</ span > < span class ="ow "> and</ span > < span class ="ow "> not</ span > < span class ="n "> cublas_path</ span > < span class ="p "> :</ span >
618
- < span class ="n "> cublas_path</ span > < span class ="o "> =</ span > < span class ="n "> candidate_cublas_path</ span >
619
- < span class ="n "> candidate_cudnn_path</ span > < span class ="o "> =</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> (</ span > < span class ="n "> nvidia_path</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'cudnn'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'lib'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'libcudnn.so.8'</ span > < span class ="p "> )</ span >
620
- < span class ="k "> if</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> exists</ span > < span class ="p "> (</ span > < span class ="n "> candidate_cudnn_path</ span > < span class ="p "> )</ span > < span class ="ow "> and</ span > < span class ="ow "> not</ span > < span class ="n "> cudnn_path</ span > < span class ="p "> :</ span >
621
- < span class ="n "> cudnn_path</ span > < span class ="o "> =</ span > < span class ="n "> candidate_cudnn_path</ span >
622
- < span class ="k "> if</ span > < span class ="n "> cublas_path</ span > < span class ="ow "> and</ span > < span class ="n "> cudnn_path</ span > < span class ="p "> :</ span >
614
+ < span class ="n "> candidate_lib_paths</ span > < span class ="o "> =</ span > < span class ="n "> glob</ span > < span class ="o "> .</ span > < span class ="n "> glob</ span > < span class ="p "> (</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> (</ span > < span class ="n "> nvidia_path</ span > < span class ="p "> ,</ span > < span class ="n "> lib_folder</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'lib'</ span > < span class ="p "> ,</ span > < span class ="n "> lib_name</ span > < span class ="p "> ))</ span >
615
+ < span class ="k "> if</ span > < span class ="n "> candidate_lib_paths</ span > < span class ="ow "> and</ span > < span class ="ow "> not</ span > < span class ="n "> lib_path</ span > < span class ="p "> :</ span >
616
+ < span class ="n "> lib_path</ span > < span class ="o "> =</ span > < span class ="n "> candidate_lib_paths</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
617
+ < span class ="k "> if</ span > < span class ="n "> lib_path</ span > < span class ="p "> :</ span >
623
618
< span class ="k "> break</ span >
624
- < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> cublas_path</ span > < span class ="ow "> or</ span > < span class ="ow "> not</ span > < span class ="n "> cudnn_path</ span > < span class ="p "> :</ span >
625
- < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "cublas and cudnn not found in the system path </ span > < span class ="si "> {</ span > < span class ="n "> sys</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
626
-
627
- < span class ="n "> ctypes</ span > < span class ="o "> .</ span > < span class ="n "> CDLL</ span > < span class ="p "> (</ span > < span class ="n "> cublas_path</ span > < span class ="p "> )</ span >
628
- < span class ="n "> ctypes</ span > < span class ="o "> .</ span > < span class ="n "> CDLL</ span > < span class ="p "> (</ span > < span class ="n "> cudnn_path</ span > < span class ="p "> )</ span >
619
+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> lib_path</ span > < span class ="p "> :</ span >
620
+ < span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="sa "> f</ span > < span class ="s2 "> "</ span > < span class ="si "> {</ span > < span class ="n "> lib_name</ span > < span class ="si "> }</ span > < span class ="s2 "> not found in the system path </ span > < span class ="si "> {</ span > < span class ="n "> sys</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="si "> }</ span > < span class ="s2 "> "</ span > < span class ="p "> )</ span >
621
+ < span class ="n "> ctypes</ span > < span class ="o "> .</ span > < span class ="n "> CDLL</ span > < span class ="p "> (</ span > < span class ="n "> lib_path</ span > < span class ="p "> )</ span >
629
622
630
623
631
624
< span class ="c1 "> # See Note [Global dependencies]</ span >
@@ -640,11 +633,26 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
640
633
< span class ="k "> try</ span > < span class ="p "> :</ span >
641
634
< span class ="n "> ctypes</ span > < span class ="o "> .</ span > < span class ="n "> CDLL</ span > < span class ="p "> (</ span > < span class ="n "> lib_path</ span > < span class ="p "> ,</ span > < span class ="n "> mode</ span > < span class ="o "> =</ span > < span class ="n "> ctypes</ span > < span class ="o "> .</ span > < span class ="n "> RTLD_GLOBAL</ span > < span class ="p "> )</ span >
642
635
< span class ="k "> except</ span > < span class ="ne "> OSError</ span > < span class ="k "> as</ span > < span class ="n "> err</ span > < span class ="p "> :</ span >
643
- < span class ="c1 "> # Can only happen of wheel with cublas as PYPI deps</ span >
644
- < span class ="c1 "> # As PyTorch is not purelib, but nvidia-cublas-cu11 is</ span >
645
- < span class ="k "> if</ span > < span class ="s1 "> 'libcublas.so.11'</ span > < span class ="ow "> not</ span > < span class ="ow "> in</ span > < span class ="n "> err</ span > < span class ="o "> .</ span > < span class ="n "> args</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]:</ span >
636
+ < span class ="c1 "> # Can only happen for wheel with cuda libs as PYPI deps</ span >
637
+ < span class ="c1 "> # As PyTorch is not purelib, but nvidia-*-cu11 is</ span >
638
+ < span class ="n "> cuda_libs</ span > < span class ="p "> :</ span > < span class ="n "> Dict</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="p "> {</ span >
639
+ < span class ="s1 "> 'cublas'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libcublas.so.*[0-9]'</ span > < span class ="p "> ,</ span >
640
+ < span class ="s1 "> 'cudnn'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libcudnn.so.*[0-9]'</ span > < span class ="p "> ,</ span >
641
+ < span class ="s1 "> 'cuda_nvrtc'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libnvrtc.so.*[0-9].*[0-9]'</ span > < span class ="p "> ,</ span >
642
+ < span class ="s1 "> 'cuda_runtime'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libcudart.so.*[0-9].*[0-9]'</ span > < span class ="p "> ,</ span >
643
+ < span class ="s1 "> 'cuda_cupti'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libcupti.so.*[0-9].*[0-9]'</ span > < span class ="p "> ,</ span >
644
+ < span class ="s1 "> 'cufft'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libcufft.so.*[0-9]'</ span > < span class ="p "> ,</ span >
645
+ < span class ="s1 "> 'curand'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libcurand.so.*[0-9]'</ span > < span class ="p "> ,</ span >
646
+ < span class ="s1 "> 'cusolver'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libcusolver.so.*[0-9]'</ span > < span class ="p "> ,</ span >
647
+ < span class ="s1 "> 'cusparse'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libcusparse.so.*[0-9]'</ span > < span class ="p "> ,</ span >
648
+ < span class ="s1 "> 'nccl'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libnccl.so.*[0-9]'</ span > < span class ="p "> ,</ span >
649
+ < span class ="s1 "> 'nvtx'</ span > < span class ="p "> :</ span > < span class ="s1 "> 'libnvToolsExt.so.*[0-9]'</ span > < span class ="p "> ,</ span >
650
+ < span class ="p "> }</ span >
651
+ < span class ="n "> is_cuda_lib_err</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="n "> lib</ span > < span class ="k "> for</ span > < span class ="n "> lib</ span > < span class ="ow "> in</ span > < span class ="n "> cuda_libs</ span > < span class ="o "> .</ span > < span class ="n "> values</ span > < span class ="p "> ()</ span > < span class ="k "> if</ span > < span class ="p "> (</ span > < span class ="n "> lib</ span > < span class ="o "> .</ span > < span class ="n "> split</ span > < span class ="p "> (</ span > < span class ="s1 "> '.'</ span > < span class ="p "> )[</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span > < span class ="ow "> in</ span > < span class ="n "> err</ span > < span class ="o "> .</ span > < span class ="n "> args</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ])]</ span >
652
+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> is_cuda_lib_err</ span > < span class ="p "> :</ span >
646
653
< span class ="k "> raise</ span > < span class ="n "> err</ span >
647
- < span class ="n "> _preload_cuda_deps</ span > < span class ="p "> ()</ span >
654
+ < span class ="k "> for</ span > < span class ="n "> lib_folder</ span > < span class ="p "> ,</ span > < span class ="n "> lib_name</ span > < span class ="ow "> in</ span > < span class ="n "> cuda_libs</ span > < span class ="o "> .</ span > < span class ="n "> items</ span > < span class ="p "> ():</ span >
655
+ < span class ="n "> _preload_cuda_deps</ span > < span class ="p "> (</ span > < span class ="n "> lib_folder</ span > < span class ="p "> ,</ span > < span class ="n "> lib_name</ span > < span class ="p "> )</ span >
648
656
< span class ="n "> ctypes</ span > < span class ="o "> .</ span > < span class ="n "> CDLL</ span > < span class ="p "> (</ span > < span class ="n "> lib_path</ span > < span class ="p "> ,</ span > < span class ="n "> mode</ span > < span class ="o "> =</ span > < span class ="n "> ctypes</ span > < span class ="o "> .</ span > < span class ="n "> RTLD_GLOBAL</ span > < span class ="p "> )</ span >
649
657
650
658
@@ -1059,7 +1067,7 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
1059
1067
< span class ="sd "> torch.float64</ span >
1060
1068
1061
1069
< span class ="sd "> """</ span >
1062
- < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n " > _string_classes </ span > < span class ="p "> ):</ span >
1070
+ < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="nb " > str </ span > < span class ="p "> ):</ span >
1063
1071
< span class ="n "> t</ span > < span class ="o "> =</ span > < span class ="n "> _import_dotted_name</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> )</ span >
1064
1072
< span class ="n "> _C</ span > < span class ="o "> .</ span > < span class ="n "> _set_default_tensor_type</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> )</ span > </ div >
1065
1073
0 commit comments