Skip to content

Commit 2de6a4a

Browse files
committed
Add test for split_channels parameter to LayerGradCam.attribute
1 parent 2060485 commit 2de6a4a

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

tests/attr/layer/test_grad_cam.py

+13
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ def test_simple_input_conv(self) -> None:
3333
net, net.conv1, inp, [[[[11.25, 13.5], [20.25, 22.5]]]]
3434
)
3535

36+
def test_simple_input_conv_split_channels(self) -> None:
37+
net = BasicModel_ConvNet_One_Conv()
38+
inp = torch.arange(16).view(1, 1, 4, 4).float()
39+
expected_result = [[[[-3.7500, 3.0000],
40+
[23.2500, 30.0000]],
41+
[[15.0000, 10.5000],
42+
[-3.0000, -7.5000]]]]
43+
self._grad_cam_test_assert(
44+
net, net.conv1, inp, expected_activation=expected_result, split_channels=True
45+
)
46+
3647
def test_simple_input_conv_no_grad(self) -> None:
3748
net = BasicModel_ConvNet_One_Conv()
3849

@@ -100,6 +111,7 @@ def _grad_cam_test_assert(
100111
additional_input: Any = None,
101112
attribute_to_layer_input: bool = False,
102113
relu_attributions: bool = False,
114+
split_channels: bool = False,
103115
):
104116
layer_gc = LayerGradCam(model, target_layer)
105117
self.assertFalse(layer_gc.multiplies_by_inputs)
@@ -109,6 +121,7 @@ def _grad_cam_test_assert(
109121
additional_forward_args=additional_input,
110122
attribute_to_layer_input=attribute_to_layer_input,
111123
relu_attributions=relu_attributions,
124+
split_channels=split_channels,
112125
)
113126
assertTensorTuplesAlmostEqual(
114127
self, attributions, expected_activation, delta=0.01

0 commit comments

Comments
 (0)