230230 < div class ="pytorch-left-menu-search ">
231231
232232 < div class ="version ">
233- < a href ='https://pytorch.org/docs/versions.html '> main (2.1.0a0+git0c4fa02 ) ▼</ a >
233+ < a href ='https://pytorch.org/docs/versions.html '> main (2.1.0a0+git3828cd4 ) ▼</ a >
234234 </ div >
235235
236236
@@ -559,7 +559,8 @@ <h1>Source code for torch._tensor</h1><div class="highlight"><pre>
559559 < span class ="c1 "> # Update the test in test_serialization if you remove 'meta' from here</ span >
560560 < span class ="k "> if</ span > < span class ="p "> (</ span >
561561 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> is_sparse</ span >
562- < span class ="ow "> or</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="o "> .</ span > < span class ="n "> type</ span > < span class ="ow "> in</ span > < span class ="p "> [</ span > < span class ="s2 "> "lazy"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "xla"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "mps"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "ort"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "meta"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "ipu"</ span > < span class ="p "> ]</ span >
562+ < span class ="ow "> or</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="o "> .</ span > < span class ="n "> type</ span >
563+ < span class ="ow "> in</ span > < span class ="p "> [</ span > < span class ="s2 "> "lazy"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "xla"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "mtia"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "mps"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "ort"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "meta"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "ipu"</ span > < span class ="p "> ]</ span >
563564 < span class ="ow "> or</ span > < span class ="p "> (</ span >
564565 < span class ="ow "> not</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _C</ span > < span class ="o "> .</ span > < span class ="n "> _has_storage</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> )</ span >
565566 < span class ="ow "> and</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="o "> .</ span > < span class ="n "> type</ span > < span class ="o "> ==</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _C</ span > < span class ="o "> .</ span > < span class ="n "> _get_privateuse1_backend_name</ span > < span class ="p "> ()</ span >
@@ -707,7 +708,7 @@ <h1>Source code for torch._tensor</h1><div class="highlight"><pre>
707708 < span class ="c1 "> # See Note [Don't serialize hooks]</ span >
708709 < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> utils</ span > < span class ="o "> .</ span > < span class ="n "> hooks</ span > < span class ="o "> .</ span > < span class ="n "> warn_if_has_hooks</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> )</ span >
709710 < span class ="n "> backward_hooks</ span > < span class ="p "> :</ span > < span class ="n "> Dict</ span > < span class ="p "> [</ span > < span class ="n "> Any</ span > < span class ="p "> ,</ span > < span class ="n "> Any</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> OrderedDict</ span > < span class ="p "> ()</ span >
710- < span class ="c1 "> # Note: Numpy array is chosen to be the rebuild component for XLA, ORT Tensors.</ span >
711+ < span class ="c1 "> # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, ORT Tensors.</ span >
711712 < span class ="c1 "> # We considered a few options:</ span >
712713 < span class ="c1 "> # 1. CPU tensor can't be used here.</ span >
713714 < span class ="c1 "> # Otherwise in torch.load CPU storage is reconstructed with randomly</ span >
@@ -717,7 +718,7 @@ <h1>Source code for torch._tensor</h1><div class="highlight"><pre>
717718 < span class ="c1 "> # 2. Python list is not a good fit due to performance reason.</ span >
718719 < span class ="c1 "> # `tolist()` converts every single element in the tensor into python objects</ span >
719720 < span class ="c1 "> # and serialize them one by one.</ span >
720- < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="o "> .</ span > < span class ="n "> type</ span > < span class ="ow "> in</ span > < span class ="p "> [</ span > < span class ="s2 "> "xla"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "ort"</ span > < span class ="p "> ]</ span > < span class ="ow "> or</ span > < span class ="p "> (</ span >
721+ < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="o "> .</ span > < span class ="n "> type</ span > < span class ="ow "> in</ span > < span class ="p "> [</ span > < span class ="s2 "> "xla"</ span > < span class ="p "> ,</ span > < span class ="s2 "> "mtia" </ span > < span class =" p " > , </ span > < span class =" s2 " > " ort"</ span > < span class ="p "> ]</ span > < span class ="ow "> or</ span > < span class ="p "> (</ span >
721722 < span class ="ow "> not</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _C</ span > < span class ="o "> .</ span > < span class ="n "> _has_storage</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> )</ span >
722723 < span class ="ow "> and</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="o "> .</ span > < span class ="n "> type</ span > < span class ="o "> ==</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> _C</ span > < span class ="o "> .</ span > < span class ="n "> _get_privateuse1_backend_name</ span > < span class ="p "> ()</ span >
723724 < span class ="p "> ):</ span >
@@ -988,7 +989,7 @@ <h1>Source code for torch._tensor</h1><div class="highlight"><pre>
988989 < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> Tensor</ span > < span class ="o "> .</ span > < span class ="n "> register_hook</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,),</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> hook</ span > < span class ="p "> )</ span >
989990 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> requires_grad</ span > < span class ="p "> :</ span >
990991 < span class ="k "> raise</ span > < span class ="ne "> RuntimeError</ span > < span class ="p "> (</ span >
991- < span class ="s2 "> "cannot register a hook on a tensor that doesn't require gradient"</ span >
992+ < span class ="s2 "> "cannot register a hook on a tensor that " </ span > < span class =" s2 " > " doesn't require gradient"</ span >
992993 < span class ="p "> )</ span >
993994 < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _backward_hooks</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
994995 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _backward_hooks</ span > < span class ="o "> =</ span > < span class ="n "> OrderedDict</ span > < span class ="p "> ()</ span >
@@ -998,62 +999,6 @@ <h1>Source code for torch._tensor</h1><div class="highlight"><pre>
998999 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _backward_hooks</ span > < span class ="p "> [</ span > < span class ="n "> handle</ span > < span class ="o "> .</ span > < span class ="n "> id</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> hook</ span >
9991000 < span class ="k "> return</ span > < span class ="n "> handle</ span > </ div >
10001001
1001- < span class ="k "> def</ span > < span class ="nf "> register_post_accumulate_grad_hook</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> hook</ span > < span class ="p "> ):</ span >
1002- < span class ="w "> </ span > < span class ="sa "> r</ span > < span class ="sd "> """Registers a backward hook that runs after grad accumulation.</ span >
1003-
1004- < span class ="sd "> The hook will be called after all gradients for a tensor have been accumulated,</ span >
1005- < span class ="sd "> meaning that the .grad field has been updated on that tensor. The post</ span >
1006- < span class ="sd "> accumulate grad hook is ONLY applicable for leaf tensors (tensors without a</ span >
1007- < span class ="sd "> .grad_fn field). Registering this hook on a non-leaf tensor will error!</ span >
1008-
1009- < span class ="sd "> The hook should have the following signature::</ span >
1010-
1011- < span class ="sd "> hook(param: Tensor) -> None</ span >
1012-
1013- < span class ="sd "> Note that, unlike other autograd hooks, this hook operates on the tensor</ span >
1014- < span class ="sd "> that requires grad and not the grad itself. The hook can in-place modify</ span >
1015- < span class ="sd "> and access its Tensor argument, including its .grad field.</ span >
1016-
1017- < span class ="sd "> This function returns a handle with a method ``handle.remove()``</ span >
1018- < span class ="sd "> that removes the hook from the module.</ span >
1019-
1020- < span class ="sd "> .. note::</ span >
1021- < span class ="sd "> See :ref:`backward-hooks-execution` for more information on how when this hook</ span >
1022- < span class ="sd "> is executed, and how its execution is ordered relative to other hooks. Since</ span >
1023- < span class ="sd "> this hook runs during the backward pass, it will run in no_grad mode (unless</ span >
1024- < span class ="sd "> create_graph is True). You can use torch.enable_grad() to re-enable autograd</ span >
1025- < span class ="sd "> within the hook if you need it.</ span >
1026-
1027- < span class ="sd "> Example::</ span >
1028-
1029- < span class ="sd "> >>> v = torch.tensor([0., 0., 0.], requires_grad=True)</ span >
1030- < span class ="sd "> >>> lr = 0.01</ span >
1031- < span class ="sd "> >>> # simulate a simple SGD update</ span >
1032- < span class ="sd "> >>> h = v.register_post_accumulate_grad_hook(lambda p: p.add_(p.grad, alpha=-lr))</ span >
1033- < span class ="sd "> >>> v.backward(torch.tensor([1., 2., 3.]))</ span >
1034- < span class ="sd "> >>> v</ span >
1035- < span class ="sd "> tensor([-0.0100, -0.0200, -0.0300], requires_grad=True)</ span >
1036-
1037- < span class ="sd "> >>> h.remove() # removes the hook</ span >
1038- < span class ="sd "> """</ span >
1039- < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ):</ span >
1040- < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span >
1041- < span class ="n "> Tensor</ span > < span class ="o "> .</ span > < span class ="n "> register_post_accumulate_grad_hook</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,),</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> hook</ span >
1042- < span class ="p "> )</ span >
1043- < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> requires_grad</ span > < span class ="p "> :</ span >
1044- < span class ="k "> raise</ span > < span class ="ne "> RuntimeError</ span > < span class ="p "> (</ span >
1045- < span class ="s2 "> "cannot register a hook on a tensor that doesn't require gradient"</ span >
1046- < span class ="p "> )</ span >
1047- < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> grad_fn</ span > < span class ="ow "> is</ span > < span class ="ow "> not</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1048- < span class ="k "> raise</ span > < span class ="ne "> RuntimeError</ span > < span class ="p "> (</ span >
1049- < span class ="s2 "> "post accumulate grad hooks cannot be registered on non-leaf tensors"</ span >
1050- < span class ="p "> )</ span >
1051- < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _post_accumulate_grad_hooks</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
1052- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _post_accumulate_grad_hooks</ span > < span class ="p "> :</ span > < span class ="n "> Dict</ span > < span class ="p "> [</ span > < span class ="n "> Any</ span > < span class ="p "> ,</ span > < span class ="n "> Any</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> OrderedDict</ span > < span class ="p "> ()</ span >
1053- < span class ="n "> handle</ span > < span class ="o "> =</ span > < span class ="n "> hooks</ span > < span class ="o "> .</ span > < span class ="n "> RemovableHandle</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _post_accumulate_grad_hooks</ span > < span class ="p "> )</ span >
1054- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> _post_accumulate_grad_hooks</ span > < span class ="p "> [</ span > < span class ="n "> handle</ span > < span class ="o "> .</ span > < span class ="n "> id</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="n "> hook</ span >
1055- < span class ="k "> return</ span > < span class ="n "> handle</ span >
1056-
10571002 < span class ="k "> def</ span > < span class ="nf "> reinforce</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> reward</ span > < span class ="p "> ):</ span >
10581003 < span class ="k "> def</ span > < span class ="nf "> trim</ span > < span class ="p "> (</ span > < span class ="nb "> str</ span > < span class ="p "> ):</ span >
10591004 < span class ="k "> return</ span > < span class ="s2 "> "</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> ([</ span > < span class ="n "> line</ span > < span class ="o "> .</ span > < span class ="n "> strip</ span > < span class ="p "> ()</ span > < span class ="k "> for</ span > < span class ="n "> line</ span > < span class ="ow "> in</ span > < span class ="nb "> str</ span > < span class ="o "> .</ span > < span class ="n "> split</ span > < span class ="p "> (</ span > < span class ="s2 "> "</ span > < span class ="se "> \n</ span > < span class ="s2 "> "</ span > < span class ="p "> )])</ span >
0 commit comments