9
9
from torch import Tensor , device
10
10
from torch .nn import Module
11
11
12
- from .common import _reduce_list , _run_forward , _sort_key_list , _verify_select_column
12
+ from .common import _reduce_list , _run_forward , _sort_key_list , _verify_select_neuron
13
13
from .typing import (
14
14
Literal ,
15
15
ModuleOrModuleList ,
@@ -125,22 +125,20 @@ def _neuron_gradients(
125
125
inputs : Union [Tensor , Tuple [Tensor , ...]],
126
126
saved_layer : Dict [device , Tuple [Tensor , ...]],
127
127
key_list : List [device ],
128
- gradient_neuron_index : Union [int , Tuple [Union [int , slice ], ...]],
128
+ gradient_neuron_selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
129
129
) -> Tuple [Tensor , ...]:
130
130
with torch .autograd .set_grad_enabled (True ):
131
131
gradient_tensors = []
132
132
for key in key_list :
133
- assert (
134
- len (saved_layer [key ]) == 1
135
- ), "Cannot compute neuron gradients for layer with multiple tensors."
136
- current_out_tensor = _verify_select_column (
137
- saved_layer [key ][0 ], gradient_neuron_index
133
+ current_out_tensor = _verify_select_neuron (
134
+ saved_layer [key ], gradient_neuron_selector
138
135
)
139
136
gradient_tensors .append (
140
137
torch .autograd .grad (
141
- torch .unbind (current_out_tensor ),
138
+ torch .unbind (current_out_tensor )
139
+ if current_out_tensor .numel () > 1
140
+ else current_out_tensor ,
142
141
inputs ,
143
- grad_outputs = torch .unbind (torch .ones_like (current_out_tensor )),
144
142
)
145
143
)
146
144
_total_gradients = _reduce_list (gradient_tensors , sum )
@@ -187,7 +185,7 @@ def _forward_layer_eval(
187
185
inputs ,
188
186
layer ,
189
187
additional_forward_args = additional_forward_args ,
190
- gradient_neuron_index = None ,
188
+ gradient_neuron_selector = None ,
191
189
grad_enabled = grad_enabled ,
192
190
device_ids = device_ids ,
193
191
attribute_to_layer_input = attribute_to_layer_input ,
@@ -369,7 +367,7 @@ def _forward_layer_eval_with_neuron_grads(
369
367
layer : Module ,
370
368
additional_forward_args : Any = None ,
371
369
* ,
372
- gradient_neuron_index : Union [int , Tuple [Union [int , slice ], ...]],
370
+ gradient_neuron_selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
373
371
grad_enabled : bool = False ,
374
372
device_ids : Union [None , List [int ]] = None ,
375
373
attribute_to_layer_input : bool = False ,
@@ -383,7 +381,7 @@ def _forward_layer_eval_with_neuron_grads(
383
381
inputs : Union [Tensor , Tuple [Tensor , ...]],
384
382
layer : Module ,
385
383
additional_forward_args : Any = None ,
386
- gradient_neuron_index : None = None ,
384
+ gradient_neuron_selector : None = None ,
387
385
grad_enabled : bool = False ,
388
386
device_ids : Union [None , List [int ]] = None ,
389
387
attribute_to_layer_input : bool = False ,
@@ -397,7 +395,7 @@ def _forward_layer_eval_with_neuron_grads(
397
395
inputs : Union [Tensor , Tuple [Tensor , ...]],
398
396
layer : List [Module ],
399
397
additional_forward_args : Any = None ,
400
- gradient_neuron_index : None = None ,
398
+ gradient_neuron_selector : None = None ,
401
399
grad_enabled : bool = False ,
402
400
device_ids : Union [None , List [int ]] = None ,
403
401
attribute_to_layer_input : bool = False ,
@@ -410,7 +408,9 @@ def _forward_layer_eval_with_neuron_grads(
410
408
inputs : Union [Tensor , Tuple [Tensor , ...]],
411
409
layer : ModuleOrModuleList ,
412
410
additional_forward_args : Any = None ,
413
- gradient_neuron_index : Union [None , int , Tuple [Union [int , slice ], ...]] = None ,
411
+ gradient_neuron_selector : Union [
412
+ None , int , Tuple [Union [int , slice ], ...], Callable
413
+ ] = None ,
414
414
grad_enabled : bool = False ,
415
415
device_ids : Union [None , List [int ]] = None ,
416
416
attribute_to_layer_input : bool = False ,
@@ -421,7 +421,7 @@ def _forward_layer_eval_with_neuron_grads(
421
421
]:
422
422
"""
423
423
This method computes forward evaluation for a particular layer using a
424
- forward hook. If a gradient_neuron_index is provided, then gradients with
424
+ forward hook. If a gradient_neuron_selector is provided, then gradients with
425
425
respect to that neuron in the layer output are also returned.
426
426
427
427
These functionalities are combined due to the behavior of DataParallel models
@@ -435,7 +435,7 @@ def _forward_layer_eval_with_neuron_grads(
435
435
evals in a dictionary protected by a lock, analogous to the gather implementation
436
436
for the core PyTorch DataParallel implementation.
437
437
"""
438
- grad_enabled = True if gradient_neuron_index is not None or grad_enabled else False
438
+ grad_enabled = True if gradient_neuron_selector is not None else grad_enabled
439
439
440
440
with torch .autograd .set_grad_enabled (grad_enabled ):
441
441
saved_layer = _forward_layer_distributed_eval (
@@ -450,12 +450,12 @@ def _forward_layer_eval_with_neuron_grads(
450
450
# key_list is a list of devices in appropriate ordering for concatenation.
451
451
# If only one key exists (standard model), key list simply has one element.
452
452
key_list = _sort_key_list (list (next (iter (saved_layer .values ())).keys ()), device_ids )
453
- if gradient_neuron_index is not None :
453
+ if gradient_neuron_selector is not None :
454
454
assert isinstance (
455
455
layer , Module
456
456
), "Cannot compute neuron gradients for multiple layers simultaneously!"
457
457
inp_grads = _neuron_gradients (
458
- inputs , saved_layer [layer ], key_list , gradient_neuron_index
458
+ inputs , saved_layer [layer ], key_list , gradient_neuron_selector
459
459
)
460
460
return (
461
461
_gather_distributed_tensors (saved_layer [layer ], key_list = key_list ),
@@ -479,7 +479,7 @@ def compute_layer_gradients_and_eval(
479
479
target_ind : TargetType = None ,
480
480
additional_forward_args : Any = None ,
481
481
* ,
482
- gradient_neuron_index : Union [int , Tuple [int , ...]],
482
+ gradient_neuron_selector : Union [int , Tuple [Union [ int , slice ], ...], Callable ],
483
483
device_ids : Union [None , List [int ]] = None ,
484
484
attribute_to_layer_input : bool = False ,
485
485
output_fn : Union [None , Callable ] = None ,
@@ -494,7 +494,7 @@ def compute_layer_gradients_and_eval(
494
494
inputs : Union [Tensor , Tuple [Tensor , ...]],
495
495
target_ind : TargetType = None ,
496
496
additional_forward_args : Any = None ,
497
- gradient_neuron_index : None = None ,
497
+ gradient_neuron_selector : None = None ,
498
498
device_ids : Union [None , List [int ]] = None ,
499
499
attribute_to_layer_input : bool = False ,
500
500
output_fn : Union [None , Callable ] = None ,
@@ -509,7 +509,7 @@ def compute_layer_gradients_and_eval(
509
509
inputs : Union [Tensor , Tuple [Tensor , ...]],
510
510
target_ind : TargetType = None ,
511
511
additional_forward_args : Any = None ,
512
- gradient_neuron_index : None = None ,
512
+ gradient_neuron_selector : None = None ,
513
513
device_ids : Union [None , List [int ]] = None ,
514
514
attribute_to_layer_input : bool = False ,
515
515
output_fn : Union [None , Callable ] = None ,
@@ -523,7 +523,9 @@ def compute_layer_gradients_and_eval(
523
523
inputs : Union [Tensor , Tuple [Tensor , ...]],
524
524
target_ind : TargetType = None ,
525
525
additional_forward_args : Any = None ,
526
- gradient_neuron_index : Union [None , int , Tuple [int , ...]] = None ,
526
+ gradient_neuron_selector : Union [
527
+ None , int , Tuple [Union [int , slice ], ...], Callable
528
+ ] = None ,
527
529
device_ids : Union [None , List [int ]] = None ,
528
530
attribute_to_layer_input : bool = False ,
529
531
output_fn : Union [None , Callable ] = None ,
@@ -659,12 +661,12 @@ def compute_layer_gradients_and_eval(
659
661
if isinstance (layer , Module ):
660
662
layer_grads = all_grads [0 ]
661
663
662
- if gradient_neuron_index is not None :
664
+ if gradient_neuron_selector is not None :
663
665
assert isinstance (
664
666
layer , Module
665
667
), "Cannot compute neuron gradients for multiple layers simultaneously!"
666
668
inp_grads = _neuron_gradients (
667
- inputs , saved_layer [layer ], key_list , gradient_neuron_index
669
+ inputs , saved_layer [layer ], key_list , gradient_neuron_selector
668
670
)
669
671
return (
670
672
cast (Tuple [Tensor , ...], layer_grads ),
@@ -676,7 +678,7 @@ def compute_layer_gradients_and_eval(
676
678
677
679
def construct_neuron_grad_fn (
678
680
layer : Module ,
679
- neuron_index : Union [int , Tuple [Union [int , slice ], ...]],
681
+ neuron_selector : Union [int , Tuple [Union [int , slice ], ...], Callable ],
680
682
device_ids : Union [None , List [int ]] = None ,
681
683
attribute_to_neuron_input : bool = False ,
682
684
) -> Callable :
@@ -691,7 +693,7 @@ def grad_fn(
691
693
inputs ,
692
694
layer ,
693
695
additional_forward_args ,
694
- gradient_neuron_index = neuron_index ,
696
+ gradient_neuron_selector = neuron_selector ,
695
697
device_ids = device_ids ,
696
698
attribute_to_layer_input = attribute_to_neuron_input ,
697
699
)
0 commit comments