Skip to content

Commit 9ce7714

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Type Hints Completion (#295)
Summary: This PR completes type hints for remaining methods, adding type hints to utility methods, and making type signatures for methods with the return_convergence_delta flag clearer using Literal overrides. Pull Request resolved: #295 Reviewed By: NarineK Differential Revision: D20182905 Pulled By: vivekmig fbshipit-source-id: 6889322297eff193c5614dec5fd6972ea41b2afe
1 parent a6e3a2e commit 9ce7714

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1438
-994
lines changed

captum/attr/_core/deep_lift.py

+75-63
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22
import typing
3-
from typing import Optional, Tuple, Union, Any, List, Callable, cast
3+
from typing import Tuple, Union, Any, List, Callable, cast
44

55
import warnings
66
import torch
@@ -13,6 +13,7 @@
1313
import numpy as np
1414

1515
from .._utils.common import (
16+
_is_tuple,
1617
_format_input,
1718
_format_baseline,
1819
_format_callable_baseline,
@@ -30,7 +31,12 @@
3031
)
3132
from .._utils.attribution import GradientAttribution
3233
from .._utils.gradient import apply_gradient_requirements, undo_gradient_requirements
33-
from .._utils.typing import TensorOrTupleOfTensors
34+
from .._utils.typing import (
35+
TensorOrTupleOfTensorsGeneric,
36+
Literal,
37+
TargetType,
38+
BaselineType,
39+
)
3440

3541

3642
# Check if module backward hook can safely be used for the module that produced
@@ -76,43 +82,39 @@ def __init__(self, model: Module) -> None:
7682
@typing.overload
7783
def attribute(
7884
self,
79-
inputs: TensorOrTupleOfTensors,
80-
baselines: Union[
81-
Tensor, int, float, Tuple[Union[Tensor, int, float], ...]
82-
] = None,
83-
target: Optional[
84-
Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]
85-
] = None,
85+
inputs: TensorOrTupleOfTensorsGeneric,
86+
baselines: BaselineType = None,
87+
target: TargetType = None,
8688
additional_forward_args: Any = None,
87-
custom_attribution_func: Callable[..., Tuple[Tensor, ...]] = None,
88-
) -> TensorOrTupleOfTensors:
89+
return_convergence_delta: Literal[False] = False,
90+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
91+
) -> TensorOrTupleOfTensorsGeneric:
8992
...
9093

9194
@typing.overload
9295
def attribute(
9396
self,
94-
inputs: TensorOrTupleOfTensors,
95-
baselines: Optional[
96-
Union[Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]
97-
] = None,
98-
target: Optional[
99-
Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]
100-
] = None,
97+
inputs: TensorOrTupleOfTensorsGeneric,
98+
baselines: BaselineType = None,
99+
target: TargetType = None,
101100
additional_forward_args: Any = None,
102-
return_convergence_delta: bool = False,
103-
custom_attribution_func: Callable[..., Tuple[Tensor, ...]] = None,
104-
) -> Union[TensorOrTupleOfTensors, Tuple[TensorOrTupleOfTensors, Tensor]]:
101+
*,
102+
return_convergence_delta: Literal[True],
103+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
104+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
105105
...
106106

