@@ -33,6 +33,17 @@ def test_simple_input_conv(self) -> None:
33
33
net , net .conv1 , inp , [[[[11.25 , 13.5 ], [20.25 , 22.5 ]]]]
34
34
)
35
35
36
+ def test_simple_input_conv_split_channels (self ) -> None :
37
+ net = BasicModel_ConvNet_One_Conv ()
38
+ inp = torch .arange (16 ).view (1 , 1 , 4 , 4 ).float ()
39
+ expected_result = [[[[- 3.7500 , 3.0000 ],
40
+ [23.2500 , 30.0000 ]],
41
+ [[15.0000 , 10.5000 ],
42
+ [- 3.0000 , - 7.5000 ]]]]
43
+ self ._grad_cam_test_assert (
44
+ net , net .conv1 , inp , expected_activation = expected_result , split_channels = True
45
+ )
46
+
36
47
def test_simple_input_conv_no_grad (self ) -> None :
37
48
net = BasicModel_ConvNet_One_Conv ()
38
49
@@ -100,6 +111,7 @@ def _grad_cam_test_assert(
100
111
additional_input : Any = None ,
101
112
attribute_to_layer_input : bool = False ,
102
113
relu_attributions : bool = False ,
114
+ split_channels : bool = False ,
103
115
):
104
116
layer_gc = LayerGradCam (model , target_layer )
105
117
self .assertFalse (layer_gc .multiplies_by_inputs )
@@ -109,6 +121,7 @@ def _grad_cam_test_assert(
109
121
additional_forward_args = additional_input ,
110
122
attribute_to_layer_input = attribute_to_layer_input ,
111
123
relu_attributions = relu_attributions ,
124
+ split_channels = split_channels ,
112
125
)
113
126
assertTensorTuplesAlmostEqual (
114
127
self , attributions , expected_activation , delta = 0.01
0 commit comments