@@ -82,7 +82,7 @@ def attribute(
82
82
additional_forward_args : Any = None ,
83
83
attribute_to_layer_input : bool = False ,
84
84
relu_attributions : bool = False ,
85
- split_channels : bool = False ,
85
+ attr_dim_summation : bool = True ,
86
86
) -> Union [Tensor , Tuple [Tensor , ...]]:
87
87
r"""
88
88
Args:
@@ -150,10 +150,10 @@ def attribute(
150
150
otherwise, by default, both positive and negative
151
151
attributions are returned.
152
152
Default: False
153
- split_channels (bool, optional): Indicates whether to
154
- keep attributions split per channel.
155
- The default (False ) means to sum per channels .
156
- Default: False
153
+ attr_dim_summation (bool, optional): Indicates whether to
154
+ sum attributions along dimension 1 (usually channel) .
155
+ The default (True ) means to sum along dimension 1 .
156
+ Default: True
157
157
158
158
Returns:
159
159
*Tensor* or *tuple[Tensor, ...]* of **attributions**:
@@ -213,16 +213,17 @@ def attribute(
213
213
for layer_grad in layer_gradients
214
214
)
215
215
216
- if split_channels :
216
+ if attr_dim_summation :
217
217
scaled_acts = tuple (
218
- summed_grad * layer_eval
218
+ torch . sum ( summed_grad * layer_eval , dim = 1 , keepdim = True )
219
219
for summed_grad , layer_eval in zip (summed_grads , layer_evals )
220
220
)
221
221
else :
222
222
scaled_acts = tuple (
223
- torch . sum ( summed_grad * layer_eval , dim = 1 , keepdim = True )
223
+ summed_grad * layer_eval
224
224
for summed_grad , layer_eval in zip (summed_grads , layer_evals )
225
225
)
226
+
226
227
if relu_attributions :
227
228
scaled_acts = tuple (F .relu (scaled_act ) for scaled_act in scaled_acts )
228
229
return _format_output (len (scaled_acts ) > 1 , scaled_acts )
0 commit comments