193
193
< div class ="pytorch-left-menu-search ">
194
194
195
195
< div class ="version ">
196
- < a href ='https://pytorch.org/docs/versions.html '> master (1.11.0a0+git125a559 ) ▼</ a >
196
+ < a href ='https://pytorch.org/docs/versions.html '> master (1.11.0a0+gita21f2ab ) ▼</ a >
197
197
</ div >
198
198
199
199
@@ -479,8 +479,8 @@ <h1>Source code for torch._vmap_internals</h1><div class="highlight"><pre>
479
479
< span class ="c1 "> # Undos the batching (and any batch dimensions) associated with the `vmap_level`.</ span >
480
480
< span class ="k "> def</ span > < span class ="nf "> _unwrap_batched</ span > < span class ="p "> (</ span >
481
481
< span class ="n "> batched_outputs</ span > < span class ="p "> :</ span > < span class ="n "> Union</ span > < span class ="p "> [</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="n "> Tuple</ span > < span class ="p "> [</ span > < span class ="n "> Tensor</ span > < span class ="p "> ,</ span > < span class ="o "> ...</ span > < span class ="p "> ]],</ span >
482
- < span class ="n "> out_dims</ span > < span class ="p "> :</ span > < span class ="n "> out_dims_t</ span > < span class ="p "> ,</ span >
483
- < span class ="n "> vmap_level </ span > < span class ="p "> :</ span > < span class ="nb "> int </ span > < span class =" p " > , </ span > < span class ="n " > batch_size </ span > < span class =" p " > : </ span > < span class ="nb " > int </ span > < span class =" p " > , </ span > < span class =" n " > func </ span > < span class =" p " > : </ span > < span class =" n " > Callable </ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Tuple</ span > < span class ="p "> :</ span >
482
+ < span class ="n "> out_dims</ span > < span class ="p "> :</ span > < span class ="n "> out_dims_t</ span > < span class ="p "> ,</ span > < span class =" n " > vmap_level </ span > < span class =" p " > : </ span > < span class =" nb " > int </ span > < span class =" p " > , </ span > < span class =" n " > batch_size </ span > < span class =" p " > : </ span > < span class =" nb " > int </ span > < span class =" p " > , </ span > < span class =" n " > func </ span > < span class =" p " > : </ span > < span class =" n " > Callable </ span > < span class =" p " > , </ span >
483
+ < span class ="n "> allow_none_pass_through </ span > < span class ="p "> :</ span > < span class ="nb "> bool </ span > < span class ="o " > = </ span > < span class ="kc " > False </ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Tuple</ span > < span class ="p "> :</ span >
484
484
< span class ="n "> num_outputs</ span > < span class ="o "> =</ span > < span class ="n "> _num_outputs</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> )</ span >
485
485
< span class ="n "> out_dims_as_tuple</ span > < span class ="o "> =</ span > < span class ="n "> _as_tuple</ span > < span class ="p "> (</ span >
486
486
< span class ="n "> out_dims</ span > < span class ="p "> ,</ span > < span class ="n "> num_outputs</ span > < span class ="p "> ,</ span >
@@ -493,8 +493,12 @@ <h1>Source code for torch._vmap_internals</h1><div class="highlight"><pre>
493
493
< span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> Tensor</ span > < span class ="p "> ):</ span >
494
494
< span class ="n "> out_dim</ span > < span class ="o "> =</ span > < span class ="n "> out_dims_as_tuple</ span > < span class ="p "> [</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
495
495
< span class ="k "> return</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _remove_batch_dim</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> vmap_level</ span > < span class ="p "> ,</ span > < span class ="n "> batch_size</ span > < span class ="p "> ,</ span > < span class ="n "> out_dim</ span > < span class ="p "> )</ span > < span class ="c1 "> # type: ignore[return-value]</ span >
496
- < span class ="k "> return</ span > < span class ="nb "> tuple</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _remove_batch_dim</ span > < span class ="p "> (</ span > < span class ="n "> out</ span > < span class ="p "> ,</ span > < span class ="n "> vmap_level</ span > < span class ="p "> ,</ span > < span class ="n "> batch_size</ span > < span class ="p "> ,</ span > < span class ="n "> out_dim</ span > < span class ="p "> )</ span >
497
- < span class ="k "> for</ span > < span class ="n "> out</ span > < span class ="p "> ,</ span > < span class ="n "> out_dim</ span > < span class ="ow "> in</ span > < span class ="nb "> zip</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> out_dims_as_tuple</ span > < span class ="p "> ))</ span >
496
+ < span class ="k "> if</ span > < span class ="n "> allow_none_pass_through</ span > < span class ="p "> :</ span >
497
+ < span class ="k "> return</ span > < span class ="nb "> tuple</ span > < span class ="p "> ((</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _remove_batch_dim</ span > < span class ="p "> (</ span > < span class ="n "> out</ span > < span class ="p "> ,</ span > < span class ="n "> vmap_level</ span > < span class ="p "> ,</ span > < span class ="n "> batch_size</ span > < span class ="p "> ,</ span > < span class ="n "> out_dim</ span > < span class ="p "> )</ span > < span class ="k "> if</ span > < span class ="n "> out</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span > < span class ="k "> else</ span > < span class ="kc "> None</ span > < span class ="p "> )</ span >
498
+ < span class ="k "> for</ span > < span class ="n "> out</ span > < span class ="p "> ,</ span > < span class ="n "> out_dim</ span > < span class ="ow "> in</ span > < span class ="nb "> zip</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> out_dims_as_tuple</ span > < span class ="p "> ))</ span >
499
+ < span class ="k "> else</ span > < span class ="p "> :</ span >
500
+ < span class ="k "> return</ span > < span class ="nb "> tuple</ span > < span class ="p "> (</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _remove_batch_dim</ span > < span class ="p "> (</ span > < span class ="n "> out</ span > < span class ="p "> ,</ span > < span class ="n "> vmap_level</ span > < span class ="p "> ,</ span > < span class ="n "> batch_size</ span > < span class ="p "> ,</ span > < span class ="n "> out_dim</ span > < span class ="p "> )</ span >
501
+ < span class ="k "> for</ span > < span class ="n "> out</ span > < span class ="p "> ,</ span > < span class ="n "> out_dim</ span > < span class ="ow "> in</ span > < span class ="nb "> zip</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> out_dims_as_tuple</ span > < span class ="p "> ))</ span >
498
502
499
503
< span class ="c1 "> # Checks that `fn` returned one or more Tensors and nothing else.</ span >
500
504
< span class ="c1 "> # NB: A python function that return multiple arguments returns a single tuple,</ span >
@@ -645,16 +649,22 @@ <h1>Source code for torch._vmap_internals</h1><div class="highlight"><pre>
645
649
< span class ="k "> return</ span > < span class ="n "> _vmap</ span > < span class ="p "> (</ span > < span class ="n "> func</ span > < span class ="p "> ,</ span > < span class ="n "> in_dims</ span > < span class ="p "> ,</ span > < span class ="n "> out_dims</ span > < span class ="p "> )</ span > </ div >
646
650
647
651
< span class ="c1 "> # A version of vmap but without the initial "experimental prototype" warning</ span >
648
- < span class ="k "> def</ span > < span class ="nf "> _vmap</ span > < span class ="p "> (</ span > < span class ="n "> func</ span > < span class ="p "> :</ span > < span class ="n "> Callable</ span > < span class ="p "> ,</ span > < span class ="n "> in_dims</ span > < span class ="p "> :</ span > < span class ="n "> in_dims_t</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="n "> out_dims</ span > < span class ="p "> :</ span > < span class ="n "> out_dims_t</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Callable</ span > < span class ="p "> :</ span >
652
+ < span class ="k "> def</ span > < span class ="nf "> _vmap</ span > < span class ="p "> (</ span > < span class ="n "> func</ span > < span class ="p "> :</ span > < span class ="n "> Callable</ span > < span class ="p "> ,</ span > < span class ="n "> in_dims</ span > < span class ="p "> :</ span > < span class ="n "> in_dims_t</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="n "> out_dims</ span > < span class ="p "> :</ span > < span class ="n "> out_dims_t</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span > < span class ="p "> ,</ span > < span class ="n "> allow_none_pass_through</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Callable</ span > < span class ="p "> :</ span >
653
+ < span class ="c1 "> # The `allow_none_pass_through` argument is a temporary workaround may be removed.</ span >
654
+ < span class ="c1 "> # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,</ span >
655
+ < span class ="c1 "> # which may return None if any of the inputs are unused. See the issue discussing this:</ span >
656
+ < span class ="c1 "> # https://github.com/facebookresearch/functorch/issues/159.</ span >
649
657
< span class ="nd "> @functools</ span > < span class ="o "> .</ span > < span class ="n "> wraps</ span > < span class ="p "> (</ span > < span class ="n "> func</ span > < span class ="p "> )</ span >
650
658
< span class ="k "> def</ span > < span class ="nf "> wrapped</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> args</ span > < span class ="p "> ):</ span >
651
659
< span class ="n "> _check_out_dims_is_int_or_int_tuple</ span > < span class ="p "> (</ span > < span class ="n "> out_dims</ span > < span class ="p "> ,</ span > < span class ="n "> func</ span > < span class ="p "> )</ span >
652
660
< span class ="n "> vmap_level</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _C</ span > < span class ="o "> .</ span > < span class ="n "> _vmapmode_increment_nesting</ span > < span class ="p "> ()</ span >
653
661
< span class ="k "> try</ span > < span class ="p "> :</ span >
654
662
< span class ="n "> batched_inputs</ span > < span class ="p "> ,</ span > < span class ="n "> batch_size</ span > < span class ="o "> =</ span > < span class ="n "> _create_batched_inputs</ span > < span class ="p "> (</ span > < span class ="n "> in_dims</ span > < span class ="p "> ,</ span > < span class ="n "> args</ span > < span class ="p "> ,</ span > < span class ="n "> vmap_level</ span > < span class ="p "> ,</ span > < span class ="n "> func</ span > < span class ="p "> )</ span >
655
663
< span class ="n "> batched_outputs</ span > < span class ="o "> =</ span > < span class ="n "> func</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> batched_inputs</ span > < span class ="p "> )</ span >
656
- < span class ="n "> _validate_outputs</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> func</ span > < span class ="p "> )</ span >
657
- < span class ="k "> return</ span > < span class ="n "> _unwrap_batched</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> out_dims</ span > < span class ="p "> ,</ span > < span class ="n "> vmap_level</ span > < span class ="p "> ,</ span > < span class ="n "> batch_size</ span > < span class ="p "> ,</ span > < span class ="n "> func</ span > < span class ="p "> )</ span >
664
+ < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> allow_none_pass_through</ span > < span class ="p "> :</ span >
665
+ < span class ="n "> _validate_outputs</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> func</ span > < span class ="p "> )</ span >
666
+ < span class ="k "> return</ span > < span class ="n "> _unwrap_batched</ span > < span class ="p "> (</ span > < span class ="n "> batched_outputs</ span > < span class ="p "> ,</ span > < span class ="n "> out_dims</ span > < span class ="p "> ,</ span > < span class ="n "> vmap_level</ span > < span class ="p "> ,</ span > < span class ="n "> batch_size</ span > < span class ="p "> ,</ span > < span class ="n "> func</ span > < span class ="p "> ,</ span >
667
+ < span class ="n "> allow_none_pass_through</ span > < span class ="o "> =</ span > < span class ="n "> allow_none_pass_through</ span > < span class ="p "> )</ span >
658
668
< span class ="k "> finally</ span > < span class ="p "> :</ span >
659
669
< span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _C</ span > < span class ="o "> .</ span > < span class ="n "> _vmapmode_decrement_nesting</ span > < span class ="p "> ()</ span >
660
670
< span class ="k "> return</ span > < span class ="n "> wrapped</ span >
0 commit comments