Skip to content

Commit 6572fb9

Browse files
committed
Add test for split_channels parameter to LayerGradCam.attribute
1 parent 2060485 commit 6572fb9

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tests/attr/layer/test_grad_cam.py

+19
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ def test_simple_input_conv(self) -> None:
3333
net, net.conv1, inp, [[[[11.25, 13.5], [20.25, 22.5]]]]
3434
)
3535

36+
def test_simple_input_conv_split_channels(self) -> None:
37+
net = BasicModel_ConvNet_One_Conv()
38+
inp = torch.arange(16).view(1, 1, 4, 4).float()
39+
expected_result = [
40+
[
41+
[[-3.7500, 3.0000], [23.2500, 30.0000]],
42+
[[15.0000, 10.5000], [-3.0000, -7.5000]],
43+
]
44+
]
45+
self._grad_cam_test_assert(
46+
net,
47+
net.conv1,
48+
inp,
49+
expected_activation=expected_result,
50+
split_channels=True,
51+
)
52+
3653
def test_simple_input_conv_no_grad(self) -> None:
3754
net = BasicModel_ConvNet_One_Conv()
3855

@@ -100,6 +117,7 @@ def _grad_cam_test_assert(
100117
additional_input: Any = None,
101118
attribute_to_layer_input: bool = False,
102119
relu_attributions: bool = False,
120+
split_channels: bool = False,
103121
):
104122
layer_gc = LayerGradCam(model, target_layer)
105123
self.assertFalse(layer_gc.multiplies_by_inputs)
@@ -109,6 +127,7 @@ def _grad_cam_test_assert(
109127
additional_forward_args=additional_input,
110128
attribute_to_layer_input=attribute_to_layer_input,
111129
relu_attributions=relu_attributions,
130+
split_channels=split_channels,
112131
)
113132
assertTensorTuplesAlmostEqual(
114133
self, attributions, expected_activation, delta=0.01

0 commit comments

Comments
 (0)