Skip to content

Commit e01c07b

Browse files
craymichaelfacebook-github-bot
authored andcommitted
Gradient-based LLM attribution in tutorial + LLM attribution type annotations (#1333)
Summary: Pull Request resolved: #1333 Add gradient-based LLM attribution to the tutorial notebook. Addresses #1237. Additionally, add more type annotations to llm_attr.py. Reviewed By: vivekmig Differential Revision: D61461521 fbshipit-source-id: 8ae68773cfd32506e797941698a9a08316676279
1 parent 0665ea5 commit e01c07b

File tree

5 files changed

+240
-126
lines changed

5 files changed

+240
-126
lines changed

captum/_utils/typing.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# pyre-strict
44

5-
from typing import List, Tuple, TYPE_CHECKING, TypeVar, Union
5+
from typing import List, Optional, Protocol, Tuple, TYPE_CHECKING, TypeVar, Union
66

77
from torch import Tensor
88
from torch.nn import Module
@@ -33,3 +33,14 @@
3333
TensorLikeList4D,
3434
TensorLikeList5D,
3535
]
36+
37+
38+
class TokenizerLike(Protocol):
39+
"""A protocol for tokenizer-like objects that can be used with Captum
40+
LLM attribution methods."""
41+
42+
def encode(
43+
self, text: str, return_tensors: Optional[str] = None
44+
) -> Union[List[int], Tensor]: ...
45+
def decode(self, token_ids: Tensor) -> str: ...
46+
def convert_ids_to_tokens(self, token_ids: Tensor) -> List[str]: ...

captum/attr/_core/llm_attr.py

+67-78
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
# pyre-strict
22
from copy import copy
33

4-
from typing import Any, Callable, cast, Dict, List, Optional, Union
4+
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
55

66
import matplotlib.pyplot as plt
77
import numpy as np
88

