1
1
from __future__ import print_function
2
- from typing import Callable , List , Optional , Union
3
2
4
- import torch
5
- import torch . nn as nn
3
+ import unittest
4
+ from typing import Callable , List , Optional , Union
6
5
7
6
from captum .insights .api import AttributionVisualizer , Data
8
- from captum .insights .features import ImageFeature , BaseFeature , FeatureOutput
9
-
7
+ from captum .insights .features import BaseFeature , FeatureOutput , ImageFeature
10
8
from tests .attr .helpers .utils import BaseTest
11
9
10
+ import torch
11
+ import torch .nn as nn
12
+
12
13
13
14
class RealFeature (BaseFeature ):
14
15
def __init__ (
@@ -97,15 +98,15 @@ def forward(self, img, misc):
97
98
98
99
99
100
def _labelled_img_data (num_samples = 10 , width = 8 , height = 8 , depth = 3 , num_labels = 10 ):
100
- for i in range (num_samples ):
101
+ for _ in range (num_samples ):
101
102
yield torch .empty (depth , height , width ).uniform_ (0 , 1 ), torch .randint (
102
103
num_labels , (1 ,)
103
104
)
104
105
105
106
106
107
def _multi_modal_data (img_dataset , feature_size = 256 ):
107
108
def misc_data (length , feature_size = None ):
108
- for i in range (length ):
109
+ for _ in range (length ):
109
110
yield torch .randn (feature_size )
110
111
111
112
misc_dataset = misc_data (length = len (img_dataset ), feature_size = feature_size )
@@ -165,11 +166,8 @@ def test_one_feature(self):
165
166
outputs = visualizer .visualize ()
166
167
167
168
for output in outputs :
168
- contribs = torch .stack (
169
- [feature .contribution for feature in output .feature_outputs ]
170
- )
171
- total_contrib = torch .sum (torch .abs (contribs ))
172
- self .assertAlmostEqual (total_contrib .item (), 1.0 , places = 6 )
169
+ total_contrib = sum (abs (f .contribution ) for f in output .feature_outputs )
170
+ self .assertAlmostEqual (total_contrib , 1.0 , places = 6 )
173
171
174
172
def test_multi_features (self ):
175
173
batch_size = 2
@@ -183,7 +181,7 @@ def test_multi_features(self):
183
181
img_dataset = img_dataset , feature_size = misc_feature_size
184
182
)
185
183
# NOTE: using DataLoader to batch the inputs since
186
- # AttributionVisualizer requires the input to be of size `B x ...`
184
+ # AttributionVisualizer requires the input to be of size `N x ...`
187
185
data_loader = torch .utils .data .DataLoader (
188
186
list (dataset ), batch_size = batch_size , shuffle = False , num_workers = 0
189
187
)
@@ -211,13 +209,14 @@ def test_multi_features(self):
211
209
outputs = visualizer .visualize ()
212
210
213
211
for output in outputs :
214
- contribs = torch .stack (
215
- [feature .contribution for feature in output .feature_outputs ]
216
- )
217
- total_contrib = torch .sum (torch .abs (contribs ))
218
- self .assertAlmostEqual (total_contrib .item (), 1.0 , places = 6 )
212
+ total_contrib = sum (abs (f .contribution ) for f in output .feature_outputs )
213
+ self .assertAlmostEqual (total_contrib , 1.0 , places = 6 )
219
214
220
215
# TODO: add test for multiple models (related to TODO in captum/insights/api.py)
221
216
#
222
217
# TODO: add test to make the attribs == 0 -- error occurs
223
218
# I know (through manual testing) that this breaks some existing code
219
+
220
+
221
+ if __name__ == "__main__" :
222
+ unittest .main ()
0 commit comments