Skip to content

Commit aae5228

Browse files
vivekmigfacebook-github-bot
authored andcommittedOct 29, 2020
Neuron Aggregation (#495)
Summary: This adds support for neuron aggregation, neuron_selector can be a function which returns a custom aggregate of a layer's neurons for all neuron methods other than neuron conductance, which has dependence on output gradients. The neuron_index argument was renamed, and a deprecation decorator was added to provide a warning for usage of the old parameter as a keyword argument. This decorator can be removed prior to the 0.4.0 release. Documentation of the new callable functionality has been added to NeuronDeepLift, this documentation will be propagated to other relevant methods after review. Pull Request resolved: #495 Reviewed By: miguelmartin75 Differential Revision: D24346065 Pulled By: vivekmig fbshipit-source-id: c3853e19256de4c8c32a8ff615965bf513a5cd22
1 parent 31f266d commit aae5228

22 files changed

+568
-221
lines changed
 

‎README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ In this case, we choose to analyze the first neuron in the linear layer.
324324

325325
```python
326326
nc = NeuronConductance(model, model.lin1)
327-
attributions = nc.attribute(input, neuron_index=1, target=0)
327+
attributions = nc.attribute(input, neuron_selector=1, target=0)
328328
print('Neuron Attributions:', attributions)
329329
```
330330
Output

‎captum/_utils/common.py

+27
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,15 @@ def _select_targets(output: Tensor, target: TargetType) -> Tensor:
417417
raise AssertionError("Target type %r is not valid." % target)
418418

419419

420+
def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool:
421+
if isinstance(target, tuple):
422+
for index in target:
423+
if isinstance(index, slice):
424+
return True
425+
return False
426+
return isinstance(target, slice)
427+
428+
420429
def _verify_select_column(
421430
output: Tensor, target: Union[int, Tuple[Union[int, slice], ...]]
422431
) -> Tensor:
@@ -427,6 +436,24 @@ def _verify_select_column(
427436
return output[(slice(None), *target)]
428437

429438

439+
def _verify_select_neuron(
440+
layer_output: Tuple[Tensor, ...],
441+
selector: Union[int, Tuple[Union[int, slice], ...], Callable],
442+
) -> Tensor:
443+
if callable(selector):
444+
return selector(layer_output if len(layer_output) > 1 else layer_output[0])
445+
446+
assert len(layer_output) == 1, (
447+
"Cannot select neuron index from layer with multiple tensors,"
448+
"consider providing a neuron selector function instead."
449+
)
450+
451+
selected_neurons = _verify_select_column(layer_output[0], selector)
452+
if _contains_slice(selector):
453+
return selected_neurons.reshape(selected_neurons.shape[0], -1).sum(1)
454+
return selected_neurons
455+
456+
430457
def _extract_device(
431458
module: Module,
432459
hook_inputs: Union[None, Tensor, Tuple[Tensor, ...]],

‎captum/_utils/gradient.py

+28-26
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch import Tensor, device
1010
from torch.nn import Module
1111

12-
from .common import _reduce_list, _run_forward, _sort_key_list, _verify_select_column
12+
from .common import _reduce_list, _run_forward, _sort_key_list, _verify_select_neuron
1313
from .typing import (
1414
Literal,
1515
ModuleOrModuleList,
@@ -125,22 +125,20 @@ def _neuron_gradients(
125125
inputs: Union[Tensor, Tuple[Tensor, ...]],
126126
saved_layer: Dict[device, Tuple[Tensor, ...]],
127127
key_list: List[device],
128-
gradient_neuron_index: Union[int, Tuple[Union[int, slice], ...]],
128+
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
129129
) -> Tuple[Tensor, ...]:
130130
with torch.autograd.set_grad_enabled(True):
131131
gradient_tensors = []
132132
for key in key_list:
133-
assert (
134-
len(saved_layer[key]) == 1
135-
), "Cannot compute neuron gradients for layer with multiple tensors."
136-
current_out_tensor = _verify_select_column(
137-
saved_layer[key][0], gradient_neuron_index
133+
current_out_tensor = _verify_select_neuron(
134+
saved_layer[key], gradient_neuron_selector
138135
)
139136
gradient_tensors.append(
140137
torch.autograd.grad(
141-
torch.unbind(current_out_tensor),
138+
torch.unbind(current_out_tensor)
139+
if current_out_tensor.numel() > 1
140+
else current_out_tensor,
142141
inputs,
143-
grad_outputs=torch.unbind(torch.ones_like(current_out_tensor)),
144142
)
145143
)
146144
_total_gradients = _reduce_list(gradient_tensors, sum)
@@ -187,7 +185,7 @@ def _forward_layer_eval(
187185
inputs,
188186
layer,
189187
additional_forward_args=additional_forward_args,
190-
gradient_neuron_index=None,
188+
gradient_neuron_selector=None,
191189
grad_enabled=grad_enabled,
192190
device_ids=device_ids,
193191
attribute_to_layer_input=attribute_to_layer_input,
@@ -369,7 +367,7 @@ def _forward_layer_eval_with_neuron_grads(
369367
layer: Module,
370368
additional_forward_args: Any = None,
371369
*,
372-
gradient_neuron_index: Union[int, Tuple[Union[int, slice], ...]],
370+
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
373371
grad_enabled: bool = False,
374372
device_ids: Union[None, List[int]] = None,
375373
attribute_to_layer_input: bool = False,
@@ -383,7 +381,7 @@ def _forward_layer_eval_with_neuron_grads(
383381
inputs: Union[Tensor, Tuple[Tensor, ...]],
384382
layer: Module,
385383
additional_forward_args: Any = None,
386-
gradient_neuron_index: None = None,
384+
gradient_neuron_selector: None = None,
387385
grad_enabled: bool = False,
388386
device_ids: Union[None, List[int]] = None,
389387
attribute_to_layer_input: bool = False,
@@ -397,7 +395,7 @@ def _forward_layer_eval_with_neuron_grads(
397395
inputs: Union[Tensor, Tuple[Tensor, ...]],
398396
layer: List[Module],
399397
additional_forward_args: Any = None,
400-
gradient_neuron_index: None = None,
398+
gradient_neuron_selector: None = None,
401399
grad_enabled: bool = False,
402400
device_ids: Union[None, List[int]] = None,
403401
attribute_to_layer_input: bool = False,
@@ -410,7 +408,9 @@ def _forward_layer_eval_with_neuron_grads(
410408
inputs: Union[Tensor, Tuple[Tensor, ...]],
411409
layer: ModuleOrModuleList,
412410
additional_forward_args: Any = None,
413-
gradient_neuron_index: Union[None, int, Tuple[Union[int, slice], ...]] = None,
411+
gradient_neuron_selector: Union[
412+
None, int, Tuple[Union[int, slice], ...], Callable
413+
] = None,
414414
grad_enabled: bool = False,
415415
device_ids: Union[None, List[int]] = None,
416416
attribute_to_layer_input: bool = False,
@@ -421,7 +421,7 @@ def _forward_layer_eval_with_neuron_grads(
421421
]:
422422
"""
423423
This method computes forward evaluation for a particular layer using a
424-
forward hook. If a gradient_neuron_index is provided, then gradients with
424+
forward hook. If a gradient_neuron_selector is provided, then gradients with
425425
respect to that neuron in the layer output are also returned.
426426
427427
These functionalities are combined due to the behavior of DataParallel models
@@ -435,7 +435,7 @@ def _forward_layer_eval_with_neuron_grads(
435435
evals in a dictionary protected by a lock, analogous to the gather implementation
436436
for the core PyTorch DataParallel implementation.
437437
"""
438-
grad_enabled = True if gradient_neuron_index is not None or grad_enabled else False
438+
grad_enabled = True if gradient_neuron_selector is not None else grad_enabled
439439

440440
with torch.autograd.set_grad_enabled(grad_enabled):
441441
saved_layer = _forward_layer_distributed_eval(
@@ -450,12 +450,12 @@ def _forward_layer_eval_with_neuron_grads(
450450
# key_list is a list of devices in appropriate ordering for concatenation.
451451
# If only one key exists (standard model), key list simply has one element.
452452
key_list = _sort_key_list(list(next(iter(saved_layer.values())).keys()), device_ids)
453-
if gradient_neuron_index is not None:
453+
if gradient_neuron_selector is not None:
454454
assert isinstance(
455455
layer, Module
456456
), "Cannot compute neuron gradients for multiple layers simultaneously!"
457457
inp_grads = _neuron_gradients(
458-
inputs, saved_layer[layer], key_list, gradient_neuron_index
458+
inputs, saved_layer[layer], key_list, gradient_neuron_selector
459459
)
460460
return (
461461
_gather_distributed_tensors(saved_layer[layer], key_list=key_list),
@@ -479,7 +479,7 @@ def compute_layer_gradients_and_eval(
479479
target_ind: TargetType = None,
480480
additional_forward_args: Any = None,
481481
*,
482-
gradient_neuron_index: Union[int, Tuple[int, ...]],
482+
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
483483
device_ids: Union[None, List[int]] = None,
484484
attribute_to_layer_input: bool = False,
485485
output_fn: Union[None, Callable] = None,
@@ -494,7 +494,7 @@ def compute_layer_gradients_and_eval(
494494
inputs: Union[Tensor, Tuple[Tensor, ...]],
495495
target_ind: TargetType = None,
496496
additional_forward_args: Any = None,
497-
gradient_neuron_index: None = None,
497+
gradient_neuron_selector: None = None,
498498
device_ids: Union[None, List[int]] = None,
499499
attribute_to_layer_input: bool = False,
500500
output_fn: Union[None, Callable] = None,
@@ -509,7 +509,7 @@ def compute_layer_gradients_and_eval(
509509
inputs: Union[Tensor, Tuple[Tensor, ...]],
510510
target_ind: TargetType = None,
511511
additional_forward_args: Any = None,
512-
gradient_neuron_index: None = None,
512+
gradient_neuron_selector: None = None,
513513
device_ids: Union[None, List[int]] = None,
514514
attribute_to_layer_input: bool = False,
515515
output_fn: Union[None, Callable] = None,
@@ -523,7 +523,9 @@ def compute_layer_gradients_and_eval(
523523
inputs: Union[Tensor, Tuple[Tensor, ...]],
524524
target_ind: TargetType = None,
525525
additional_forward_args: Any = None,
526-
gradient_neuron_index: Union[None, int, Tuple[int, ...]] = None,
526+
gradient_neuron_selector: Union[
527+
None, int, Tuple[Union[int, slice], ...], Callable
528+
] = None,
527529
device_ids: Union[None, List[int]] = None,
528530
attribute_to_layer_input: bool = False,
529531
output_fn: Union[None, Callable] = None,
@@ -659,12 +661,12 @@ def compute_layer_gradients_and_eval(
659661
if isinstance(layer, Module):
660662
layer_grads = all_grads[0]
661663

662-
if gradient_neuron_index is not None:
664+
if gradient_neuron_selector is not None:
663665
assert isinstance(
664666
layer, Module
665667
), "Cannot compute neuron gradients for multiple layers simultaneously!"
666668
inp_grads = _neuron_gradients(
667-
inputs, saved_layer[layer], key_list, gradient_neuron_index
669+
inputs, saved_layer[layer], key_list, gradient_neuron_selector
668670
)
669671
return (
670672
cast(Tuple[Tensor, ...], layer_grads),
@@ -676,7 +678,7 @@ def compute_layer_gradients_and_eval(
676678

677679
def construct_neuron_grad_fn(
678680
layer: Module,
679-
neuron_index: Union[int, Tuple[Union[int, slice], ...]],
681+
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
680682
device_ids: Union[None, List[int]] = None,
681683
attribute_to_neuron_input: bool = False,
682684
) -> Callable:
@@ -691,7 +693,7 @@ def grad_fn(
691693
inputs,
692694
layer,
693695
additional_forward_args,
694-
gradient_neuron_index=neuron_index,
696+
gradient_neuron_selector=neuron_selector,
695697
device_ids=device_ids,
696698
attribute_to_layer_input=attribute_to_neuron_input,
697699
)

‎captum/attr/_core/neuron/neuron_conductance.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from ..._utils.approximation_methods import approximation_parameters
2121
from ..._utils.attribution import GradientAttribution, NeuronAttribution
2222
from ..._utils.batching import _batch_attribution
23-
from ..._utils.common import _format_input_baseline, _reshape_and_sum, _validate_input
23+
from ..._utils.common import (
24+
_format_input_baseline,
25+
_reshape_and_sum,
26+
_validate_input,
27+
neuron_index_deprecation_decorator,
28+
)
2429

2530

2631
class NeuronConductance(NeuronAttribution, GradientAttribution):
@@ -46,7 +51,7 @@ def __init__(
4651
modification of it
4752
layer (torch.nn.Module): Layer for which neuron attributions are computed.
4853
Attributions for a particular neuron in the input or output
49-
of this layer are computed using the argument neuron_index
54+
of this layer are computed using the argument neuron_selector
5055
in the attribute method.
5156
Currently, only layers with a single tensor input or output
5257
are supported.
@@ -85,10 +90,11 @@ def __init__(
8590
self._multiply_by_inputs = multiply_by_inputs
8691

8792
@log_usage()
93+
@neuron_index_deprecation_decorator
8894
def attribute(
8995
self,
9096
inputs: TensorOrTupleOfTensorsGeneric,
91-
neuron_index: Union[int, Tuple[int, ...]],
97+
neuron_selector: Union[int, Tuple[int, ...]],
9298
baselines: BaselineType = None,
9399
target: TargetType = None,
94100
additional_forward_args: Any = None,
@@ -108,7 +114,7 @@ def attribute(
108114
that for all given input tensors, dimension 0 corresponds
109115
to the number of examples, and if multiple input tensors
110116
are provided, the examples must be aligned appropriately.
111-
neuron_index (int or tuple): Index of neuron in output of given
117+
neuron_selector (int or tuple): Index of neuron in output of given
112118
layer for which attribution is desired. Length of
113119
this tuple must be one less than the number of
114120
dimensions in the output of the given layer (since
@@ -260,7 +266,7 @@ def attribute(
260266
n_steps,
261267
inputs=inputs,
262268
baselines=baselines,
263-
neuron_index=neuron_index,
269+
neuron_selector=neuron_selector,
264270
target=target,
265271
additional_forward_args=additional_forward_args,
266272
method=method,
@@ -269,7 +275,7 @@ def attribute(
269275
else:
270276
attrs = self._attribute(
271277
inputs=inputs,
272-
neuron_index=neuron_index,
278+
neuron_selector=neuron_selector,
273279
baselines=baselines,
274280
target=target,
275281
additional_forward_args=additional_forward_args,
@@ -282,7 +288,7 @@ def attribute(
282288
def _attribute(
283289
self,
284290
inputs: Tuple[Tensor, ...],
285-
neuron_index: Union[int, Tuple[int, ...]],
291+
neuron_selector: Union[int, Tuple[int, ...]],
286292
baselines: Tuple[Union[Tensor, int, float], ...],
287293
target: TargetType = None,
288294
additional_forward_args: Any = None,
@@ -333,7 +339,7 @@ def _attribute(
333339
inputs=scaled_features_tpl,
334340
target_ind=expanded_target,
335341
additional_forward_args=input_additional_args,
336-
gradient_neuron_index=neuron_index,
342+
gradient_neuron_selector=neuron_selector,
337343
device_ids=self.device_ids,
338344
attribute_to_layer_input=attribute_to_neuron_input,
339345
)
@@ -348,7 +354,7 @@ def _attribute(
348354
# Multiplies by appropriate gradient of output with respect to hidden neurons
349355
# mid_grads is a 1D Tensor of length num_steps*internal_batch_size,
350356
# containing mid layer gradient for each input step.
351-
mid_grads = _verify_select_column(layer_gradients, neuron_index)
357+
mid_grads = _verify_select_column(layer_gradients, neuron_selector)
352358

353359
scaled_input_gradients = tuple(
354360
input_grad

0 commit comments

Comments
 (0)
Please sign in to comment.