10
10
from torch import Tensor
11
11
from torch .nn import Module
12
12
13
- PredictionScore = namedtuple ("PredictionScore " , "score label" )
13
+ OutputScore = namedtuple ("OutputScore " , "score index label" )
14
14
VisualizationOutput = namedtuple (
15
- "VisualizationOutput" , "feature_outputs actual predicted"
15
+ "VisualizationOutput" , "feature_outputs actual predicted active_index "
16
16
)
17
17
Contribution = namedtuple ("Contribution" , "name percent" )
18
+ SampleCache = namedtuple ("SampleCache" , "inputs additional_forward_args label" )
18
19
19
20
20
21
class FilterConfig (NamedTuple ):
@@ -44,6 +45,7 @@ def __init__(
44
45
features : Union [List [BaseFeature ], BaseFeature ],
45
46
dataset : Iterable [Data ],
46
47
score_func : Optional [Callable ] = None ,
48
+ use_label_for_attr : bool = True ,
47
49
):
48
50
if not isinstance (models , List ):
49
51
models = [models ]
@@ -56,20 +58,34 @@ def __init__(
56
58
self .features = features
57
59
self .dataset = dataset
58
60
self .score_func = score_func
61
+ self ._outputs = []
59
62
self ._config = FilterConfig (steps = 25 , prediction = "all" , classes = [], count = 4 )
63
+ self ._use_label_for_attr = use_label_for_attr
64
+
65
+ def _calculate_attribution_from_cache (
66
+ self , index : int , target : Optional [Tensor ]
67
+ ) -> VisualizationOutput :
68
+ c = self ._outputs [index ][1 ]
69
+ return self ._calculate_vis_output (
70
+ c .inputs , c .additional_forward_args , c .label , torch .tensor (target )
71
+ )
60
72
61
73
def _calculate_attribution (
62
74
self ,
63
75
net : Module ,
64
76
baselines : Optional [List [Tuple [Tensor , ...]]],
65
77
data : Tuple [Tensor , ...],
66
78
additional_forward_args : Optional [Tuple [Tensor , ...]],
67
- label : Optional [Tensor ],
79
+ label : Optional [Union [ Tensor ] ],
68
80
) -> Tensor :
69
81
ig = IntegratedGradients (net )
70
82
# TODO support multiple baselines
71
83
baseline = baselines [0 ] if len (baselines ) > 0 else None
72
- label = None if label is None or label .nelement () == 0 else label
84
+ label = (
85
+ None
86
+ if not self ._use_label_for_attr or label is None or label .nelement () == 0
87
+ else label
88
+ )
73
89
attr_ig = ig .attribute (
74
90
data ,
75
91
baselines = baseline ,
@@ -98,11 +114,11 @@ def render(self, blocking=False, debug=False):
98
114
99
115
def _get_labels_from_scores (
100
116
self , scores : Tensor , indices : Tensor
101
- ) -> List [PredictionScore ]:
117
+ ) -> List [OutputScore ]:
102
118
pred_scores = []
103
119
for i in range (len (indices )):
104
- score = scores [i ]. item ()
105
- pred_scores .append (PredictionScore (score , self .classes [indices [i ]]))
120
+ score = scores [i ]
121
+ pred_scores .append (OutputScore (score , indices [ i ] , self .classes [indices [i ]]))
106
122
return pred_scores
107
123
108
124
def _transform (
@@ -123,7 +139,7 @@ def _transform(
123
139
transformed_inputs = transforms (transformed_inputs )
124
140
125
141
if batch :
126
- transformed_inputs . unsqueeze_ (0 )
142
+ transformed_inputs = transformed_inputs . unsqueeze (0 )
127
143
128
144
return transformed_inputs
129
145
@@ -141,22 +157,20 @@ def _calculate_net_contrib(self, attrs_per_input_feature: List[Tensor]):
141
157
return net_contrib .tolist ()
142
158
143
159
def _predictions_matches_labels (
144
- self ,
145
- predicted_scores : List [PredictionScore ],
146
- actual_labels : Union [str , List [str ]],
160
+ self , predicted_scores : List [OutputScore ], labels : Union [str , List [str ]]
147
161
) -> bool :
148
162
if len (predicted_scores ) == 0 :
149
163
return False
150
164
151
165
predicted_label = predicted_scores [0 ].label
152
166
153
- if isinstance (actual_labels , List ):
154
- return predicted_label in actual_labels
167
+ if isinstance (labels , List ):
168
+ return predicted_label in labels
155
169
156
- return actual_labels == predicted_label
170
+ return labels == predicted_label
157
171
158
172
def _should_keep_prediction (
159
- self , predicted_scores : List [PredictionScore ], actual_label : str
173
+ self , predicted_scores : List [OutputScore ], actual_label : str
160
174
) -> bool :
161
175
# filter by class
162
176
if len (self ._config .classes ) != 0 :
@@ -179,104 +193,117 @@ def _should_keep_prediction(
179
193
180
194
return True
181
195
182
- def _get_outputs (self ) -> List [VisualizationOutput ]:
183
- batch_data = next (self .dataset )
196
+ def _calculate_vis_output (
197
+ self , inputs , additional_forward_args , label , target = None
198
+ ) -> Optional [VisualizationOutput ]:
184
199
net = self .models [0 ] # TODO process multiple models
185
- vis_outputs = []
186
200
187
- for inputs , additional_forward_args , label in _batched_generator (
188
- inputs = batch_data .inputs ,
189
- additional_forward_args = batch_data .additional_args ,
190
- target_ind = batch_data .labels ,
191
- internal_batch_size = 1 , # should be 1 until we have batch label support
192
- ):
193
- # initialize baselines
194
- baseline_transforms_len = len (self .features [0 ].baseline_transforms or [])
195
- baselines = [
196
- [None ] * len (self .features ) for _ in range (baseline_transforms_len )
197
- ]
198
- transformed_inputs = list (inputs )
199
-
200
- for feature_i , feature in enumerate (self .features ):
201
- if feature .input_transforms is not None :
202
- transformed_inputs [feature_i ] = self ._transform (
203
- feature .input_transforms , transformed_inputs [feature_i ], True
201
+ # initialize baselines
202
+ baseline_transforms_len = len (self .features [0 ].baseline_transforms or [])
203
+ baselines = [
204
+ [None ] * len (self .features ) for _ in range (baseline_transforms_len )
205
+ ]
206
+ transformed_inputs = list (inputs )
207
+
208
+ # transformed_inputs = list([i.clone() for i in inputs])
209
+ for feature_i , feature in enumerate (self .features ):
210
+ if feature .input_transforms is not None :
211
+ transformed_inputs [feature_i ] = self ._transform (
212
+ feature .input_transforms , transformed_inputs [feature_i ], True
213
+ )
214
+ if feature .baseline_transforms is not None :
215
+ assert baseline_transforms_len == len (
216
+ feature .baseline_transforms
217
+ ), "Must have same number of baselines across all features"
218
+
219
+ for baseline_i , baseline_transform in enumerate (
220
+ feature .baseline_transforms
221
+ ):
222
+ baselines [baseline_i ][feature_i ] = self ._transform (
223
+ baseline_transform , transformed_inputs [feature_i ], True
204
224
)
205
- if feature .baseline_transforms is not None :
206
- assert baseline_transforms_len == len (
207
- feature .baseline_transforms
208
- ), "Must have same number of baselines across all features"
209
-
210
- for baseline_i , baseline_transform in enumerate (
211
- feature .baseline_transforms
212
- ):
213
- baselines [baseline_i ][feature_i ] = self ._transform (
214
- baseline_transform , transformed_inputs [feature_i ], True
215
- )
216
-
217
- outputs = _run_forward (
218
- net , tuple (transformed_inputs ), additional_forward_args
219
- )
220
225
221
- if self .score_func is not None :
222
- outputs = self .score_func (outputs )
223
-
224
- if outputs .nelement () == 1 :
225
- scores = outputs
226
- predicted = scores .round ().to (torch .int )
227
- else :
228
- scores , predicted = outputs .topk (min (4 , outputs .shape [- 1 ]))
229
-
230
- scores = scores .cpu ().squeeze (0 )
231
- predicted = predicted .cpu ().squeeze (0 )
232
-
233
- actual_label = self .classes [label [0 ]] if label is not None else None
234
- predicted_scores = self ._get_labels_from_scores (scores , predicted )
235
-
236
- # Filter based on UI configuration
237
- if not self ._should_keep_prediction (predicted_scores , actual_label ):
238
- continue
239
-
240
- baselines = [tuple (b ) for b in baselines ]
241
-
242
- # attributions are given per input*
243
- # inputs given to the model are described via `self.features`
244
- #
245
- # *an input contains multiple features that represent it
246
- # e.g. all the pixels that describe an image is an input
247
- attrs_per_input_feature = self ._calculate_attribution (
248
- net ,
249
- baselines ,
250
- tuple (transformed_inputs ),
251
- additional_forward_args ,
252
- label ,
226
+ outputs = _run_forward (net , tuple (transformed_inputs ), additional_forward_args )
227
+
228
+ if self .score_func is not None :
229
+ outputs = self .score_func (outputs )
230
+
231
+ if outputs .nelement () == 1 :
232
+ scores = outputs
233
+ predicted = scores .round ().to (torch .int )
234
+ else :
235
+ scores , predicted = outputs .topk (min (4 , outputs .shape [- 1 ]))
236
+
237
+ scores = scores .cpu ().squeeze (0 )
238
+ predicted = predicted .cpu ().squeeze (0 )
239
+
240
+ if label is not None and len (label ) > 0 :
241
+ actual_label = OutputScore (
242
+ score = 0 , index = label [0 ], label = self .classes [label [0 ]]
253
243
)
244
+ else :
245
+ actual_label = None
254
246
255
- net_contrib = self ._calculate_net_contrib ( attrs_per_input_feature )
247
+ predicted_scores = self ._get_labels_from_scores ( scores , predicted )
256
248
257
- # the features per input given
258
- features_per_input = [
259
- feature .visualize (attr , data , contrib )
260
- for feature , attr , data , contrib in zip (
261
- self .features , attrs_per_input_feature , inputs , net_contrib
262
- )
263
- ]
249
+ # Filter based on UI configuration
250
+ if not self ._should_keep_prediction (predicted_scores , actual_label ):
251
+ return None
252
+
253
+ baselines = [tuple (b ) for b in baselines ]
254
+
255
+ if target is None :
256
+ target = predicted_scores [0 ].index if len (predicted_scores ) > 0 else None
257
+
258
+ # attributions are given per input*
259
+ # inputs given to the model are described via `self.features`
260
+ #
261
+ # *an input contains multiple features that represent it
262
+ # e.g. all the pixels that describe an image is an input
263
+
264
+ attrs_per_input_feature = self ._calculate_attribution (
265
+ net , baselines , tuple (transformed_inputs ), additional_forward_args , target
266
+ )
267
+
268
+ net_contrib = self ._calculate_net_contrib (attrs_per_input_feature )
264
269
265
- output = VisualizationOutput (
266
- feature_outputs = features_per_input ,
267
- actual = actual_label ,
268
- predicted = predicted_scores ,
270
+ # the features per input given
271
+ features_per_input = [
272
+ feature .visualize (attr , data , contrib )
273
+ for feature , attr , data , contrib in zip (
274
+ self .features , attrs_per_input_feature , inputs , net_contrib
269
275
)
276
+ ]
270
277
271
- vis_outputs .append (output )
278
+ return VisualizationOutput (
279
+ feature_outputs = features_per_input ,
280
+ actual = actual_label ,
281
+ predicted = predicted_scores ,
282
+ active_index = target if target is not None else actual_label .index ,
283
+ )
284
+
285
+ def _get_outputs (self ) -> List [VisualizationOutput ]:
286
+ batch_data = next (self .dataset )
287
+ vis_outputs = []
288
+
289
+ for inputs , additional_forward_args , label in _batched_generator (
290
+ inputs = batch_data .inputs ,
291
+ additional_forward_args = batch_data .additional_args ,
292
+ target_ind = batch_data .labels ,
293
+ internal_batch_size = 1 , # should be 1 until we have batch label support
294
+ ):
295
+ output = self ._calculate_vis_output (inputs , additional_forward_args , label )
296
+ if output is not None :
297
+ cache = SampleCache (inputs , additional_forward_args , label )
298
+ vis_outputs .append ((output , cache ))
272
299
273
300
return vis_outputs
274
301
275
302
def visualize (self ):
276
- output_list = []
277
- while len (output_list ) < self ._config .count :
303
+ self . _outputs = []
304
+ while len (self . _outputs ) < self ._config .count :
278
305
try :
279
- output_list .extend (self ._get_outputs ())
306
+ self . _outputs .extend (self ._get_outputs ())
280
307
except StopIteration :
281
308
break
282
- return output_list
309
+ return [ o [ 0 ] for o in self . _outputs ]
0 commit comments