@@ -33,6 +33,23 @@ def test_simple_input_conv(self) -> None:
33
33
net , net .conv1 , inp , [[[[11.25 , 13.5 ], [20.25 , 22.5 ]]]]
34
34
)
35
35
36
+ def test_simple_input_conv_split_channels (self ) -> None :
37
+ net = BasicModel_ConvNet_One_Conv ()
38
+ inp = torch .arange (16 ).view (1 , 1 , 4 , 4 ).float ()
39
+ expected_result = [
40
+ [
41
+ [[- 3.7500 , 3.0000 ], [23.2500 , 30.0000 ]],
42
+ [[15.0000 , 10.5000 ], [- 3.0000 , - 7.5000 ]],
43
+ ]
44
+ ]
45
+ self ._grad_cam_test_assert (
46
+ net ,
47
+ net .conv1 ,
48
+ inp ,
49
+ expected_activation = expected_result ,
50
+ split_channels = True ,
51
+ )
52
+
36
53
def test_simple_input_conv_no_grad (self ) -> None :
37
54
net = BasicModel_ConvNet_One_Conv ()
38
55
@@ -100,6 +117,7 @@ def _grad_cam_test_assert(
100
117
additional_input : Any = None ,
101
118
attribute_to_layer_input : bool = False ,
102
119
relu_attributions : bool = False ,
120
+ split_channels : bool = False ,
103
121
):
104
122
layer_gc = LayerGradCam (model , target_layer )
105
123
self .assertFalse (layer_gc .multiplies_by_inputs )
@@ -109,6 +127,7 @@ def _grad_cam_test_assert(
109
127
additional_forward_args = additional_input ,
110
128
attribute_to_layer_input = attribute_to_layer_input ,
111
129
relu_attributions = relu_attributions ,
130
+ split_channels = split_channels ,
112
131
)
113
132
assertTensorTuplesAlmostEqual (
114
133
self , attributions , expected_activation , delta = 0.01
0 commit comments