@@ -300,21 +300,31 @@ def attribute(
300
300
)
301
301
attr_progress .update (0 )
302
302
303
- initial_eval = _run_forward (
303
+ initial_eval = self . _strict_run_forward (
304
304
self .forward_func , baselines , target , additional_forward_args
305
305
)
306
306
307
307
if show_progress :
308
308
attr_progress .update ()
309
309
310
310
agg_output_mode = _find_output_mode_and_verify (
311
- initial_eval , num_examples , perturbations_per_eval , feature_mask
311
+ initial_eval ,
312
+ num_examples ,
313
+ perturbations_per_eval ,
314
+ feature_mask ,
315
+ allow_multi_outputs = True ,
312
316
)
313
317
314
318
# Initialize attribution totals and counts
319
+ output_shape = initial_eval .shape
320
+ n_outputs = initial_eval .numel ()
321
+
322
+ # attr shape (*output_shape, *input_feature_shape)
315
323
total_attrib = [
316
- torch .zeros_like (
317
- input [0 :1 ] if agg_output_mode else input , dtype = torch .float
324
+ torch .zeros (
325
+ (* output_shape , * input .shape [1 :]),
326
+ dtype = torch .float ,
327
+ device = inputs [0 ].device ,
318
328
)
319
329
for input in inputs
320
330
]
@@ -349,7 +359,7 @@ def attribute(
349
359
)
350
360
# modified_eval dimensions: 1D tensor with length
351
361
# equal to #num_examples * #features in batch
352
- modified_eval = _run_forward (
362
+ modified_eval = self . _strict_run_forward (
353
363
self .forward_func ,
354
364
current_inputs ,
355
365
current_target ,
@@ -362,23 +372,35 @@ def attribute(
362
372
eval_diff = modified_eval - prev_results
363
373
prev_results = modified_eval
364
374
else :
375
+ # when perturb_per_eval > 1, every num_examples stands for
376
+ # one perturb. Since the perturbs are from a consecutive
377
+ # perumuation, each diff of a perturb is its eval minus
378
+ # the eval of the previous perturb
365
379
all_eval = torch .cat ((prev_results , modified_eval ), dim = 0 )
366
380
eval_diff = all_eval [num_examples :] - all_eval [:- num_examples ]
367
381
prev_results = all_eval [- num_examples :]
382
+
368
383
for j in range (len (total_attrib )):
369
- current_eval_diff = eval_diff
370
- if not agg_output_mode :
371
- # current_eval_diff dimensions:
372
- # (#features in batch, #num_examples, 1,.. 1)
373
- # (contains 1 more dimension than inputs). This adds extra
374
- # dimensions of 1 to make the tensor broadcastable with the
375
- # inputs tensor.
376
- current_eval_diff = current_eval_diff .reshape (
377
- (- 1 , num_examples ) + (len (inputs [j ].shape ) - 1 ) * (1 ,)
378
- )
379
- total_attrib [j ] += (
380
- current_eval_diff * current_masks [j ].float ()
381
- ).sum (dim = 0 )
384
+ # format eval_diff to shape
385
+ # (n_perturb, n_outputs, 1,.. 1)
386
+ # where n_perturb may not be perturb_per_eval
387
+ # Append n_input_feature dim of 1 to make the tensor
388
+ # have the same dim as the mask tensor.
389
+ formatted_eval_diff = eval_diff .reshape (
390
+ (- 1 , n_outputs ) + (len (inputs [j ].shape ) - 1 ) * (1 ,)
391
+ )
392
+
393
+ # mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
394
+ # aggregate n_perturb
395
+ cur_attr = (formatted_eval_diff * current_masks [j ].float ()).sum (
396
+ dim = 0
397
+ )
398
+
399
+ # (n_outputs, *input_feature_shape) ->
400
+ # (*output_shape, *input_feature_shape)
401
+ total_attrib [j ] += cur_attr .reshape (
402
+ (* output_shape , * cur_attr .shape [1 :])
403
+ )
382
404
383
405
if show_progress :
384
406
attr_progress .close ()
@@ -476,6 +498,31 @@ def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval):
476
498
"""return the total number of forward evaluations needed"""
477
499
return math .ceil (total_features / perturbations_per_eval ) * n_samples
478
500
501
+ def _strict_run_forward (self , * args , ** kwargs ) -> Tensor :
502
+ """
503
+ A temp wrapper for global _run_forward util to force forward output
504
+ type assertion & conversion.
505
+ Remove after the strict logic is supported by all attr classes
506
+ """
507
+ forward_output = _run_forward (* args , ** kwargs )
508
+ if isinstance (forward_output , Tensor ):
509
+ # format scalar to shape (1) so we can always assume non-empty output_shape
510
+ if not forward_output .shape :
511
+ forward_output = forward_output .reshape (1 )
512
+
513
+ return forward_output
514
+
515
+ output_type = type (forward_output )
516
+ assert output_type is int or output_type is float , (
517
+ "the return of forward_func must be a tensor, int, or float,"
518
+ f" received: { forward_output } "
519
+ )
520
+
521
+ # using python built-in type as torch dtype
522
+ # int -> torch.int64, float -> torch.float64
523
+ # ref: https://github.com/pytorch/pytorch/pull/21215
524
+ return torch .tensor ([forward_output ], dtype = output_type )
525
+
479
526
480
527
class ShapleyValues (ShapleyValueSampling ):
481
528
"""
0 commit comments