1
1
# pyre-strict
2
2
from copy import copy
3
3
4
- from typing import Any , Callable , cast , Dict , List , Optional , Union
4
+ from typing import Any , Callable , cast , Dict , List , Optional , Tuple , Union
5
5
6
6
import matplotlib .pyplot as plt
7
7
import numpy as np
8
8
9
9
import torch
10
+ from captum ._utils .typing import TokenizerLike
10
11
from captum .attr ._core .feature_ablation import FeatureAblation
11
12
from captum .attr ._core .kernel_shap import KernelShap
12
13
from captum .attr ._core .layer .layer_integrated_gradients import LayerIntegratedGradients
13
14
from captum .attr ._core .lime import Lime
14
15
from captum .attr ._core .shapley_value import ShapleyValues , ShapleyValueSampling
15
- from captum .attr ._utils .attribution import Attribution
16
+ from captum .attr ._utils .attribution import (
17
+ Attribution ,
18
+ GradientAttribution ,
19
+ PerturbationAttribution ,
20
+ )
16
21
from captum .attr ._utils .interpretable_input import (
17
22
InterpretableInput ,
18
23
TextTemplateInput ,
@@ -44,11 +49,12 @@ def __init__(
44
49
self .output_tokens = output_tokens
45
50
46
51
@property
47
- def seq_attr_dict (self ) -> Dict [str , Any ]:
52
+ def seq_attr_dict (self ) -> Dict [str , float ]:
48
53
return {k : v for v , k in zip (self .seq_attr .cpu ().tolist (), self .input_tokens )}
49
54
50
- # pyre-fixme[3]: Return type must be annotated.
51
- def plot_token_attr (self , show : bool = False ):
55
+ def plot_token_attr (
56
+ self , show : bool = False
57
+ ) -> Union [None , Tuple [plt .Figure , plt .Axes ]]:
52
58
"""
53
59
Generate a matplotlib plot for visualising the attribution
54
60
of the output tokens.
@@ -58,7 +64,11 @@ def plot_token_attr(self, show: bool = False):
58
64
Default: False
59
65
"""
60
66
61
- # pyre-fixme[16]: `Optional` has no attribute `cpu`.
67
+ if self .token_attr is None :
68
+ raise ValueError (
69
+ "token_attr is None (no token-level attribution was performed), please "
70
+ "use plot_seq_attr instead for the sequence-level attribution plot"
71
+ )
62
72
token_attr = self .token_attr .cpu () # type: ignore
63
73
64
74
# maximum absolute attribution value
@@ -83,7 +93,7 @@ def plot_token_attr(self, show: bool = False):
83
93
)
84
94
85
95
# Create colorbar
86
- cbar = ax . figure .colorbar (im , ax = ax ) # type: ignore
96
+ cbar = fig .colorbar (im , ax = ax ) # type: ignore
87
97
cbar .ax .set_ylabel ("Token Attribuiton" , rotation = - 90 , va = "bottom" )
88
98
89
99
# Show all ticks and label them with the respective list entries.
@@ -113,11 +123,13 @@ def plot_token_attr(self, show: bool = False):
113
123
114
124
if show :
115
125
plt .show ()
126
+ return None # mypy wants this
116
127
else :
117
128
return fig , ax
118
129
119
- # pyre-fixme[3]: Return type must be annotated.
120
- def plot_seq_attr (self , show : bool = False ):
130
+ def plot_seq_attr (
131
+ self , show : bool = False
132
+ ) -> Union [None , Tuple [plt .Figure , plt .Axes ]]:
121
133
"""
122
134
Generate a matplotlib plot for visualising the attribution
123
135
of the output sequence.
@@ -150,6 +162,7 @@ def plot_seq_attr(self, show: bool = False):
150
162
151
163
if show :
152
164
plt .show ()
165
+ return None # mypy wants this
153
166
else :
154
167
return fig , ax
155
168
@@ -181,9 +194,8 @@ class LLMAttribution(Attribution):
181
194
182
195
def __init__ (
183
196
self ,
184
- attr_method : Attribution ,
185
- # pyre-fixme[2]: Parameter must be annotated.
186
- tokenizer ,
197
+ attr_method : PerturbationAttribution ,
198
+ tokenizer : TokenizerLike ,
187
199
attr_target : str = "log_prob" , # TODO: support callable attr_target
188
200
) -> None :
189
201
"""
@@ -208,24 +220,19 @@ class created with the llm model that follows huggingface style
208
220
super ().__init__ (attr_method .forward_func )
209
221
210
222
# shallow copy is enough to avoid modifying original instance
211
- # pyre-fixme[4]: Attribute must be annotated.
212
- self .attr_method = copy (attr_method )
213
- # pyre-fixme[4]: Attribute must be annotated.
214
- self .include_per_token_attr = isinstance (
223
+ self .attr_method : PerturbationAttribution = copy (attr_method )
224
+ self .include_per_token_attr : bool = isinstance (
215
225
attr_method , self .SUPPORTED_PER_TOKEN_ATTR_METHODS
216
226
)
217
227
218
228
self .attr_method .forward_func = self ._forward_func
219
229
220
230
# alias, we really need a model and don't support wrapper functions
221
231
# coz we need call model.forward, model.generate, etc.
222
- # pyre-fixme[4]: Attribute must be annotated.
223
- self .model = cast (nn .Module , self .forward_func )
232
+ self .model : nn .Module = cast (nn .Module , self .forward_func )
224
233
225
- # pyre-fixme[4]: Attribute must be annotated.
226
- self .tokenizer = tokenizer
227
- # pyre-fixme[4]: Attribute must be annotated.
228
- self .device = (
234
+ self .tokenizer : TokenizerLike = tokenizer
235
+ self .device : torch .device = (
229
236
cast (torch .device , self .model .device )
230
237
if hasattr (self .model , "device" )
231
238
else next (self .model .parameters ()).device
@@ -239,15 +246,12 @@ class created with the llm model that follows huggingface style
239
246
240
247
def _forward_func (
241
248
self ,
242
- # pyre-fixme[2]: Parameter must be annotated.
243
- perturbed_tensor ,
244
- # pyre-fixme[2]: Parameter must be annotated.
245
- inp ,
246
- # pyre-fixme[2]: Parameter must be annotated.
247
- target_tokens ,
249
+ perturbed_tensor : Union [None , Tensor ],
250
+ inp : InterpretableInput ,
251
+ target_tokens : Tensor ,
248
252
use_cached_outputs : bool = False ,
249
- _inspect_forward = None ,
250
- ) -> Union [ int , Tensor ] :
253
+ _inspect_forward : Optional [ Callable [[ str , str , List [ float ]], None ]] = None ,
254
+ ) -> Tensor :
251
255
perturbed_input = self ._format_model_input (inp .to_model_input (perturbed_tensor ))
252
256
init_model_inp = perturbed_input
253
257
@@ -279,7 +283,9 @@ def _forward_func(
279
283
(model_inp , torch .tensor ([[target_token ]]).to (self .device )), dim = 1
280
284
)
281
285
282
- total_log_prob = sum (log_prob_list )
286
+ # pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return
287
+ # Tensor
288
+ total_log_prob : Tensor = sum (log_prob_list ) # type: ignore
283
289
# 1st element is the total prob, rest are the target tokens
284
290
# add a leading dim for batch even we only support single instance for now
285
291
if self .include_per_token_attr :
@@ -288,8 +294,6 @@ def _forward_func(
288
294
).unsqueeze (0 )
289
295
else :
290
296
target_log_probs = total_log_prob # type: ignore
291
- # pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int,
292
- # Tensor]`.
293
297
target_probs = torch .exp (target_log_probs )
294
298
295
299
if _inspect_forward :
@@ -301,35 +305,31 @@ def _forward_func(
301
305
302
306
return target_probs if self .attr_target != "log_prob" else target_log_probs
303
307
304
- # pyre-fixme[3]: Return type must be annotated.
305
- def _format_model_input (self , model_input : Union [str , Tensor ]):
308
+ def _format_model_input (self , model_input : Union [str , Tensor ]) -> Tensor :
306
309
"""
307
310
Convert str to tokenized tensor
308
311
to make LLMAttribution work with model inputs of both
309
312
raw text and text token tensors
310
313
"""
311
314
# return tensor(1, n_tokens)
312
315
if isinstance (model_input , str ):
313
- return self .tokenizer .encode (model_input , return_tensors = "pt" ).to (
314
- self .device
315
- )
316
+ # pyre-ignore[9] pyre/mypy thinks return type may be List, but it will be
317
+ # Tensor
318
+ return self .tokenizer .encode ( # type: ignore
319
+ model_input , return_tensors = "pt"
320
+ ).to (self .device )
316
321
return model_input .to (self .device )
317
322
318
323
def attribute (
319
324
self ,
320
325
inp : InterpretableInput ,
321
326
target : Union [str , torch .Tensor , None ] = None ,
322
327
num_trials : int = 1 ,
323
- # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
324
- # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
325
- # errors.
326
- gen_args : Optional [Dict ] = None ,
328
+ gen_args : Optional [Dict [str , Any ]] = None ,
327
329
use_cached_outputs : bool = True ,
328
330
# internal callback hook can be used for logging
329
- # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
330
- _inspect_forward : Optional [Callable ] = None ,
331
- # pyre-fixme[2]: Parameter must be annotated.
332
- ** kwargs ,
331
+ _inspect_forward : Optional [Callable [[str , str , List [float ]], None ]] = None ,
332
+ ** kwargs : Any ,
333
333
) -> LLMAttributionResult :
334
334
"""
335
335
Args:
@@ -380,10 +380,14 @@ def attribute(
380
380
target_tokens = torch .tensor (target_tokens )
381
381
elif type (target ) is torch .Tensor :
382
382
target_tokens = target
383
+ else :
384
+ raise TypeError (
385
+ "target must either be str or Tensor, but the type of target is "
386
+ "{}" .format (type (target ))
387
+ )
383
388
384
389
attr = torch .zeros (
385
390
[
386
- # pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
387
391
1 + len (target_tokens ) if self .include_per_token_attr else 1 ,
388
392
inp .n_itp_features ,
389
393
],
@@ -398,8 +402,6 @@ def attribute(
398
402
attr_input ,
399
403
additional_forward_args = (
400
404
inp ,
401
- # pyre-fixme[61]: `target_tokens` is undefined, or not always
402
- # defined.
403
405
target_tokens ,
404
406
use_cached_outputs ,
405
407
_inspect_forward ,
@@ -424,7 +426,6 @@ def attribute(
424
426
attr [1 :] if self .include_per_token_attr else None
425
427
), # shape(n_output_token, n_input_features)
426
428
inp .values ,
427
- # pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
428
429
self .tokenizer .convert_ids_to_tokens (target_tokens ),
429
430
)
430
431
@@ -454,14 +455,11 @@ class LLMGradientAttribution(Attribution):
454
455
SUPPORTED_METHODS = (LayerIntegratedGradients ,)
455
456
SUPPORTED_INPUTS = (TextTokenInput ,)
456
457
457
- # pyre-fixme[3]: Return type must be annotated.
458
458
def __init__ (
459
459
self ,
460
- # pyre-fixme[2]: Parameter must be annotated.
461
- attr_method ,
462
- # pyre-fixme[2]: Parameter must be annotated.
463
- tokenizer ,
464
- ):
460
+ attr_method : GradientAttribution ,
461
+ tokenizer : TokenizerLike ,
462
+ ) -> None :
465
463
"""
466
464
Args:
467
465
attr_method (Attribution): instance of a supported perturbation attribution
@@ -476,19 +474,15 @@ class created with the llm model that follows huggingface style
476
474
super ().__init__ (attr_method .forward_func )
477
475
478
476
# shallow copy is enough to avoid modifying original instance
479
- # pyre-fixme[4]: Attribute must be annotated.
480
- self .attr_method = copy (attr_method )
477
+ self .attr_method : GradientAttribution = copy (attr_method )
481
478
self .attr_method .forward_func = self ._forward_func
482
479
483
480
# alias, we really need a model and don't support wrapper functions
484
481
# coz we need call model.forward, model.generate, etc.
485
- # pyre-fixme[4]: Attribute must be annotated.
486
- self .model = cast (nn .Module , self .forward_func )
482
+ self .model : nn .Module = cast (nn .Module , self .forward_func )
487
483
488
- # pyre-fixme[4]: Attribute must be annotated.
489
- self .tokenizer = tokenizer
490
- # pyre-fixme[4]: Attribute must be annotated.
491
- self .device = (
484
+ self .tokenizer : TokenizerLike = tokenizer
485
+ self .device : torch .device = (
492
486
cast (torch .device , self .model .device )
493
487
if hasattr (self .model , "device" )
494
488
else next (self .model .parameters ()).device
@@ -526,9 +520,7 @@ def _forward_func(
526
520
# the attribution target is limited to the log probability
527
521
return token_log_probs
528
522
529
- # pyre-fixme[3]: Return type must be annotated.
530
- # pyre-fixme[2]: Parameter must be annotated.
531
- def _format_model_input (self , model_input ):
523
+ def _format_model_input (self , model_input : Tensor ) -> Tensor :
532
524
"""
533
525
Convert str to tokenized tensor
534
526
"""
@@ -538,12 +530,8 @@ def attribute(
538
530
self ,
539
531
inp : InterpretableInput ,
540
532
target : Union [str , torch .Tensor , None ] = None ,
541
- # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
542
- # `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
543
- # errors.
544
- gen_args : Optional [Dict ] = None ,
545
- # pyre-fixme[2]: Parameter must be annotated.
546
- ** kwargs ,
533
+ gen_args : Optional [Dict [str , Any ]] = None ,
534
+ ** kwargs : Any ,
547
535
) -> LLMAttributionResult :
548
536
"""
549
537
Args:
@@ -590,19 +578,21 @@ def attribute(
590
578
target_tokens = torch .tensor (target_tokens )
591
579
elif type (target ) is torch .Tensor :
592
580
target_tokens = target
581
+ else :
582
+ raise TypeError (
583
+ "target must either be str or Tensor, but the type of target is "
584
+ "{}" .format (type (target ))
585
+ )
593
586
594
587
attr_inp = inp .to_tensor ().to (self .device )
595
588
596
589
attr_list = []
597
- # pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
598
590
for cur_target_idx , _ in enumerate (target_tokens ):
599
591
# attr in shape(batch_size, input+output_len, emb_dim)
600
592
attr = self .attr_method .attribute (
601
593
attr_inp ,
602
594
additional_forward_args = (
603
595
inp ,
604
- # pyre-fixme[61]: `target_tokens` is undefined, or not always
605
- # defined.
606
596
target_tokens ,
607
597
cur_target_idx ,
608
598
),
@@ -629,7 +619,7 @@ def attribute(
629
619
# it attributes to all the elements of the output of the specified layer
630
620
# so we need special handling for the inp type which don't care all the elements
631
621
if isinstance (inp , TextTokenInput ) and inp .itp_mask is not None :
632
- itp_mask = inp .itp_mask .to (self .device )
622
+ itp_mask = inp .itp_mask .to (attr .device )
633
623
itp_mask = itp_mask .expand_as (attr )
634
624
attr = attr [itp_mask ].view (attr .size (0 ), - 1 )
635
625
@@ -642,7 +632,6 @@ def attribute(
642
632
seq_attr ,
643
633
attr , # shape(n_output_token, n_input_features)
644
634
inp .values ,
645
- # pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
646
635
self .tokenizer .convert_ids_to_tokens (target_tokens ),
647
636
)
648
637
0 commit comments