1
1
#!/usr/bin/env python3
2
2
import typing
3
- from typing import Optional , Tuple , Union , Any , List , Callable , cast
3
+ from typing import Tuple , Union , Any , List , Callable , cast
4
4
5
5
import warnings
6
6
import torch
13
13
import numpy as np
14
14
15
15
from .._utils .common import (
16
+ _is_tuple ,
16
17
_format_input ,
17
18
_format_baseline ,
18
19
_format_callable_baseline ,
30
31
)
31
32
from .._utils .attribution import GradientAttribution
32
33
from .._utils .gradient import apply_gradient_requirements , undo_gradient_requirements
33
- from .._utils .typing import TensorOrTupleOfTensors
34
+ from .._utils .typing import (
35
+ TensorOrTupleOfTensorsGeneric ,
36
+ Literal ,
37
+ TargetType ,
38
+ BaselineType ,
39
+ )
34
40
35
41
36
42
# Check if module backward hook can safely be used for the module that produced
@@ -76,43 +82,39 @@ def __init__(self, model: Module) -> None:
76
82
@typing .overload
77
83
def attribute (
78
84
self ,
79
- inputs : TensorOrTupleOfTensors ,
80
- baselines : Union [
81
- Tensor , int , float , Tuple [Union [Tensor , int , float ], ...]
82
- ] = None ,
83
- target : Optional [
84
- Union [int , Tuple [int , ...], Tensor , List [Tuple [int , ...]]]
85
- ] = None ,
85
+ inputs : TensorOrTupleOfTensorsGeneric ,
86
+ baselines : BaselineType = None ,
87
+ target : TargetType = None ,
86
88
additional_forward_args : Any = None ,
87
- custom_attribution_func : Callable [..., Tuple [Tensor , ...]] = None ,
88
- ) -> TensorOrTupleOfTensors :
89
+ return_convergence_delta : Literal [False ] = False ,
90
+ custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
91
+ ) -> TensorOrTupleOfTensorsGeneric :
89
92
...
90
93
91
94
@typing .overload
92
95
def attribute (
93
96
self ,
94
- inputs : TensorOrTupleOfTensors ,
95
- baselines : Optional [
96
- Union [Tensor , int , float , Tuple [Union [Tensor , int , float ], ...]]
97
- ] = None ,
98
- target : Optional [
99
- Union [int , Tuple [int , ...], Tensor , List [Tuple [int , ...]]]
100
- ] = None ,
97
+ inputs : TensorOrTupleOfTensorsGeneric ,
98
+ baselines : BaselineType = None ,
99
+ target : TargetType = None ,
101
100
additional_forward_args : Any = None ,
102
- return_convergence_delta : bool = False ,
103
- custom_attribution_func : Callable [..., Tuple [Tensor , ...]] = None ,
104
- ) -> Union [TensorOrTupleOfTensors , Tuple [TensorOrTupleOfTensors , Tensor ]]:
101
+ * ,
102
+ return_convergence_delta : Literal [True ],
103
+ custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
104
+ ) -> Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]:
105
105
...
106
106
107
- def attribute (
107
+ def attribute ( # type: ignore
108
108
self ,
109
- inputs ,
110
- baselines = None ,
111
- target = None ,
112
- additional_forward_args = None ,
113
- return_convergence_delta = False ,
114
- custom_attribution_func = None ,
115
- ):
109
+ inputs : TensorOrTupleOfTensorsGeneric ,
110
+ baselines : BaselineType = None ,
111
+ target : TargetType = None ,
112
+ additional_forward_args : Any = None ,
113
+ return_convergence_delta : bool = False ,
114
+ custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
115
+ ) -> Union [
116
+ TensorOrTupleOfTensorsGeneric , Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]
117
+ ]:
116
118
r""""
117
119
Implements DeepLIFT algorithm based on the following paper:
118
120
Learning Important Features Through Propagating Activation Differences,
@@ -278,7 +280,7 @@ def attribute(
278
280
279
281
# Keeps track whether original input is a tuple or not before
280
282
# converting it into a tuple.
281
- is_inputs_tuple = isinstance (inputs , tuple )
283
+ is_inputs_tuple = _is_tuple (inputs )
282
284
283
285
inputs = _format_input (inputs )
284
286
baselines = _format_baseline (baselines , inputs )
@@ -341,10 +343,8 @@ def attribute(
341
343
def _construct_forward_func (
342
344
self ,
343
345
forward_func : Callable ,
344
- inputs : TensorOrTupleOfTensors ,
345
- target : Optional [
346
- Union [int , Tuple [int , ...], Tensor , List [Tuple [int , ...]]]
347
- ] = None ,
346
+ inputs : Tuple ,
347
+ target : TargetType = None ,
348
348
additional_forward_args : Any = None ,
349
349
) -> Callable :
350
350
def forward_fn ():
@@ -533,39 +533,45 @@ def __init__(self, model: Module) -> None:
533
533
@typing .overload # type: ignore
534
534
def attribute (
535
535
self ,
536
- inputs : TensorOrTupleOfTensors ,
537
- baselines : Union [TensorOrTupleOfTensors , Callable [..., TensorOrTupleOfTensors ]],
538
- target : Optional [
539
- Union [ int , Tuple [ int , ...], Tensor , List [ Tuple [ int , ...]]]
540
- ] = None ,
536
+ inputs : TensorOrTupleOfTensorsGeneric ,
537
+ baselines : Union [
538
+ TensorOrTupleOfTensorsGeneric , Callable [..., TensorOrTupleOfTensorsGeneric ]
539
+ ],
540
+ target : TargetType = None ,
541
541
additional_forward_args : Any = None ,
542
- custom_attribution_func : Callable [..., Tuple [Tensor , ...]] = None ,
543
- ) -> TensorOrTupleOfTensors :
542
+ return_convergence_delta : Literal [False ] = False ,
543
+ custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
544
+ ) -> TensorOrTupleOfTensorsGeneric :
544
545
...
545
546
546
547
@typing .overload
547
548
def attribute (
548
549
self ,
549
- inputs : TensorOrTupleOfTensors ,
550
- baselines : Union [TensorOrTupleOfTensors , Callable [..., TensorOrTupleOfTensors ]],
551
- target : Optional [
552
- Union [ int , Tuple [ int , ...], Tensor , List [ Tuple [ int , ...]]]
553
- ] = None ,
550
+ inputs : TensorOrTupleOfTensorsGeneric ,
551
+ baselines : Union [
552
+ TensorOrTupleOfTensorsGeneric , Callable [..., TensorOrTupleOfTensorsGeneric ]
553
+ ],
554
+ target : TargetType = None ,
554
555
additional_forward_args : Any = None ,
555
- return_convergence_delta : bool = False ,
556
- custom_attribution_func : Callable [..., Tuple [Tensor , ...]] = None ,
557
- ) -> Union [TensorOrTupleOfTensors , Tuple [TensorOrTupleOfTensors , Tensor ]]:
556
+ * ,
557
+ return_convergence_delta : Literal [True ],
558
+ custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
559
+ ) -> Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]:
558
560
...
559
561
560
- def attribute (
562
+ def attribute ( # type: ignore
561
563
self ,
562
- inputs ,
563
- baselines ,
564
- target = None ,
565
- additional_forward_args = None ,
566
- return_convergence_delta = False ,
567
- custom_attribution_func = None ,
568
- ):
564
+ inputs : TensorOrTupleOfTensorsGeneric ,
565
+ baselines : Union [
566
+ TensorOrTupleOfTensorsGeneric , Callable [..., TensorOrTupleOfTensorsGeneric ]
567
+ ],
568
+ target : TargetType = None ,
569
+ additional_forward_args : Any = None ,
570
+ return_convergence_delta : bool = False ,
571
+ custom_attribution_func : Union [None , Callable [..., Tuple [Tensor , ...]]] = None ,
572
+ ) -> Union [
573
+ TensorOrTupleOfTensorsGeneric , Tuple [TensorOrTupleOfTensorsGeneric , Tensor ]
574
+ ]:
569
575
r"""
570
576
Extends DeepLift algorithm and approximates SHAP values using Deeplift.
571
577
For each input sample it computes DeepLift attribution with respect to
@@ -724,7 +730,7 @@ def attribute(
724
730
725
731
# Keeps track whether original input is a tuple or not before
726
732
# converting it into a tuple.
727
- is_inputs_tuple = isinstance (inputs , tuple )
733
+ is_inputs_tuple = _is_tuple (inputs )
728
734
729
735
inputs = _format_input (inputs )
730
736
@@ -745,14 +751,18 @@ def attribute(
745
751
exp_base ,
746
752
target = exp_tgt ,
747
753
additional_forward_args = exp_addit_args ,
748
- return_convergence_delta = return_convergence_delta ,
754
+ return_convergence_delta = cast (
755
+ Literal [True , False ], return_convergence_delta
756
+ ),
749
757
custom_attribution_func = custom_attribution_func ,
750
758
)
751
759
if return_convergence_delta :
752
- attributions , delta = attributions
760
+ attributions , delta = cast ( Tuple [ Tuple [ Tensor , ...], Tensor ], attributions )
753
761
754
762
attributions = tuple (
755
- self ._compute_mean_across_baselines (inp_bsz , base_bsz , attribution )
763
+ self ._compute_mean_across_baselines (
764
+ inp_bsz , base_bsz , cast (Tensor , attribution )
765
+ )
756
766
for attribution in attributions
757
767
)
758
768
@@ -765,9 +775,11 @@ def _expand_inputs_baselines_targets(
765
775
self ,
766
776
baselines : Tuple [Tensor , ...],
767
777
inputs : Tuple [Tensor , ...],
768
- target : Optional [ Union [ int , Tuple [ int , ...], Tensor , List [ Tuple [ int , ...]]]] ,
778
+ target : TargetType ,
769
779
additional_forward_args : Any ,
770
- ):
780
+ ) -> Tuple [
781
+ Tuple [Tensor , ...], Tuple [Tensor , ...], TargetType , Any ,
782
+ ]:
771
783
inp_bsz = inputs [0 ].shape [0 ]
772
784
base_bsz = baselines [0 ].shape [0 ]
773
785
0 commit comments