Skip to content

Commit 61c3898

Browse files
aobo-yfacebook-github-bot
authored andcommitted
Enable multi-task attribution for Shapley (#1173)
Summary: Pull Request resolved: #1173 Support multi-task attribution in `ShapleyValues` and `ShapleyValueSampling`. Assuming the return of `forward_fun` is in (*output_shape), the attribution result will be in (*output_shape, *input_shape[1:]). Existing use cases becomes just special cases where output_shape is (1,) or (batch_size,) Reviewed By: vivekmig Differential Revision: D48696578 fbshipit-source-id: cc0f9275b20be6416abf0d8e72739a2c3ca421b6
1 parent fea4c7f commit 61c3898

File tree

3 files changed

+107
-21
lines changed

3 files changed

+107
-21
lines changed

captum/attr/_core/shapley_value.py

+65-18
Original file line numberDiff line numberDiff line change
@@ -300,21 +300,31 @@ def attribute(
300300
)
301301
attr_progress.update(0)
302302

303-
initial_eval = _run_forward(
303+
initial_eval = self._strict_run_forward(
304304
self.forward_func, baselines, target, additional_forward_args
305305
)
306306

307307
if show_progress:
308308
attr_progress.update()
309309

310310
agg_output_mode = _find_output_mode_and_verify(
311-
initial_eval, num_examples, perturbations_per_eval, feature_mask
311+
initial_eval,
312+
num_examples,
313+
perturbations_per_eval,
314+
feature_mask,
315+
allow_multi_outputs=True,
312316
)
313317

314318
# Initialize attribution totals and counts
319+
output_shape = initial_eval.shape
320+
n_outputs = initial_eval.numel()
321+
322+
# attr shape (*output_shape, *input_feature_shape)
315323
total_attrib = [
316-
torch.zeros_like(
317-
input[0:1] if agg_output_mode else input, dtype=torch.float
324+
torch.zeros(
325+
(*output_shape, *input.shape[1:]),
326+
dtype=torch.float,
327+
device=inputs[0].device,
318328
)
319329
for input in inputs
320330
]
@@ -349,7 +359,7 @@ def attribute(
349359
)
350360
# modified_eval dimensions: 1D tensor with length
351361
# equal to #num_examples * #features in batch
352-
modified_eval = _run_forward(
362+
modified_eval = self._strict_run_forward(
353363
self.forward_func,
354364
current_inputs,
355365
current_target,
@@ -362,23 +372,35 @@ def attribute(
362372
eval_diff = modified_eval - prev_results
363373
prev_results = modified_eval
364374
else:
375+
# when perturb_per_eval > 1, every num_examples stands for
376+
# one perturb. Since the perturbs are from a consecutive
377+
# perumuation, each diff of a perturb is its eval minus
378+
# the eval of the previous perturb
365379
all_eval = torch.cat((prev_results, modified_eval), dim=0)
366380
eval_diff = all_eval[num_examples:] - all_eval[:-num_examples]
367381
prev_results = all_eval[-num_examples:]
382+
368383
for j in range(len(total_attrib)):
369-
current_eval_diff = eval_diff
370-
if not agg_output_mode:
371-
# current_eval_diff dimensions:
372-
# (#features in batch, #num_examples, 1,.. 1)
373-
# (contains 1 more dimension than inputs). This adds extra
374-
# dimensions of 1 to make the tensor broadcastable with the
375-
# inputs tensor.
376-
current_eval_diff = current_eval_diff.reshape(
377-
(-1, num_examples) + (len(inputs[j].shape) - 1) * (1,)
378-
)
379-
total_attrib[j] += (
380-
current_eval_diff * current_masks[j].float()
381-
).sum(dim=0)
384+
# format eval_diff to shape
385+
# (n_perturb, n_outputs, 1,.. 1)
386+
# where n_perturb may not be perturb_per_eval
387+
# Append n_input_feature dim of 1 to make the tensor
388+
# have the same dim as the mask tensor.
389+
formatted_eval_diff = eval_diff.reshape(
390+
(-1, n_outputs) + (len(inputs[j].shape) - 1) * (1,)
391+
)
392+
393+
# mask in shape (n_perturb, *mask_shape_broadcastable_to_input)
394+
# aggregate n_perturb
395+
cur_attr = (formatted_eval_diff * current_masks[j].float()).sum(
396+
dim=0
397+
)
398+
399+
# (n_outputs, *input_feature_shape) ->
400+
# (*output_shape, *input_feature_shape)
401+
total_attrib[j] += cur_attr.reshape(
402+
(*output_shape, *cur_attr.shape[1:])
403+
)
382404

383405
if show_progress:
384406
attr_progress.close()
@@ -476,6 +498,31 @@ def _get_n_evaluations(self, total_features, n_samples, perturbations_per_eval):
476498
"""return the total number of forward evaluations needed"""
477499
return math.ceil(total_features / perturbations_per_eval) * n_samples
478500

501+
def _strict_run_forward(self, *args, **kwargs) -> Tensor:
502+
"""
503+
A temp wrapper for global _run_forward util to force forward output
504+
type assertion & conversion.
505+
Remove after the strict logic is supported by all attr classes
506+
"""
507+
forward_output = _run_forward(*args, **kwargs)
508+
if isinstance(forward_output, Tensor):
509+
# format scalar to shape (1) so we can always assume non-empty output_shape
510+
if not forward_output.shape:
511+
forward_output = forward_output.reshape(1)
512+
513+
return forward_output
514+
515+
output_type = type(forward_output)
516+
assert output_type is int or output_type is float, (
517+
"the return of forward_func must be a tensor, int, or float,"
518+
f" received: {forward_output}"
519+
)
520+
521+
# using python built-in type as torch dtype
522+
# int -> torch.int64, float -> torch.float64
523+
# ref: https://github.com/pytorch/pytorch/pull/21215
524+
return torch.tensor([forward_output], dtype=output_type)
525+
479526

480527
class ShapleyValues(ShapleyValueSampling):
481528
"""

captum/attr/_utils/common.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def _find_output_mode_and_verify(
318318
num_examples: int,
319319
perturbations_per_eval: int,
320320
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric],
321+
allow_multi_outputs: bool = False,
321322
) -> bool:
322323
"""
323324
This method identifies whether the model outputs a single output for a batch
@@ -346,9 +347,10 @@ def _find_output_mode_and_verify(
346347
)
347348
else:
348349
agg_output_mode = False
349-
assert (
350-
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1
351-
), "Target should identify a single element in the model output."
350+
if not allow_multi_outputs:
351+
assert (
352+
isinstance(initial_eval, torch.Tensor) and initial_eval[0].numel() == 1
353+
), "Target should identify a single element in the model output."
352354
return agg_output_mode
353355

354356

tests/attr/test_shapley.py

+37
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,43 @@ def test_multi_input_shapley_sampling_with_mask(self) -> None:
151151
perturbations_per_eval=(1, 2, 3),
152152
)
153153

154+
def test_shapley_sampling_multi_task_output(self) -> None:
155+
# return shape (batch size, 2)
156+
net1 = BasicModel_MultiLayer()
157+
158+
# return shape (batch size, 4)
159+
def forward_func(*args, **kwargs):
160+
net_output = net1(*args, **kwargs)
161+
batch_size = net_output.size(0)
162+
constant = torch.ones(batch_size, 2)
163+
output = torch.cat(
164+
[
165+
net_output,
166+
constant,
167+
],
168+
dim=-1,
169+
)
170+
return output
171+
172+
inp = torch.tensor([[20.0, 50.0, 30.0]], requires_grad=True)
173+
174+
self._shapley_test_assert(
175+
forward_func,
176+
inp,
177+
[
178+
[
179+
[76.66666, 196.66666, 116.66666],
180+
[76.66666, 196.66666, 116.66666],
181+
[0, 0, 0],
182+
[0, 0, 0],
183+
]
184+
],
185+
target=None, # no target, multi-task output for all classes
186+
perturbations_per_eval=(1, 2, 3),
187+
n_samples=150,
188+
test_true_shapley=True,
189+
)
190+
154191
# Remaining tests are for cases where forward function returns a scalar
155192
# per batch, as either a float, integer, 0d tensor or 1d tensor.
156193
def test_single_shapley_batch_scalar_float(self) -> None:

0 commit comments

Comments
 (0)