Skip to content

Commit 41b3195

Browse files
committed
Add split_channels parameter to LayerGradCam.attribute
This allows examination of each channel's contribution. That is useful if channels are something other than standard RGB, for example multi-spectral input, potentially with many spectral channels.
1 parent c076410 commit 41b3195

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

captum/attr/_core/layer/grad_cam.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def attribute(
8282
additional_forward_args: Any = None,
8383
attribute_to_layer_input: bool = False,
8484
relu_attributions: bool = False,
85+
split_channels: bool = False,
8586
) -> Union[Tensor, Tuple[Tensor, ...]]:
8687
r"""
8788
Args:
@@ -149,6 +150,10 @@ def attribute(
149150
otherwise, by default, both positive and negative
150151
attributions are returned.
151152
Default: False
153+
split_channels (bool, optional): Indicates whether to
154+
keep attributions split per channel.
155+
The default (False) means to sum per channels.
156+
Default: False
152157
153158
Returns:
154159
*Tensor* or *tuple[Tensor, ...]* of **attributions**:
@@ -208,10 +213,16 @@ def attribute(
208213
for layer_grad in layer_gradients
209214
)
210215

211-
scaled_acts = tuple(
212-
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
213-
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
214-
)
216+
if split_channels:
217+
scaled_acts = tuple(
218+
summed_grad * layer_eval
219+
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
220+
)
221+
else:
222+
scaled_acts = tuple(
223+
torch.sum(summed_grad * layer_eval, dim=1, keepdim=True)
224+
for summed_grad, layer_eval in zip(summed_grads, layer_evals)
225+
)
215226
if relu_attributions:
216227
scaled_acts = tuple(F.relu(scaled_act) for scaled_act in scaled_acts)
217228
return _format_output(len(scaled_acts) > 1, scaled_acts)

0 commit comments

Comments
 (0)