Skip to content

Commit 7eeab0f

Browse files
committed
Rename split_channels into attr_dim_summation
1 parent 6572fb9 commit 7eeab0f

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

captum/attr/_core/layer/grad_cam.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def attribute(
8282
additional_forward_args: Any = None,
8383
attribute_to_layer_input: bool = False,
8484
relu_attributions: bool = False,
85-
split_channels: bool = False,
85+
attr_dim_summation: bool = False,
8686
) -> Union[Tensor, Tuple[Tensor, ...]]:
8787
r"""
8888
Args:
@@ -150,7 +150,7 @@ def attribute(
150150
otherwise, by default, both positive and negative
151151
attributions are returned.
152152
Default: False
153-
split_channels (bool, optional): Indicates whether to
153+
attr_dim_summation (bool, optional): Indicates whether to
154154
keep attributions split per channel.
155155
The default (False) means to sum per channels.
156156
Default: False
@@ -213,7 +213,7 @@ def attribute(
213213
for layer_grad in layer_gradients
214214
)
215215

216-
if split_channels:
216+
if attr_dim_summation:
217217
scaled_acts = tuple(
218218
summed_grad * layer_eval
219219
for summed_grad, layer_eval in zip(summed_grads, layer_evals)

tests/attr/layer/test_grad_cam.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_simple_input_conv_split_channels(self) -> None:
4747
net.conv1,
4848
inp,
4949
expected_activation=expected_result,
50-
split_channels=True,
50+
attr_dim_summation=True,
5151
)
5252

5353
def test_simple_input_conv_no_grad(self) -> None:
@@ -117,7 +117,7 @@ def _grad_cam_test_assert(
117117
additional_input: Any = None,
118118
attribute_to_layer_input: bool = False,
119119
relu_attributions: bool = False,
120-
split_channels: bool = False,
120+
attr_dim_summation: bool = False,
121121
):
122122
layer_gc = LayerGradCam(model, target_layer)
123123
self.assertFalse(layer_gc.multiplies_by_inputs)
@@ -127,7 +127,7 @@ def _grad_cam_test_assert(
127127
additional_forward_args=additional_input,
128128
attribute_to_layer_input=attribute_to_layer_input,
129129
relu_attributions=relu_attributions,
130-
split_channels=split_channels,
130+
attr_dim_summation=attr_dim_summation,
131131
)
132132
assertTensorTuplesAlmostEqual(
133133
self, attributions, expected_activation, delta=0.01

0 commit comments

Comments
 (0)