Skip to content

Commit 0f37186

Browse files
committed
Rename split_channels into attr_dim_summation and invert the logic
1 parent cce95cd commit 0f37186

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

captum/attr/_core/layer/grad_cam.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def attribute(
8282
additional_forward_args: Any = None,
8383
attribute_to_layer_input: bool = False,
8484
relu_attributions: bool = False,
85-
split_channels: bool = False,
85+
attr_dim_summation: bool = True,
8686
) -> Union[Tensor, Tuple[Tensor, ...]]:
8787
r"""
8888
Args:
@@ -150,10 +150,10 @@ def attribute(
150150
otherwise, by default, both positive and negative
151151
attributions are returned.
152152
Default: False
153-
split_channels (bool, optional): Indicates whether to
154-
keep attributions split per channel.
155-
The default (False) means to sum per channels.
156-
Default: False
153+
attr_dim_summation (bool, optional): Indicates whether to
154+
sum attributions along dimension 1 (usually channel).
155+
The default (True) means to sum along dimension 1.
156+
Default: True
157157
158158
Returns:
159159
*Tensor* or *tuple[Tensor, ...]* of **attributions**:
@@ -213,16 +213,17 @@ def attribute(
213213
for layer_grad in layer_gradients
214214
)
215215

216-
if split_channels:
216+
if attr_dim_summation:
217217
scaled_acts = tuple(
218-
summed_grad * layer_eval
218+
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
219219
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
220220
)
221221
else:
222222
scaled_acts = tuple(
223-
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
223+
summed_grad * layer_eval
224224
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
225225
)
226+
226227
if relu_attributions:
227228
scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
228229
return _format_output(len(scaled_acts) > 1, scaled_acts)

tests/attr/layer/test_grad_cam.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_simple_input_conv_split_channels(self) -> None:
4747
net.conv1,
4848
inp,
4949
expected_activation=expected_result,
50-
split_channels=True,
50+
attr_dim_summation=False,
5151
)
5252

5353
def test_simple_input_conv_no_grad(self) -> None:
@@ -117,7 +117,7 @@ def _grad_cam_test_assert(
117117
additional_input: Any = None,
118118
attribute_to_layer_input: bool = False,
119119
relu_attributions: bool = False,
120-
split_channels: bool = False,
120+
attr_dim_summation: bool = True,
121121
):
122122
layer_gc = LayerGradCam(model, target_layer)
123123
self.assertFalse(layer_gc.multiplies_by_inputs)
@@ -127,7 +127,7 @@ def _grad_cam_test_assert(
127127
additional_forward_args=additional_input,
128128
attribute_to_layer_input=attribute_to_layer_input,
129129
relu_attributions=relu_attributions,
130-
split_channels=split_channels,
130+
attr_dim_summation=attr_dim_summation,
131131
)
132132
assertTensorTuplesAlmostEqual(
133133
self, attributions, expected_activation, delta=0.01

0 commit comments

Comments
 (0)