Skip to content

Commit c1bd8c6

Browse files
author
Carlos Araya
committedOct 9, 2019
Merge remote-tracking branch 'upstream/master'
2 parents adcbb0e + ea20256 commit c1bd8c6

24 files changed

+1288
-422
lines changed
 

‎README.md

+26-27
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,15 @@ class ToyModel(nn.Module):
120120
self.lin1 = nn.Linear(3, 3)
121121
self.relu = nn.ReLU()
122122
self.lin2 = nn.Linear(3, 2)
123-
self.sigmoid = nn.Sigmoid()
124123
125124
# initialize weights and biases
126-
self.lin1.weight = nn.Parameter(torch.arange(0.0, 9.0).view(3, 3))
125+
self.lin1.weight = nn.Parameter(torch.arange(-4.0, 5.0).view(3, 3))
127126
self.lin1.bias = nn.Parameter(torch.zeros(1,3))
128-
self.lin2.weight = nn.Parameter(torch.arange(0.0, 6.0).view(2, 3))
127+
self.lin2.weight = nn.Parameter(torch.arange(-3.0, 3.0).view(2, 3))
129128
self.lin2.bias = nn.Parameter(torch.ones(1,2))
130129
131130
def forward(self, input):
132-
return self.sigmoid(self.lin2(self.relu(self.lin1(input))))
131+
return self.lin2(self.relu(self.lin1(input)))
133132
```
134133

135134
Let's create an instance of our model and set it to eval mode.
@@ -176,9 +175,9 @@ print('IG Attributions: ', attributions, ' Convergence Delta: ', delta)
176175
```
177176
Output:
178177
```
179-
IG Attributions: tensor([[0.0628, 0.1314, 0.0747],
180-
[0.0930, 0.0120, 0.1639]])
181-
Convergence Delta: tensor([0., 0.])
178+
IG Attributions: tensor([[-0.5922, -1.5497, -1.0067],
179+
[ 0.0000, -0.2219, -5.1991]])
180+
Convergence Delta: tensor([2.3842e-07, -4.7684e-07])
182181
```
183182
The algorithm outputs an attribution score for each input element and a
184183
convergence delta. The lower the absolute value of the convergence delta the better
@@ -217,9 +216,9 @@ print('GradientShap Attributions: ', attributions, ' Convergence Delta: ', delta
217216
```
218217
Output
219218
```
220-
GradientShap Attributions: tensor([[ 0.0008, 0.0019, 0.0009],
221-
[ 0.1892, -0.0045, 0.2445]])
222-
Convergence Delta: tensor([-0.2681, -0.2633, -0.2607, -0.2655, -0.2689, -0.2689, 1.4493, -0.2688])
219+
GradientShap Attributions: tensor([[-0.1542, -1.6229, -1.5835],
220+
[-0.3916, -0.2836, -4.6851]])
221+
Convergence Delta: tensor([ 0.0000, -0.0005, -0.0029, -0.0084, -0.0087, -0.0405, 0.0000, -0.0084])
223222
224223
```
225224
Deltas are computed for each `n_samples * input.shape[0]` example. The user can,
@@ -243,8 +242,8 @@ print('DeepLift Attributions: ', attributions, ' Convergence Delta: ', delta)
243242
```
244243
Output
245244
```
246-
DeepLift Attributions: tensor([[0.0628, 0.1314, 0.0747],
247-
[0.0930, 0.0120, 0.1639]])
245+
DeepLift Attributions: tensor([[-0.5922, -1.5497, -1.0067],
246+
[ 0.0000, -0.2219, -5.1991])
248247
Convergence Delta: tensor([0., 0.])
249248
```
250249
DeepLift assigns similar attribution scores as Integrated Gradients to inputs,
@@ -269,12 +268,12 @@ print('DeepLiftSHAP Attributions: ', attributions, ' Convergence Delta: ', delta
269268
```
270269
Output
271270
```
272-
DeepLiftShap Attributions: tensor([0.0627, 0.1313, 0.0747],
273-
[0.0929, 0.0120, 0.1637], grad_fn=<MeanBackward1>)
274-
Convergence Delta: tensor([-2.9802e-08, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
275-
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.9802e-08,
276-
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
277-
0.0000e+00, 0.0000e+00, 2.9802e-08, 0.0000e+00, 2.9802e-08])
271+
DeepLiftShap Attributions: tensor([[-5.9169e-01, -1.5491e+00, -1.0076e+00],
272+
[-4.7101e-03, -2.2300e-01, -5.1926e+00]], grad_fn=<MeanBackward1>)
273+
Convergence Delta: tensor([-4.6120e-03, -1.6267e-03, -5.1045e-04, -1.4184e-03, -6.8886e-03,
274+
-2.2224e-02, 0.0000e+00, -2.8790e-02, -4.1285e-03, -2.7295e-02,
275+
-3.2349e-03, -1.6265e-03, -4.7684e-07, -1.4191e-03, -6.8889e-03,
276+
-2.2224e-02, 0.0000e+00, -2.4792e-02, -4.1289e-03, -2.7296e-02])
278277
```
279278
`DeepLiftShap` uses `DeepLift` to compute attribution score for each
280279
input-baseline pair and averages it for each input across all baselines.
@@ -303,10 +302,10 @@ print('IG + SmoothGrad Attributions: ', attributions, ' Convergence Delta: ', de
303302
```
304303
Output
305304
```
306-
IG + SmoothGrad Attributions: tensor([[0.0631, 0.1335, 0.0723],
307-
[0.0911, 0.0142, 0.1636]])
308-
Convergence Delta: tensor([ 1.4901e-07, -8.9407e-08, 1.1921e-07,
309-
1.4901e-07, 1.1921e-07, -1.7881e-07, -5.9605e-08, 5.9605e-08])
305+
IG + SmoothGrad Attributions: tensor([[-0.4574, -1.5493, -1.0893],
306+
[ 0.0000, -0.2647, -5.1619]])
307+
Convergence Delta: tensor([ 0.0000e+00, 2.3842e-07, 0.0000e+00, -2.3842e-07, 0.0000e+00,
308+
-4.7684e-07, 0.0000e+00, -4.7684e-07])
310309
311310
```
312311
The number of elements in the `delta` tensor is equal to: `n_samples * input.shape[0]`
@@ -334,8 +333,8 @@ print('Neuron Attributions: ', attributions)
334333
```
335334
Output
336335
```
337-
Neuron Attributions: tensor([[0.0106, 0.0247, 0.0150],
338-
[0.0144, 0.0021, 0.0301]])
336+
Neuron Attributions: tensor([[ 0.0000, 0.0000, 0.0000],
337+
[ 1.3358, 0.0000, -1.6811]])
339338
```
340339

341340
Layer conductance shows the importance of neurons for a layer and given input.
@@ -351,9 +350,9 @@ print('Layer Attributions: ', attributions, ' Convergence Delta: ', delta)
351350
```
352351
Outputs
353352
```
354-
Layer Attributions: tensor([[0.0000, 0.0515, 0.1811],
355-
[0.0000, 0.0477, 0.1652]], grad_fn=<SumBackward1>)
356-
Convergence Delta: tensor([-0.0363, -0.0560])
353+
Layer Attributions: tensor([[ 0.0000, 0.0000, -3.0856],
354+
[ 0.0000, -0.3488, -4.9638]], grad_fn=<SumBackward1>)
355+
Convergence Delta: tensor([0.0630, 0.1084])
357356
```
358357

359358
Similar to other attribution algorithms that return convergence delta, LayerConductance

‎captum/attr/_core/deep_lift.py

+48-8
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,31 @@ def attribute(
101101
If inputs is a tuple of tensors, baselines must also be
102102
a tuple of tensors, with matching dimensions to inputs.
103103
Default: zero tensor for each input tensor
104-
target (int, optional): Output index for which gradient is computed
105-
(for classification cases, this is the target class).
104+
target (int, tuple, tensor or list, optional): Output indices for
105+
which gradients are computed (for classification cases,
106+
this is usually the target class).
106107
If the network returns a scalar value per example,
107-
no target index is necessary. (Note: Tuples for multi
108-
-dimensional output indices will be supported soon.)
108+
no target index is necessary.
109+
For general 2D outputs, targets can be either:
110+
111+
- a single integer or a tensor containing a single
112+
integer, which is applied to all input examples
113+
114+
- a list of integers or a 1D tensor, with length matching
115+
the number of examples in inputs (dim 0). Each integer
116+
is applied as the target for the corresponding example.
117+
118+
For outputs with > 2 dimensions, targets can be either:
119+
120+
- A single tuple, which contains #output_dims - 1
121+
elements. This target index is applied to all examples.
122+
123+
- A list of tuples with length equal to the number of
124+
examples in inputs (dim 0), and each tuple containing
125+
#output_dims - 1 elements. Each tuple is applied as the
126+
target for the corresponding example.
127+
128+
Default: None
109129
additional_forward_args (tuple, optional): If the forward function
110130
requires additional arguments other than the inputs for
111131
which attributions should not be computed, this argument
@@ -372,11 +392,31 @@ def attribute(
372392
first dimension. It is recommended that the number of
373393
samples in the baselines' tensors is larger than one.
374394
Default: zero tensor for each input tensor
375-
target (int, optional): Output index for which gradient is computed
376-
(for classification cases, this is the target class).
395+
target (int, tuple, tensor or list, optional): Output indices for
396+
which gradients are computed (for classification cases,
397+
this is usually the target class).
377398
If the network returns a scalar value per example,
378-
no target index is necessary. (Note: Tuples for multi
379-
-dimensional output indices will be supported soon.)
399+
no target index is necessary.
400+
For general 2D outputs, targets can be either:
401+
402+
- a single integer or a tensor containing a single
403+
integer, which is applied to all input examples
404+
405+
- a list of integers or a 1D tensor, with length matching
406+
the number of examples in inputs (dim 0). Each integer
407+
is applied as the target for the corresponding example.
408+
409+
For outputs with > 2 dimensions, targets can be either:
410+
411+
- A single tuple, which contains #output_dims - 1
412+
elements. This target index is applied to all examples.
413+
414+
- A list of tuples with length equal to the number of
415+
examples in inputs (dim 0), and each tuple containing
416+
#output_dims - 1 elements. Each tuple is applied as the
417+
target for the corresponding example.
418+
419+
Default: None
380420
additional_forward_args (tuple, optional): If the forward function
381421
requires additional arguments other than the inputs for
382422
which attributions should not be computed, this argument

‎captum/attr/_core/gradient_shap.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,31 @@ def attribute(
9898
corresponds to the input with the same index in the inputs
9999
tuple.
100100
Default: 0.0
101-
target (int, optional): Output index for which gradient is computed
102-
(for classification cases, this is the target class).
101+
target (int, tuple, tensor or list, optional): Output indices for
102+
which gradients are computed (for classification cases,
103+
this is usually the target class).
103104
If the network returns a scalar value per example,
104-
no target index is necessary. (Note: Tuples for multi
105-
-dimensional output indices will be supported soon.)
105+
no target index is necessary.
106+
For general 2D outputs, targets can be either:
107+
108+
- a single integer or a tensor containing a single
109+
integer, which is applied to all input examples
110+
111+
- a list of integers or a 1D tensor, with length matching
112+
the number of examples in inputs (dim 0). Each integer
113+
is applied as the target for the corresponding example.
114+
115+
For outputs with > 2 dimensions, targets can be either:
116+
117+
- A single tuple, which contains #output_dims - 1
118+
elements. This target index is applied to all examples.
119+
120+
- A list of tuples with length equal to the number of
121+
examples in inputs (dim 0), and each tuple containing
122+
#output_dims - 1 elements. Each tuple is applied as the
123+
target for the corresponding example.
124+
125+
Default: None
106126
additional_forward_args (tuple, optional): If the forward function
107127
requires additional arguments other than the inputs for
108128
which attributions should not be computed, this argument

‎captum/attr/_core/input_x_gradient.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,31 @@ def attribute(self, inputs, target=None, additional_forward_args=None):
3131
to the number of examples (aka batch size), and if
3232
mutliple input tensors are provided, the examples must
3333
be aligned appropriately.
34-
target (int, optional): Output index for which gradient is computed
35-
(for classification cases, this is the target class).
34+
target (int, tuple, tensor or list, optional): Output indices for
35+
which gradients are computed (for classification cases,
36+
this is usually the target class).
3637
If the network returns a scalar value per example,
37-
no target index is necessary. (Note: Tuples for multi
38-
-dimensional output indices will be supported soon.)
38+
no target index is necessary.
39+
For general 2D outputs, targets can be either:
40+
41+
- a single integer or a tensor containing a single
42+
integer, which is applied to all input examples
43+
44+
- a list of integers or a 1D tensor, with length matching
45+
the number of examples in inputs (dim 0). Each integer
46+
is applied as the target for the corresponding example.
47+
48+
For outputs with > 2 dimensions, targets can be either:
49+
50+
- A single tuple, which contains #output_dims - 1
51+
elements. This target index is applied to all examples.
52+
53+
- A list of tuples with length equal to the number of
54+
examples in inputs (dim 0), and each tuple containing
55+
#output_dims - 1 elements. Each tuple is applied as the
56+
target for the corresponding example.
57+
58+
Default: None
3959
additional_forward_args (tuple, optional): If the forward function
4060
requires additional arguments other than the inputs for
4161
which attributions should not be computed, this argument

‎captum/attr/_core/integrated_gradients.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,30 @@ def attribute(
6161
If inputs is a tuple of tensors, baselines must also be
6262
a tuple of tensors, with matching dimensions to inputs.
6363
Default: zero tensor for each input tensor
64-
target (int, optional): Output index for which gradient is computed
65-
(for classification cases, this is the target class).
64+
target (int, tuple, tensor or list, optional): Output indices for
65+
which gradients are computed (for classification cases,
66+
this is usually the target class).
6667
If the network returns a scalar value per example,
67-
no target index is necessary. (Note: Tuples for multi
68-
-dimensional output indices will be supported soon.)
68+
no target index is necessary.
69+
For general 2D outputs, targets can be either:
70+
71+
- a single integer or a tensor containing a single
72+
integer, which is applied to all input examples
73+
74+
- a list of integers or a 1D tensor, with length matching
75+
the number of examples in inputs (dim 0). Each integer
76+
is applied as the target for the corresponding example.
77+
78+
For outputs with > 2 dimensions, targets can be either:
79+
80+
- A single tuple, which contains #output_dims - 1
81+
elements. This target index is applied to all examples.
82+
83+
- A list of tuples with length equal to the number of
84+
examples in inputs (dim 0), and each tuple containing
85+
#output_dims - 1 elements. Each tuple is applied as the
86+
target for the corresponding example.
87+
6988
Default: None
7089
additional_forward_args (tuple, optional): If the forward function
7190
requires additional arguments other than the inputs for

‎captum/attr/_core/internal_influence.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,30 @@ def attribute(
7474
baselines must also be a tuple of tensors, with matching
7575
dimensions to inputs.
7676
Default: zero tensor for each input tensor
77-
target (int, optional): Output index for which gradient is computed
78-
(for classification cases, this is the target class).
77+
target (int, tuple, tensor or list, optional): Output indices for
78+
which gradients are computed (for classification cases,
79+
this is usually the target class).
7980
If the network returns a scalar value per example,
80-
no target index is necessary. (Note: Tuples for multi
81-
-dimensional output indices will be supported soon.)
81+
no target index is necessary.
82+
For general 2D outputs, targets can be either:
83+
84+
- a single integer or a tensor containing a single
85+
integer, which is applied to all input examples
86+
87+
- a list of integers or a 1D tensor, with length matching
88+
the number of examples in inputs (dim 0). Each integer
89+
is applied as the target for the corresponding example.
90+
91+
For outputs with > 2 dimensions, targets can be either:
92+
93+
- A single tuple, which contains #output_dims - 1
94+
elements. This target index is applied to all examples.
95+
96+
- A list of tuples with length equal to the number of
97+
examples in inputs (dim 0), and each tuple containing
98+
#output_dims - 1 elements. Each tuple is applied as the
99+
target for the corresponding example.
100+
82101
Default: None
83102
additional_forward_args (tuple, optional): If the forward function
84103
requires additional arguments other than the inputs for

‎captum/attr/_core/layer_conductance.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,30 @@ def attribute(
8080
baselines must also be a tuple of tensors, with matching
8181
dimensions to inputs.
8282
Default: zero tensor for each input tensor
83-
target (int, optional): Output index for which gradient is computed
84-
(for classification cases, this is the target class).
83+
target (int, tuple, tensor or list, optional): Output indices for
84+
which gradients are computed (for classification cases,
85+
this is usually the target class).
8586
If the network returns a scalar value per example,
86-
no target index is necessary. (Note: Tuples for multi
87-
-dimensional output indices will be supported soon.)
87+
no target index is necessary.
88+
For general 2D outputs, targets can be either:
89+
90+
- a single integer or a tensor containing a single
91+
integer, which is applied to all input examples
92+
93+
- a list of integers or a 1D tensor, with length matching
94+
the number of examples in inputs (dim 0). Each integer
95+
is applied as the target for the corresponding example.
96+
97+
For outputs with > 2 dimensions, targets can be either:
98+
99+
- A single tuple, which contains #output_dims - 1
100+
elements. This target index is applied to all examples.
101+
102+
- A list of tuples with length equal to the number of
103+
examples in inputs (dim 0), and each tuple containing
104+
#output_dims - 1 elements. Each tuple is applied as the
105+
target for the corresponding example.
106+
88107
Default: None
89108
additional_forward_args (tuple, optional): If the forward function
90109
requires additional arguments other than the inputs for

‎captum/attr/_core/layer_gradient_x_activation.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,30 @@ def attribute(self, inputs, target=None, additional_forward_args=None):
4040
that for all given input tensors, dimension 0 corresponds
4141
to the number of examples, and if mutliple input tensors
4242
are provided, the examples must be aligned appropriately.
43-
target (int, optional): Output index for which gradient is computed
44-
(for classification cases, this is the target class).
43+
target (int, tuple, tensor or list, optional): Output indices for
44+
which gradients are computed (for classification cases,
45+
this is usually the target class).
4546
If the network returns a scalar value per example,
46-
no target index is necessary. (Note: Tuples for multi
47-
-dimensional output indices will be supported soon.)
47+
no target index is necessary.
48+
For general 2D outputs, targets can be either:
49+
50+
- a single integer or a tensor containing a single
51+
integer, which is applied to all input examples
52+
53+
- a list of integers or a 1D tensor, with length matching
54+
the number of examples in inputs (dim 0). Each integer
55+
is applied as the target for the corresponding example.
56+
57+
For outputs with > 2 dimensions, targets can be either:
58+
59+
- A single tuple, which contains #output_dims - 1
60+
elements. This target index is applied to all examples.
61+
62+
- A list of tuples with length equal to the number of
63+
examples in inputs (dim 0), and each tuple containing
64+
#output_dims - 1 elements. Each tuple is applied as the
65+
target for the corresponding example.
66+
4867
Default: None
4968
additional_forward_args (tuple, optional): If the forward function
5069
requires additional arguments other than the inputs for

‎captum/attr/_core/neuron_conductance.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,30 @@ def attribute(
7979
baselines must also be a tuple of tensors, with matching
8080
dimensions to inputs.
8181
Default: zero tensor for each input tensor
82-
target (int, optional): Output index for which gradient is computed
83-
(for classification cases, this is the target class).
82+
target (int, tuple, tensor or list, optional): Output indices for
83+
which gradients are computed (for classification cases,
84+
this is usually the target class).
8485
If the network returns a scalar value per example,
85-
no target index is necessary. (Note: Tuples for multi
86-
-dimensional output indices will be supported soon.)
86+
no target index is necessary.
87+
For general 2D outputs, targets can be either:
88+
89+
- a single integer or a tensor containing a single
90+
integer, which is applied to all input examples
91+
92+
- a list of integers or a 1D tensor, with length matching
93+
the number of examples in inputs (dim 0). Each integer
94+
is applied as the target for the corresponding example.
95+
96+
For outputs with > 2 dimensions, targets can be either:
97+
98+
- A single tuple, which contains #output_dims - 1
99+
elements. This target index is applied to all examples.
100+
101+
- A list of tuples with length equal to the number of
102+
examples in inputs (dim 0), and each tuple containing
103+
#output_dims - 1 elements. Each tuple is applied as the
104+
target for the corresponding example.
105+
87106
Default: None
88107
additional_forward_args (tuple, optional): If the forward function
89108
requires additional arguments other than the inputs for

‎captum/attr/_core/saliency.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,30 @@ def attribute(self, inputs, target=None, abs=True, additional_forward_args=None)
3737
to the number of examples (aka batch size), and if
3838
multiple input tensors are provided, the examples must
3939
be aligned appropriately.
40-
target (int, optional): Output index for which gradient is computed
41-
(for classification cases, this is the target class).
40+
target (int, tuple, tensor or list, optional): Output indices for
41+
which gradients are computed (for classification cases,
42+
this is usually the target class).
4243
If the network returns a scalar value per example,
43-
no target index is necessary. (Note: Tuples for multi
44-
-dimensional output indices will be supported soon.)
44+
no target index is necessary.
45+
For general 2D outputs, targets can be either:
46+
47+
- a single integer or a tensor containing a single
48+
integer, which is applied to all input examples
49+
50+
- a list of integers or a 1D tensor, with length matching
51+
the number of examples in inputs (dim 0). Each integer
52+
is applied as the target for the corresponding example.
53+
54+
For outputs with > 2 dimensions, targets can be either:
55+
56+
- A single tuple, which contains #output_dims - 1
57+
elements. This target index is applied to all examples.
58+
59+
- A list of tuples with length equal to the number of
60+
examples in inputs (dim 0), and each tuple containing
61+
#output_dims - 1 elements. Each tuple is applied as the
62+
target for the corresponding example.
63+
4564
Default: None
4665
abs (bool, optional): Returns absolute value of gradients if set
4766
to True, otherwise returns the (signed) gradients if

‎captum/attr/_utils/gradient.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ def compute_gradients(
8383
with torch.autograd.set_grad_enabled(True):
8484
# runs forward pass
8585
output = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
86-
assert output[0].numel() == 1, "Target not provided when necessary, cannot"
87-
"take gradient with respect to multiple outputs."
86+
assert output[0].numel() == 1, (
87+
"Target not provided when necessary, cannot"
88+
" take gradient with respect to multiple outputs."
89+
)
8890
# torch.unbind(forward_out) is a list of scalar tensor tuples and
8991
# contains batch_size * #steps elements
9092
grads = torch.autograd.grad(torch.unbind(output), inputs)
@@ -262,8 +264,10 @@ def forward_hook(module, inp, out):
262264

263265
hook = layer.register_forward_hook(forward_hook)
264266
output = _run_forward(forward_fn, inputs, target_ind, additional_forward_args)
265-
assert output[0].numel() == 1, "Target not provided when necessary, cannot"
266-
"take gradient with respect to multiple outputs."
267+
assert output[0].numel() == 1, (
268+
"Target not provided when necessary, cannot"
269+
" take gradient with respect to multiple outputs."
270+
)
267271
# Remove unnecessary forward hook.
268272
hook.remove()
269273

‎captum/insights/api.py

+126-99
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
from torch import Tensor
1111
from torch.nn import Module
1212

13-
PredictionScore = namedtuple("PredictionScore", "score label")
13+
OutputScore = namedtuple("OutputScore", "score index label")
1414
VisualizationOutput = namedtuple(
15-
"VisualizationOutput", "feature_outputs actual predicted"
15+
"VisualizationOutput", "feature_outputs actual predicted active_index"
1616
)
1717
Contribution = namedtuple("Contribution", "name percent")
18+
SampleCache = namedtuple("SampleCache", "inputs additional_forward_args label")
1819

1920

2021
class FilterConfig(NamedTuple):
@@ -44,6 +45,7 @@ def __init__(
4445
features: Union[List[BaseFeature], BaseFeature],
4546
dataset: Iterable[Data],
4647
score_func: Optional[Callable] = None,
48+
use_label_for_attr: bool = True,
4749
):
4850
if not isinstance(models, List):
4951
models = [models]
@@ -56,20 +58,34 @@ def __init__(
5658
self.features = features
5759
self.dataset = dataset
5860
self.score_func = score_func
61+
self._outputs = []
5962
self._config = FilterConfig(steps=25, prediction="all", classes=[], count=4)
63+
self._use_label_for_attr = use_label_for_attr
64+
65+
def _calculate_attribution_from_cache(
66+
self, index: int, target: Optional[Tensor]
67+
) -> VisualizationOutput:
68+
c = self._outputs[index][1]
69+
return self._calculate_vis_output(
70+
c.inputs, c.additional_forward_args, c.label, torch.tensor(target)
71+
)
6072

6173
def _calculate_attribution(
6274
self,
6375
net: Module,
6476
baselines: Optional[List[Tuple[Tensor, ...]]],
6577
data: Tuple[Tensor, ...],
6678
additional_forward_args: Optional[Tuple[Tensor, ...]],
67-
label: Optional[Tensor],
79+
label: Optional[Union[Tensor]],
6880
) -> Tensor:
6981
ig = IntegratedGradients(net)
7082
# TODO support multiple baselines
7183
baseline = baselines[0] if len(baselines) > 0 else None
72-
label = None if label is None or label.nelement() == 0 else label
84+
label = (
85+
None
86+
if not self._use_label_for_attr or label is None or label.nelement() == 0
87+
else label
88+
)
7389
attr_ig = ig.attribute(
7490
data,
7591
baselines=baseline,
@@ -98,11 +114,11 @@ def render(self, blocking=False, debug=False):
98114

99115
def _get_labels_from_scores(
100116
self, scores: Tensor, indices: Tensor
101-
) -> List[PredictionScore]:
117+
) -> List[OutputScore]:
102118
pred_scores = []
103119
for i in range(len(indices)):
104-
score = scores[i].item()
105-
pred_scores.append(PredictionScore(score, self.classes[indices[i]]))
120+
score = scores[i]
121+
pred_scores.append(OutputScore(score, indices[i], self.classes[indices[i]]))
106122
return pred_scores
107123

108124
def _transform(
@@ -123,7 +139,7 @@ def _transform(
123139
transformed_inputs = transforms(transformed_inputs)
124140

125141
if batch:
126-
transformed_inputs.unsqueeze_(0)
142+
transformed_inputs = transformed_inputs.unsqueeze(0)
127143

128144
return transformed_inputs
129145

@@ -141,22 +157,20 @@ def _calculate_net_contrib(self, attrs_per_input_feature: List[Tensor]):
141157
return net_contrib.tolist()
142158

143159
def _predictions_matches_labels(
144-
self,
145-
predicted_scores: List[PredictionScore],
146-
actual_labels: Union[str, List[str]],
160+
self, predicted_scores: List[OutputScore], labels: Union[str, List[str]]
147161
) -> bool:
148162
if len(predicted_scores) == 0:
149163
return False
150164

151165
predicted_label = predicted_scores[0].label
152166

153-
if isinstance(actual_labels, List):
154-
return predicted_label in actual_labels
167+
if isinstance(labels, List):
168+
return predicted_label in labels
155169

156-
return actual_labels == predicted_label
170+
return labels == predicted_label
157171

158172
def _should_keep_prediction(
159-
self, predicted_scores: List[PredictionScore], actual_label: str
173+
self, predicted_scores: List[OutputScore], actual_label: str
160174
) -> bool:
161175
# filter by class
162176
if len(self._config.classes) != 0:
@@ -179,104 +193,117 @@ def _should_keep_prediction(
179193

180194
return True
181195

182-
def _get_outputs(self) -> List[VisualizationOutput]:
183-
batch_data = next(self.dataset)
196+
def _calculate_vis_output(
197+
self, inputs, additional_forward_args, label, target=None
198+
) -> Optional[VisualizationOutput]:
184199
net = self.models[0] # TODO process multiple models
185-
vis_outputs = []
186200

187-
for inputs, additional_forward_args, label in _batched_generator(
188-
inputs=batch_data.inputs,
189-
additional_forward_args=batch_data.additional_args,
190-
target_ind=batch_data.labels,
191-
internal_batch_size=1, # should be 1 until we have batch label support
192-
):
193-
# initialize baselines
194-
baseline_transforms_len = len(self.features[0].baseline_transforms or [])
195-
baselines = [
196-
[None] * len(self.features) for _ in range(baseline_transforms_len)
197-
]
198-
transformed_inputs = list(inputs)
199-
200-
for feature_i, feature in enumerate(self.features):
201-
if feature.input_transforms is not None:
202-
transformed_inputs[feature_i] = self._transform(
203-
feature.input_transforms, transformed_inputs[feature_i], True
201+
# initialize baselines
202+
baseline_transforms_len = len(self.features[0].baseline_transforms or [])
203+
baselines = [
204+
[None] * len(self.features) for _ in range(baseline_transforms_len)
205+
]
206+
transformed_inputs = list(inputs)
207+
208+
# transformed_inputs = list([i.clone() for i in inputs])
209+
for feature_i, feature in enumerate(self.features):
210+
if feature.input_transforms is not None:
211+
transformed_inputs[feature_i] = self._transform(
212+
feature.input_transforms, transformed_inputs[feature_i], True
213+
)
214+
if feature.baseline_transforms is not None:
215+
assert baseline_transforms_len == len(
216+
feature.baseline_transforms
217+
), "Must have same number of baselines across all features"
218+
219+
for baseline_i, baseline_transform in enumerate(
220+
feature.baseline_transforms
221+
):
222+
baselines[baseline_i][feature_i] = self._transform(
223+
baseline_transform, transformed_inputs[feature_i], True
204224
)
205-
if feature.baseline_transforms is not None:
206-
assert baseline_transforms_len == len(
207-
feature.baseline_transforms
208-
), "Must have same number of baselines across all features"
209-
210-
for baseline_i, baseline_transform in enumerate(
211-
feature.baseline_transforms
212-
):
213-
baselines[baseline_i][feature_i] = self._transform(
214-
baseline_transform, transformed_inputs[feature_i], True
215-
)
216-
217-
outputs = _run_forward(
218-
net, tuple(transformed_inputs), additional_forward_args
219-
)
220225

221-
if self.score_func is not None:
222-
outputs = self.score_func(outputs)
223-
224-
if outputs.nelement() == 1:
225-
scores = outputs
226-
predicted = scores.round().to(torch.int)
227-
else:
228-
scores, predicted = outputs.topk(min(4, outputs.shape[-1]))
229-
230-
scores = scores.cpu().squeeze(0)
231-
predicted = predicted.cpu().squeeze(0)
232-
233-
actual_label = self.classes[label[0]] if label is not None else None
234-
predicted_scores = self._get_labels_from_scores(scores, predicted)
235-
236-
# Filter based on UI configuration
237-
if not self._should_keep_prediction(predicted_scores, actual_label):
238-
continue
239-
240-
baselines = [tuple(b) for b in baselines]
241-
242-
# attributions are given per input*
243-
# inputs given to the model are described via `self.features`
244-
#
245-
# *an input contains multiple features that represent it
246-
# e.g. all the pixels that describe an image is an input
247-
attrs_per_input_feature = self._calculate_attribution(
248-
net,
249-
baselines,
250-
tuple(transformed_inputs),
251-
additional_forward_args,
252-
label,
226+
outputs = _run_forward(net, tuple(transformed_inputs), additional_forward_args)
227+
228+
if self.score_func is not None:
229+
outputs = self.score_func(outputs)
230+
231+
if outputs.nelement() == 1:
232+
scores = outputs
233+
predicted = scores.round().to(torch.int)
234+
else:
235+
scores, predicted = outputs.topk(min(4, outputs.shape[-1]))
236+
237+
scores = scores.cpu().squeeze(0)
238+
predicted = predicted.cpu().squeeze(0)
239+
240+
if label is not None and len(label) > 0:
241+
actual_label = OutputScore(
242+
score=0, index=label[0], label=self.classes[label[0]]
253243
)
244+
else:
245+
actual_label = None
254246

255-
net_contrib = self._calculate_net_contrib(attrs_per_input_feature)
247+
predicted_scores = self._get_labels_from_scores(scores, predicted)
256248

257-
# the features per input given
258-
features_per_input = [
259-
feature.visualize(attr, data, contrib)
260-
for feature, attr, data, contrib in zip(
261-
self.features, attrs_per_input_feature, inputs, net_contrib
262-
)
263-
]
249+
# Filter based on UI configuration
250+
if not self._should_keep_prediction(predicted_scores, actual_label):
251+
return None
252+
253+
baselines = [tuple(b) for b in baselines]
254+
255+
if target is None:
256+
target = predicted_scores[0].index if len(predicted_scores) > 0 else None
257+
258+
# attributions are given per input*
259+
# inputs given to the model are described via `self.features`
260+
#
261+
# *an input contains multiple features that represent it
262+
# e.g. all the pixels that describe an image is an input
263+
264+
attrs_per_input_feature = self._calculate_attribution(
265+
net, baselines, tuple(transformed_inputs), additional_forward_args, target
266+
)
267+
268+
net_contrib = self._calculate_net_contrib(attrs_per_input_feature)
264269

265-
output = VisualizationOutput(
266-
feature_outputs=features_per_input,
267-
actual=actual_label,
268-
predicted=predicted_scores,
270+
# the features per input given
271+
features_per_input = [
272+
feature.visualize(attr, data, contrib)
273+
for feature, attr, data, contrib in zip(
274+
self.features, attrs_per_input_feature, inputs, net_contrib
269275
)
276+
]
270277

271-
vis_outputs.append(output)
278+
return VisualizationOutput(
279+
feature_outputs=features_per_input,
280+
actual=actual_label,
281+
predicted=predicted_scores,
282+
active_index=target if target is not None else actual_label.index,
283+
)
284+
285+
def _get_outputs(self) -> List[VisualizationOutput]:
286+
batch_data = next(self.dataset)
287+
vis_outputs = []
288+
289+
for inputs, additional_forward_args, label in _batched_generator(
290+
inputs=batch_data.inputs,
291+
additional_forward_args=batch_data.additional_args,
292+
target_ind=batch_data.labels,
293+
internal_batch_size=1, # should be 1 until we have batch label support
294+
):
295+
output = self._calculate_vis_output(inputs, additional_forward_args, label)
296+
if output is not None:
297+
cache = SampleCache(inputs, additional_forward_args, label)
298+
vis_outputs.append((output, cache))
272299

273300
return vis_outputs
274301

275302
def visualize(self):
276-
output_list = []
277-
while len(output_list) < self._config.count:
303+
self._outputs = []
304+
while len(self._outputs) < self._config.count:
278305
try:
279-
output_list.extend(self._get_outputs())
306+
self._outputs.extend(self._get_outputs())
280307
except StopIteration:
281308
break
282-
return output_list
309+
return [o[0] for o in self._outputs]

‎captum/insights/features.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def visualization_type(self) -> str:
6060
return "image"
6161

6262
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
63-
attribution.squeeze_()
64-
data.squeeze_()
63+
attribution = attribution.squeeze()
64+
data = data.squeeze()
6565
data_t = np.transpose(data.cpu().detach().numpy(), (1, 2, 0))
6666
attribution_t = np.transpose(
6767
attribution.squeeze().cpu().detach().numpy(), (1, 2, 0)
@@ -111,8 +111,8 @@ def visualization_type(self) -> str:
111111
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
112112
text = self.visualization_transform(data)
113113

114-
attribution.squeeze_(0)
115-
data.squeeze_(0)
114+
attribution = attribution.squeeze(0)
115+
data = data.squeeze(0)
116116
attribution = attribution.sum(dim=1)
117117

118118
# L-Infinity norm
@@ -142,8 +142,8 @@ def visualization_type(self) -> str:
142142
return "general"
143143

144144
def visualize(self, attribution, data, contribution_frac) -> FeatureOutput:
145-
attribution.squeeze_(0)
146-
data.squeeze_(0)
145+
attribution = attribution.squeeze(0)
146+
data = data.squeeze(0)
147147

148148
# L-2 norm
149149
normalized_attribution = attribution / attribution.norm()

‎captum/insights/frontend/public/index.html

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<head>
44
<meta charset="utf-8" />
55
<meta name="viewport" content="width=device-width, initial-scale=1" />
6-
<title>Captum Visualization</title>
6+
<title>Captum Insights</title>
77
</head>
88
<body>
99
<noscript>You need to enable JavaScript to run this app.</noscript>

‎captum/insights/frontend/src/App.css

+16
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ button {
118118
border: solid 1px #ee4c2c;
119119
text-align: center;
120120
font-weight: 600;
121+
font-size: 1em;
121122
border-radius: 4px;
122123
padding: 6px 8px;
123124
display: inline-block;
@@ -150,13 +151,28 @@ button {
150151
display: block;
151152
}
152153

154+
.loading {
155+
margin-top: 150px;
156+
position: absolute;
157+
width: 100%;
158+
align-items: center;
159+
justify-content: center;
160+
display: flex;
161+
}
162+
153163
.panel {
154164
margin: 16px;
155165
padding: 24px;
156166
background: white;
157167
border-radius: 8px;
158168
display: flex;
159169
box-shadow: 0px 3px 6px 0px rgba(0, 0, 0, 0.18);
170+
transition: opacity 0.2s; /* for loading */
171+
}
172+
173+
.panel--loading {
174+
opacity: 0.5;
175+
pointer-events: none; /* disables all interactions inside panel */
160176
}
161177

162178
.panel--center {

‎captum/insights/frontend/src/App.js

+114-35
Original file line numberDiff line numberDiff line change
@@ -312,49 +312,104 @@ class Contributions extends React.Component {
312312
}
313313
}
314314

315+
class LabelButton extends React.Component {
316+
onClick = e => {
317+
e.preventDefault();
318+
this.props.onTargetClick(this.props.labelIndex, this.props.instance);
319+
};
320+
321+
render() {
322+
return (
323+
<button
324+
onClick={this.onClick}
325+
className={cx({
326+
btn: true,
327+
"btn--solid": this.props.active,
328+
"btn--outline": !this.props.active
329+
})}
330+
>
331+
{this.props.children}
332+
</button>
333+
);
334+
}
335+
}
336+
315337
class Visualization extends React.Component {
338+
constructor(props) {
339+
super(props);
340+
this.state = {
341+
loading: false
342+
};
343+
}
344+
345+
onTargetClick = (labelIndex, instance) => {
346+
this.setState({ loading: true });
347+
this.props.onTargetClick(labelIndex, instance, () =>
348+
this.setState({ loading: false })
349+
);
350+
};
351+
316352
render() {
317353
const data = this.props.data;
318354
const features = data.feature_outputs.map(f => <Feature data={f} />);
319355

320356
return (
321-
<div className="panel panel--long">
322-
<div className="panel__column">
323-
<div className="panel__column__title">Predicted</div>
324-
<div className="panel__column__body">
325-
{data.predicted.map((p, i) => (
357+
<>
358+
{this.state.loading && (
359+
<div className="loading">
360+
<Spinner />
361+
</div>
362+
)}
363+
<div
364+
className={cx({
365+
panel: true,
366+
"panel--long": true,
367+
"panel--loading": this.state.loading
368+
})}
369+
>
370+
<div className="panel__column">
371+
<div className="panel__column__title">Predicted</div>
372+
<div className="panel__column__body">
373+
{data.predicted.map(p => (
374+
<div className="row row--padding">
375+
<LabelButton
376+
onTargetClick={this.onTargetClick}
377+
labelIndex={p.index}
378+
instance={this.props.instance}
379+
active={p.index === data.active_index}
380+
>
381+
{p.label} ({p.score.toFixed(3)})
382+
</LabelButton>
383+
</div>
384+
))}
385+
</div>
386+
</div>
387+
<div className="panel__column">
388+
<div className="panel__column__title">Label</div>
389+
<div className="panel__column__body">
326390
<div className="row row--padding">
327-
<div
328-
className={cx({
329-
btn: true,
330-
"btn--solid": i === 0,
331-
"btn--outline": i !== 0
332-
})}
391+
<LabelButton
392+
onTargetClick={this.onTargetClick}
393+
labelIndex={data.actual.index}
394+
instance={this.props.instance}
395+
active={data.actual.index === data.active_index}
333396
>
334-
{p.label} ({p.score.toFixed(3)})
335-
</div>
397+
{data.actual.label}
398+
</LabelButton>
336399
</div>
337-
))}
338-
</div>
339-
</div>
340-
<div className="panel__column">
341-
<div className="panel__column__title">Label</div>
342-
<div className="panel__column__body">
343-
<div className="row row--padding">
344-
<div className="btn btn--outline">{data.actual}</div>
345400
</div>
346401
</div>
347-
</div>
348-
<div className="panel__column">
349-
<div className="panel__column__title">Contribution</div>
350-
<div className="panel__column__body">
351-
<div className="bar-chart">
352-
<Contributions feature_outputs={data.feature_outputs} />
402+
<div className="panel__column">
403+
<div className="panel__column__title">Contribution</div>
404+
<div className="panel__column__body">
405+
<div className="bar-chart">
406+
<Contributions feature_outputs={data.feature_outputs} />
407+
</div>
353408
</div>
354409
</div>
410+
<div className="panel__column panel__column--stretch">{features}</div>
355411
</div>
356-
<div className="panel__column panel__column--stretch">{features}</div>
357-
</div>
412+
</>
358413
);
359414
}
360415
}
@@ -385,7 +440,12 @@ function Visualizations(props) {
385440
return (
386441
<div className="viz">
387442
{props.data.map((v, i) => (
388-
<Visualization data={v} key={i} />
443+
<Visualization
444+
data={v}
445+
instance={i}
446+
key={i}
447+
onTargetClick={props.onTargetClick}
448+
/>
389449
))}
390450
</div>
391451
);
@@ -404,10 +464,8 @@ class App extends React.Component {
404464

405465
_fetchInit = () => {
406466
fetch("/init")
407-
.then(response => response.json())
408-
.then(response => {
409-
this.setState({ config: response });
410-
});
467+
.then(r => r.json())
468+
.then(r => this.setState({ config: r }));
411469
};
412470

413471
fetchData = filter_config => {
@@ -423,6 +481,23 @@ class App extends React.Component {
423481
.then(response => this.setState({ data: response, loading: false }));
424482
};
425483

484+
onTargetClick = (labelIndex, instance, callback) => {
485+
fetch("/attribute", {
486+
method: "POST",
487+
headers: {
488+
"Content-Type": "application/json"
489+
},
490+
body: JSON.stringify({ labelIndex, instance })
491+
})
492+
.then(response => response.json())
493+
.then(response => {
494+
const data = Object.assign([], this.state.data);
495+
data[instance] = response;
496+
this.setState({ data });
497+
callback();
498+
});
499+
};
500+
426501
render() {
427502
return (
428503
<div className="app">
@@ -432,7 +507,11 @@ class App extends React.Component {
432507
config={this.state.config}
433508
key={this.state.config}
434509
/>
435-
<Visualizations data={this.state.data} loading={this.state.loading} />
510+
<Visualizations
511+
data={this.state.data}
512+
loading={this.state.loading}
513+
onTargetClick={this.onTargetClick}
514+
/>
436515
</div>
437516
);
438517
}

‎captum/insights/server.py

+13
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from time import sleep
55
from typing import Optional
66

7+
from torch import Tensor
78
from flask import Flask, jsonify, render_template, request
89

910
app = Flask(
@@ -14,6 +15,8 @@
1415

1516

1617
def namedtuple_to_dict(obj):
18+
if isinstance(obj, Tensor):
19+
return obj.item()
1720
if hasattr(obj, "_asdict"): # detect namedtuple
1821
return dict(zip(obj._fields, (namedtuple_to_dict(item) for item in obj)))
1922
elif isinstance(obj, str): # iterables - strings
@@ -28,6 +31,16 @@ def namedtuple_to_dict(obj):
2831
return obj
2932

3033

34+
@app.route("/attribute", methods=["POST"])
35+
def attribute():
36+
r = request.json
37+
return jsonify(
38+
namedtuple_to_dict(
39+
visualizer._calculate_attribution_from_cache(r["instance"], r["labelIndex"])
40+
)
41+
)
42+
43+
3144
@app.route("/fetch", methods=["POST"])
3245
def fetch():
3346
visualizer._update_config(request.json)

‎docs/algorithms.md

+143-15
Large diffs are not rendered by default.

‎tutorials/CIFAR_TorchVision_Interpret.ipynb

+48-17
Large diffs are not rendered by default.

‎tutorials/IMDB_TorchText_Interpret.ipynb

+57-92
Large diffs are not rendered by default.

‎tutorials/Multimodal_VQA_Interpret.ipynb

+484-66
Large diffs are not rendered by default.

‎tutorials/Resnet_TorchVision_Interpret.ipynb

+12-20
Large diffs are not rendered by default.
25.1 KB
Loading
41.3 KB
Loading

0 commit comments

Comments
 (0)
Please sign in to comment.