107-
def attribute(
107+
def attribute( # type: ignore
108108
self,
109-
inputs,
110-
baselines=None,
111-
target=None,
112-
additional_forward_args=None,
113-
return_convergence_delta=False,
114-
custom_attribution_func=None,
115-
):
109+
inputs: TensorOrTupleOfTensorsGeneric,
110+
baselines: BaselineType = None,
111+
target: TargetType = None,
112+
additional_forward_args: Any = None,
113+
return_convergence_delta: bool = False,
114+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
115+
) -> Union[
116+
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
117+
]:
116118
r""""
117119
Implements DeepLIFT algorithm based on the following paper:
118120
Learning Important Features Through Propagating Activation Differences,
@@ -278,7 +280,7 @@ def attribute(
278280

279281
# Keeps track whether original input is a tuple or not before
280282
# converting it into a tuple.
281-
is_inputs_tuple = isinstance(inputs, tuple)
283+
is_inputs_tuple = _is_tuple(inputs)
282284

283285
inputs = _format_input(inputs)
284286
baselines = _format_baseline(baselines, inputs)
@@ -341,10 +343,8 @@ def attribute(
341343
def _construct_forward_func(
342344
self,
343345
forward_func: Callable,
344-
inputs: TensorOrTupleOfTensors,
345-
target: Optional[
346-
Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]
347-
] = None,
346+
inputs: Tuple,
347+
target: TargetType = None,
348348
additional_forward_args: Any = None,
349349
) -> Callable:
350350
def forward_fn():
@@ -533,39 +533,45 @@ def __init__(self, model: Module) -> None:
533533
@typing.overload # type: ignore
534534
def attribute(
535535
self,
536-
inputs: TensorOrTupleOfTensors,
537-
baselines: Union[TensorOrTupleOfTensors, Callable[..., TensorOrTupleOfTensors]],
538-
target: Optional[
539-
Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]
540-
] = None,
536+
inputs: TensorOrTupleOfTensorsGeneric,
537+
baselines: Union[
538+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
539+
],
540+
target: TargetType = None,
541541
additional_forward_args: Any = None,
542-
custom_attribution_func: Callable[..., Tuple[Tensor, ...]] = None,
543-
) -> TensorOrTupleOfTensors:
542+
return_convergence_delta: Literal[False] = False,
543+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
544+
) -> TensorOrTupleOfTensorsGeneric:
544545
...
545546

546547
@typing.overload
547548
def attribute(
548549
self,
549-
inputs: TensorOrTupleOfTensors,
550-
baselines: Union[TensorOrTupleOfTensors, Callable[..., TensorOrTupleOfTensors]],
551-
target: Optional[
552-
Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]
553-
] = None,
550+
inputs: TensorOrTupleOfTensorsGeneric,
551+
baselines: Union[
552+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
553+
],
554+
target: TargetType = None,
554555
additional_forward_args: Any = None,
555-
return_convergence_delta: bool = False,
556-
custom_attribution_func: Callable[..., Tuple[Tensor, ...]] = None,
557-
) -> Union[TensorOrTupleOfTensors, Tuple[TensorOrTupleOfTensors, Tensor]]:
556+
*,
557+
return_convergence_delta: Literal[True],
558+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
559+
) -> Tuple[TensorOrTupleOfTensorsGeneric, Tensor]:
558560
...
559561

560-
def attribute(
562+
def attribute( # type: ignore
561563
self,
562-
inputs,
563-
baselines,
564-
target=None,
565-
additional_forward_args=None,
566-
return_convergence_delta=False,
567-
custom_attribution_func=None,
568-
):
564+
inputs: TensorOrTupleOfTensorsGeneric,
565+
baselines: Union[
566+
TensorOrTupleOfTensorsGeneric, Callable[..., TensorOrTupleOfTensorsGeneric]
567+
],
568+
target: TargetType = None,
569+
additional_forward_args: Any = None,
570+
return_convergence_delta: bool = False,
571+
custom_attribution_func: Union[None, Callable[..., Tuple[Tensor, ...]]] = None,
572+
) -> Union[
573+
TensorOrTupleOfTensorsGeneric, Tuple[TensorOrTupleOfTensorsGeneric, Tensor]
574+
]:
569575
r"""
570576
Extends DeepLift algorithm and approximates SHAP values using Deeplift.
571577
For each input sample it computes DeepLift attribution with respect to
@@ -724,7 +730,7 @@ def attribute(
724730

725731
# Keeps track whether original input is a tuple or not before
726732
# converting it into a tuple.
727-
is_inputs_tuple = isinstance(inputs, tuple)
733+
is_inputs_tuple = _is_tuple(inputs)
728734

729735
inputs = _format_input(inputs)
730736

@@ -745,14 +751,18 @@ def attribute(
745751
exp_base,
746752
target=exp_tgt,
747753
additional_forward_args=exp_addit_args,
748-
return_convergence_delta=return_convergence_delta,
754+
return_convergence_delta=cast(
755+
Literal[True, False], return_convergence_delta
756+
),
749757
custom_attribution_func=custom_attribution_func,
750758
)
751759
if return_convergence_delta:
752-
attributions, delta = attributions
760+
attributions, delta = cast(Tuple[Tuple[Tensor, ...], Tensor], attributions)
753761

754762
attributions = tuple(
755-
self._compute_mean_across_baselines(inp_bsz, base_bsz, attribution)
763+
self._compute_mean_across_baselines(
764+
inp_bsz, base_bsz, cast(Tensor, attribution)
765+
)
756766
for attribution in attributions
757767
)
758768

@@ -765,9 +775,11 @@ def _expand_inputs_baselines_targets(
765775
self,
766776
baselines: Tuple[Tensor, ...],
767777
inputs: Tuple[Tensor, ...],
768-
target: Optional[Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]],
778+
target: TargetType,
769779
additional_forward_args: Any,
770-
):
780+
) -> Tuple[
781+
Tuple[Tensor, ...], Tuple[Tensor, ...], TargetType, Any,
782+
]:
771783
inp_bsz = inputs[0].shape[0]
772784
base_bsz = baselines[0].shape[0]
773785

captum/attr/_core/feature_ablation.py

+9-13
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,21 @@
44

55
from torch import Tensor, dtype
66

7-
from typing import Callable, List, Optional, Tuple, Union, Any, cast
7+
from typing import Any, Callable, Tuple, Union, cast
88

99
from .._utils.common import (
1010
_find_output_mode_and_verify,
1111
_format_attributions,
1212
_format_input,
1313
_format_input_baseline,
14+
_is_tuple,
1415
_run_forward,
1516
_expand_additional_forward_args,
1617
_expand_target,
1718
_format_additional_forward_args,
1819
)
1920
from .._utils.attribution import PerturbationAttribution
20-
from .._utils.typing import TensorOrTupleOfTensors
21+
from .._utils.typing import TensorOrTupleOfTensorsGeneric, TargetType, BaselineType
2122