99
import torch
10+
from captum._utils.typing import TokenizerLike
1011
from captum.attr._core.feature_ablation import FeatureAblation
1112
from captum.attr._core.kernel_shap import KernelShap
1213
from captum.attr._core.layer.layer_integrated_gradients import LayerIntegratedGradients
1314
from captum.attr._core.lime import Lime
1415
from captum.attr._core.shapley_value import ShapleyValues, ShapleyValueSampling
15-
from captum.attr._utils.attribution import Attribution
16+
from captum.attr._utils.attribution import (
17+
Attribution,
18+
GradientAttribution,
19+
PerturbationAttribution,
20+
)
1621
from captum.attr._utils.interpretable_input import (
1722
InterpretableInput,
1823
TextTemplateInput,
@@ -44,11 +49,12 @@ def __init__(
4449
self.output_tokens = output_tokens
4550

4651
@property
47-
def seq_attr_dict(self) -> Dict[str, Any]:
52+
def seq_attr_dict(self) -> Dict[str, float]:
4853
return {k: v for v, k in zip(self.seq_attr.cpu().tolist(), self.input_tokens)}
4954

50-
# pyre-fixme[3]: Return type must be annotated.
51-
def plot_token_attr(self, show: bool = False):
55+
def plot_token_attr(
56+
self, show: bool = False
57+
) -> Union[None, Tuple[plt.Figure, plt.Axes]]:
5258
"""
5359
Generate a matplotlib plot for visualising the attribution
5460
of the output tokens.
@@ -58,7 +64,11 @@ def plot_token_attr(self, show: bool = False):
5864
Default: False
5965
"""
6066

61-
# pyre-fixme[16]: `Optional` has no attribute `cpu`.
67+
if self.token_attr is None:
68+
raise ValueError(
69+
"token_attr is None (no token-level attribution was performed), please "
70+
"use plot_seq_attr instead for the sequence-level attribution plot"
71+
)
6272
token_attr = self.token_attr.cpu() # type: ignore
6373

6474
# maximum absolute attribution value
@@ -83,7 +93,7 @@ def plot_token_attr(self, show: bool = False):
8393
)
8494

8595
# Create colorbar
86-
cbar = ax.figure.colorbar(im, ax=ax) # type: ignore
96+
cbar = fig.colorbar(im, ax=ax) # type: ignore
8797
cbar.ax.set_ylabel("Token Attribuiton", rotation=-90, va="bottom")
8898

8999
# Show all ticks and label them with the respective list entries.
@@ -113,11 +123,13 @@ def plot_token_attr(self, show: bool = False):
113123

114124
if show:
115125
plt.show()
126+
return None # mypy wants this
116127
else:
117128
return fig, ax
118129

119-
# pyre-fixme[3]: Return type must be annotated.
120-
def plot_seq_attr(self, show: bool = False):
130+
def plot_seq_attr(
131+
self, show: bool = False
132+
) -> Union[None, Tuple[plt.Figure, plt.Axes]]:
121133
"""
122134
Generate a matplotlib plot for visualising the attribution
123135
of the output sequence.
@@ -150,6 +162,7 @@ def plot_seq_attr(self, show: bool = False):
150162

151163
if show:
152164
plt.show()
165+
return None # mypy wants this
153166
else:
154167
return fig, ax
155168

@@ -181,9 +194,8 @@ class LLMAttribution(Attribution):
181194

182195
def __init__(
183196
self,
184-
attr_method: Attribution,
185-
# pyre-fixme[2]: Parameter must be annotated.
186-
tokenizer,
197+
attr_method: PerturbationAttribution,
198+
tokenizer: TokenizerLike,
187199
attr_target: str = "log_prob", # TODO: support callable attr_target
188200
) -> None:
189201
"""
@@ -208,24 +220,19 @@ class created with the llm model that follows huggingface style
208220
super().__init__(attr_method.forward_func)
209221

210222
# shallow copy is enough to avoid modifying original instance
211-
# pyre-fixme[4]: Attribute must be annotated.
212-
self.attr_method = copy(attr_method)
213-
# pyre-fixme[4]: Attribute must be annotated.
214-
self.include_per_token_attr = isinstance(
223+
self.attr_method: PerturbationAttribution = copy(attr_method)
224+
self.include_per_token_attr: bool = isinstance(
215225
attr_method, self.SUPPORTED_PER_TOKEN_ATTR_METHODS
216226
)
217227

218228
self.attr_method.forward_func = self._forward_func
219229

220230
# alias, we really need a model and don't support wrapper functions
221231
# coz we need call model.forward, model.generate, etc.
222-
# pyre-fixme[4]: Attribute must be annotated.
223-
self.model = cast(nn.Module, self.forward_func)
232+
self.model: nn.Module = cast(nn.Module, self.forward_func)
224233

225-
# pyre-fixme[4]: Attribute must be annotated.
226-
self.tokenizer = tokenizer
227-
# pyre-fixme[4]: Attribute must be annotated.
228-
self.device = (
234+
self.tokenizer: TokenizerLike = tokenizer
235+
self.device: torch.device = (
229236
cast(torch.device, self.model.device)
230237
if hasattr(self.model, "device")
231238
else next(self.model.parameters()).device
@@ -239,15 +246,12 @@ class created with the llm model that follows huggingface style
239246

240247
def _forward_func(
241248
self,
242-
# pyre-fixme[2]: Parameter must be annotated.
243-
perturbed_tensor,
244-
# pyre-fixme[2]: Parameter must be annotated.
245-
inp,
246-
# pyre-fixme[2]: Parameter must be annotated.
247-
target_tokens,
249+
perturbed_tensor: Union[None, Tensor],
250+
inp: InterpretableInput,
251+
target_tokens: Tensor,
248252
use_cached_outputs: bool = False,
249-
_inspect_forward=None,
250-
) -> Union[int, Tensor]:
253+
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
254+
) -> Tensor:
251255
perturbed_input = self._format_model_input(inp.to_model_input(perturbed_tensor))
252256
init_model_inp = perturbed_input
253257

@@ -279,7 +283,9 @@ def _forward_func(
279283
(model_inp, torch.tensor([[target_token]]).to(self.device)), dim=1
280284
)
281285

282-
total_log_prob = sum(log_prob_list)
286+
# pyre-ignore[9] pyre/mypy thinks sum returns int here, but it will return
287+
# Tensor
288+
total_log_prob: Tensor = sum(log_prob_list) # type: ignore
283289
# 1st element is the total prob, rest are the target tokens
284290
# add a leading dim for batch even we only support single instance for now
285291
if self.include_per_token_attr:
@@ -288,8 +294,6 @@ def _forward_func(
288294
).unsqueeze(0)
289295
else:
290296
target_log_probs = total_log_prob # type: ignore
291-
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[int,
292-
# Tensor]`.
293297
target_probs = torch.exp(target_log_probs)
294298

295299
if _inspect_forward:
@@ -301,35 +305,31 @@ def _forward_func(
301305

302306
return target_probs if self.attr_target != "log_prob" else target_log_probs
303307

304-
# pyre-fixme[3]: Return type must be annotated.
305-
def _format_model_input(self, model_input: Union[str, Tensor]):
308+
def _format_model_input(self, model_input: Union[str, Tensor]) -> Tensor:
306309
"""
307310
Convert str to tokenized tensor
308311
to make LLMAttribution work with model inputs of both
309312
raw text and text token tensors
310313
"""
311314
# return tensor(1, n_tokens)
312315
if isinstance(model_input, str):
313-
return self.tokenizer.encode(model_input, return_tensors="pt").to(
314-
self.device
315-
)
316+
# pyre-ignore[9] pyre/mypy thinks return type may be List, but it will be
317+
# Tensor
318+
return self.tokenizer.encode( # type: ignore
319+
model_input, return_tensors="pt"
320+
).to(self.device)
316321
return model_input.to(self.device)
317322

318323
def attribute(
319324
self,
320325
inp: InterpretableInput,
321326
target: Union[str, torch.Tensor, None] = None,
322327
num_trials: int = 1,
323-
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
324-
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
325-
# errors.
326-
gen_args: Optional[Dict] = None,
328+
gen_args: Optional[Dict[str, Any]] = None,
327329
use_cached_outputs: bool = True,
328330
# internal callback hook can be used for logging
329-
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
330-
_inspect_forward: Optional[Callable] = None,
331-
# pyre-fixme[2]: Parameter must be annotated.
332-
**kwargs,
331+
_inspect_forward: Optional[Callable[[str, str, List[float]], None]] = None,
332+
**kwargs: Any,
333333
) -> LLMAttributionResult:
334334
"""
335335
Args:
@@ -380,10 +380,14 @@ def attribute(
380380
target_tokens = torch.tensor(target_tokens)
381381
elif type(target) is torch.Tensor:
382382
target_tokens = target
383+
else:
384+
raise TypeError(
385+
"target must either be str or Tensor, but the type of target is "
386+
"{}".format(type(target))
387+
)
383388

384389
attr = torch.zeros(
385390
[
386-
# pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
387391
1 + len(target_tokens) if self.include_per_token_attr else 1,
388392
inp.n_itp_features,
389393
],
@@ -398,8 +402,6 @@ def attribute(
398402
attr_input,
399403
additional_forward_args=(
400404
inp,
401-
# pyre-fixme[61]: `target_tokens` is undefined, or not always
402-
# defined.
403405
target_tokens,
404406
use_cached_outputs,
405407
_inspect_forward,
@@ -424,7 +426,6 @@ def attribute(
424426
attr[1:] if self.include_per_token_attr else None
425427
), # shape(n_output_token, n_input_features)
426428
inp.values,
427-
# pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
428429
self.tokenizer.convert_ids_to_tokens(target_tokens),
429430
)
430431

@@ -454,14 +455,11 @@ class LLMGradientAttribution(Attribution):
454455
SUPPORTED_METHODS = (LayerIntegratedGradients,)
455456
SUPPORTED_INPUTS = (TextTokenInput,)
456457

457-
# pyre-fixme[3]: Return type must be annotated.
458458
def __init__(
459459
self,
460-
# pyre-fixme[2]: Parameter must be annotated.
461-
attr_method,
462-
# pyre-fixme[2]: Parameter must be annotated.
463-
tokenizer,
464-
):
460+
attr_method: GradientAttribution,
461+
tokenizer: TokenizerLike,
462+
) -> None:
465463
"""
466464
Args:
467465
attr_method (Attribution): instance of a supported perturbation attribution
@@ -476,19 +474,15 @@ class created with the llm model that follows huggingface style
476474
super().__init__(attr_method.forward_func)
477475

478476
# shallow copy is enough to avoid modifying original instance
479-
# pyre-fixme[4]: Attribute must be annotated.
480-
self.attr_method = copy(attr_method)
477+
self.attr_method: GradientAttribution = copy(attr_method)
481478
self.attr_method.forward_func = self._forward_func
482479

483480
# alias, we really need a model and don't support wrapper functions
484481
# coz we need call model.forward, model.generate, etc.
485-
# pyre-fixme[4]: Attribute must be annotated.
486-
self.model = cast(nn.Module, self.forward_func)
482+
self.model: nn.Module = cast(nn.Module, self.forward_func)
487483

488-
# pyre-fixme[4]: Attribute must be annotated.
489-
self.tokenizer = tokenizer
490-
# pyre-fixme[4]: Attribute must be annotated.
491-
self.device = (
484+
self.tokenizer: TokenizerLike = tokenizer
485+
self.device: torch.device = (
492486
cast(torch.device, self.model.device)
493487
if hasattr(self.model, "device")
494488
else next(self.model.parameters()).device
@@ -526,9 +520,7 @@ def _forward_func(
526520
# the attribution target is limited to the log probability
527521
return token_log_probs
528522

529-
# pyre-fixme[3]: Return type must be annotated.
530-
# pyre-fixme[2]: Parameter must be annotated.
531-
def _format_model_input(self, model_input):
523+
def _format_model_input(self, model_input: Tensor) -> Tensor:
532524
"""
533525
Convert str to tokenized tensor
534526
"""
@@ -538,12 +530,8 @@ def attribute(
538530
self,
539531
inp: InterpretableInput,
540532
target: Union[str, torch.Tensor, None] = None,
541-
# pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use
542-
# `typing.Dict[<key type>, <value type>]` to avoid runtime subscripting
543-
# errors.
544-
gen_args: Optional[Dict] = None,
545-
# pyre-fixme[2]: Parameter must be annotated.
546-
**kwargs,
533+
gen_args: Optional[Dict[str, Any]] = None,
534+
**kwargs: Any,
547535
) -> LLMAttributionResult:
548536
"""
549537
Args:
@@ -590,19 +578,21 @@ def attribute(
590578
target_tokens = torch.tensor(target_tokens)
591579
elif type(target) is torch.Tensor:
592580
target_tokens = target
581+
else:
582+
raise TypeError(
583+
"target must either be str or Tensor, but the type of target is "
584+
"{}".format(type(target))
585+
)
593586

594587
attr_inp = inp.to_tensor().to(self.device)
595588

596589
attr_list = []
597-
# pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
598590
for cur_target_idx, _ in enumerate(target_tokens):
599591
# attr in shape(batch_size, input+output_len, emb_dim)
600592
attr = self.attr_method.attribute(
601593
attr_inp,
602594
additional_forward_args=(
603595
inp,
604-
# pyre-fixme[61]: `target_tokens` is undefined, or not always
605-
# defined.
606596
target_tokens,
607597
cur_target_idx,
608598
),
@@ -629,7 +619,7 @@ def attribute(
629619
# it attributes to all the elements of the output of the specified layer
630620
# so we need special handling for the inp type which don't care all the elements
631621
if isinstance(inp, TextTokenInput) and inp.itp_mask is not None:
632-
itp_mask = inp.itp_mask.to(self.device)
622+
itp_mask = inp.itp_mask.to(attr.device)
633623
itp_mask = itp_mask.expand_as(attr)
634624
attr = attr[itp_mask].view(attr.size(0), -1)
635625

@@ -642,7 +632,6 @@ def attribute(
642632
seq_attr,
643633
attr, # shape(n_output_token, n_input_features)
644634
inp.values,
645-
# pyre-fixme[61]: `target_tokens` is undefined, or not always defined.
646635
self.tokenizer.convert_ids_to_tokens(target_tokens),
647636
)
648637

0 commit comments

Comments
 (0)