2223

2324
class FeatureAblation(PerturbationAttribution):
@@ -32,20 +33,15 @@ def __init__(self, forward_func: Callable) -> None:
3233
self.use_weights = False
3334

3435
def attribute(
35-
# type:ignore
3636
self,
37-
inputs: TensorOrTupleOfTensors,
38-
baselines: Optional[
39-
Union[Tensor, int, float, Tuple[Union[Tensor, int, float], ...]]
40-
] = None,
41-
target: Optional[
42-
Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]
43-
] = None,
37+
inputs: TensorOrTupleOfTensorsGeneric,
38+
baselines: BaselineType = None,
39+
target: TargetType = None,
4440
additional_forward_args: Any = None,
45-
feature_mask: Optional[TensorOrTupleOfTensors] = None,
41+
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
4642
perturbations_per_eval: int = 1,
4743
**kwargs: Any
48-
) -> TensorOrTupleOfTensors:
44+
) -> TensorOrTupleOfTensorsGeneric:
4945
r""""
5046
A perturbation based approach to computing attribution, involving
5147
replacing each input feature with a given baseline / reference, and
@@ -228,7 +224,7 @@ def attribute(
228224
"""
229225
# Keeps track whether original input is a tuple or not before
230226
# converting it into a tuple.
231-
is_inputs_tuple = isinstance(inputs, tuple)
227+
is_inputs_tuple = _is_tuple(inputs)
232228
inputs, baselines = _format_input_baseline(inputs, baselines)
233229
additional_forward_args = _format_additional_forward_args(
234230
additional_forward_args

captum/attr/_core/feature_permutation.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#!/usr/bin/env python3
2-
from typing import Any, Callable, List, Optional, Tuple, Union
2+
from typing import Any, Callable, Tuple, Union
33

44
import torch
55
from torch import Tensor
66

77
from .feature_ablation import FeatureAblation
8-
from .._utils.typing import TensorOrTupleOfTensors
8+
from .._utils.typing import TargetType, TensorOrTupleOfTensorsGeneric
99

1010

1111
def _permute_feature(x: Tensor, feature_mask: Tensor) -> Tensor:
@@ -67,15 +67,13 @@ def __init__(self, forward_func: Callable, perm_func: Callable = _permute_featur
6767
# signature to the parent
6868
def attribute( # type: ignore
6969
self,
70-
inputs: TensorOrTupleOfTensors,
71-
target: Optional[
72-
Union[int, Tuple[int, ...], Tensor, List[Tuple[int, ...]]]
73-
] = None,
70+
inputs: TensorOrTupleOfTensorsGeneric,
71+
target: TargetType = None,
7472
additional_forward_args: Any = None,
75-
feature_mask: Optional[TensorOrTupleOfTensors] = None,
76-
ablations_per_eval: int = 1,
73+
feature_mask: Union[None, TensorOrTupleOfTensorsGeneric] = None,
74+
perturbations_per_eval: int = 1,
7775
**kwargs: Any,
78-
) -> TensorOrTupleOfTensors:
76+
) -> TensorOrTupleOfTensorsGeneric:
7977
r"""
8078
This function is almost equivalent to `FeatureAblation.attribute`. The
8179
main difference is the way ablated examples are generated. Specifically
@@ -154,17 +152,17 @@ def attribute( # type: ignore
154152
each scalar within a tensor as a separate feature, which
155153
is permuted independently.
156154
Default: None
157-
ablations_per_eval (int, optional): Allows permutations (ablations)
155+
perturbations_per_eval (int, optional): Allows permutations
158156
of multiple features to be processed simultaneously
159157
in one call to forward_fn. Each forward pass will
160-
contain a maximum of ablations_per_eval * #examples
158+
contain a maximum of perturbations_per_eval * #examples
161159
samples. For DataParallel models, each batch is
162160
split among the available devices, so evaluations on
163161
each available device contain at most
164-
(ablations_per_eval * #examples) / num_devices
162+
(perturbations_per_eval * #examples) / num_devices
165163
samples.
166164
If the forward function returns a single scalar per batch,
167-
ablations_per_eval must be set to 1.
165+
perturbations_per_eval must be set to 1.
168166
Default: 1
169167
**kwargs (Any, optional): Any additional arguments used by child
170168
classes of FeatureAblation (such as Occlusion) to construct
@@ -179,7 +177,7 @@ def attribute( # type: ignore
179177
target=target,
180178
additional_forward_args=additional_forward_args,
181179
feature_mask=feature_mask,
182-
ablations_per_eval=ablations_per_eval,
180+
perturbations_per_eval=perturbations_per_eval,
183181
**kwargs,
184182
)
185183

0 commit comments

Comments
 (0)