Skip to content

Commit 0113aae

Browse files
authored
Add implementation of typical sampling (#15504)
* typical decoding * changing arg name * add test config params * forgotten arg rename * fix edge case where scores are same * test for typical logits warper * code quality fixes
1 parent f588cf4 commit 0113aae

5 files changed

+94
-3
lines changed

src/transformers/configuration_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def __init__(self, **kwargs):
282282
self.temperature = kwargs.pop("temperature", 1.0)
283283
self.top_k = kwargs.pop("top_k", 50)
284284
self.top_p = kwargs.pop("top_p", 1.0)
285+
self.typical_p = kwargs.pop("typical_p", 1.0)
285286
self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
286287
self.length_penalty = kwargs.pop("length_penalty", 1.0)
287288
self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)

src/transformers/generation_logits_process.py

+33
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,39 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
239239
return scores
240240

241241

242+
class TypicalLogitsWarper(LogitsWarper):
243+
def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
244+
245+
self.filter_value = filter_value
246+
self.mass = mass
247+
self.min_tokens_to_keep = min_tokens_to_keep
248+
249+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
250+
251+
# calculate entropy
252+
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
253+
p = torch.exp(normalized)
254+
ent = -(normalized * p).nansum(-1, keepdim=True)
255+
256+
# shift and sort
257+
shifted_scores = torch.abs((-normalized) - ent)
258+
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
259+
sorted_logits = scores.gather(-1, sorted_indices)
260+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
261+
262+
# Remove tokens with cumulative mass above the threshold
263+
last_ind = (cumulative_probs < self.mass).sum(dim=1)
264+
last_ind[last_ind < 0] = 0
265+
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
266+
if self.min_tokens_to_keep > 1:
267+
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
268+
sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
269+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
270+
271+
scores = scores.masked_fill(indices_to_remove, self.filter_value)
272+
return scores
273+
274+
242275
def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
243276
generated_ngrams = [{} for _ in range(num_hypos)]
244277
for idx in range(num_hypos):

src/transformers/generation_utils.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TemperatureLogitsWarper,
4141
TopKLogitsWarper,
4242
TopPLogitsWarper,
43+
TypicalLogitsWarper,
4344
)
4445
from .generation_stopping_criteria import (
4546
MaxLengthCriteria,
@@ -620,7 +621,12 @@ def _reorder_cache(self, past, beam_idx):
620621
)
621622

622623
def _get_logits_warper(
623-
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
624+
self,
625+
top_k: int = None,
626+
top_p: float = None,
627+
typical_p: float = None,
628+
temperature: float = None,
629+
num_beams: int = None,
624630
) -> LogitsProcessorList:
625631
"""
626632
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
@@ -630,6 +636,7 @@ def _get_logits_warper(
630636
# init warp parameters
631637
top_k = top_k if top_k is not None else self.config.top_k
632638
top_p = top_p if top_p is not None else self.config.top_p
639+
typical_p = typical_p if typical_p is not None else self.config.typical_p
633640
temperature = temperature if temperature is not None else self.config.temperature
634641
# instantiate warpers list
635642
warpers = LogitsProcessorList()
@@ -642,6 +649,8 @@ def _get_logits_warper(
642649
warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
643650
if top_p is not None and top_p < 1.0:
644651
warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
652+
if typical_p is not None and typical_p < 1.0:
653+
warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))
645654
return warpers
646655

647656
def _get_logits_processor(
@@ -811,6 +820,7 @@ def generate(
811820
temperature: Optional[float] = None,
812821
top_k: Optional[int] = None,
813822
top_p: Optional[float] = None,
823+
typical_p: Optional[float] = None,
814824
repetition_penalty: Optional[float] = None,
815825
bad_words_ids: Optional[Iterable[int]] = None,
816826
bos_token_id: Optional[int] = None,
@@ -1191,7 +1201,7 @@ def generate(
11911201
elif is_sample_gen_mode:
11921202
# 10. prepare logits warper
11931203
logits_warper = self._get_logits_warper(
1194-
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
1204+
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
11951205
)
11961206

11971207
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
@@ -1253,7 +1263,7 @@ def generate(
12531263
elif is_beam_sample_gen_mode:
12541264
# 10. prepare logits warper
12551265
logits_warper = self._get_logits_warper(
1256-
top_k=top_k, top_p=top_p, temperature=temperature, num_beams=num_beams
1266+
top_k=top_k, top_p=top_p, typical_p=typical_p, temperature=temperature, num_beams=num_beams
12571267
)
12581268

12591269
if stopping_criteria.max_length is None:

tests/test_configuration_common.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"temperature": 2.0,
5959
"top_k": 10,
6060
"top_p": 0.7,
61+
"typical_p": 0.2,
6162
"repetition_penalty": 0.8,
6263
"length_penalty": 0.8,
6364
"no_repeat_ngram_size": 5,

tests/test_generation_logits_process.py

+46
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
TemperatureLogitsWarper,
4242
TopKLogitsWarper,
4343
TopPLogitsWarper,
44+
TypicalLogitsWarper,
4445
)
4546

4647

@@ -191,6 +192,51 @@ def test_top_p_dist_warper(self):
191192
# first batch should keep three tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
192193
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [3, 2])
193194

195+
def test_typical_dist_warper(self):
196+
input_ids = None
197+
vocab_size = 10
198+
batch_size = 2
199+
200+
# create distribution and take log (inverse to Softmax as taken in TopPLogitsWarper)
201+
dist = torch.log(
202+
torch.tensor([[0.97, 0.01, 0.01, 0.01], [0.4, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float)
203+
)
204+
205+
typical_warp = TypicalLogitsWarper(0.5)
206+
filtered_dist = torch.exp(typical_warp(input_ids, dist))
207+
208+
# dist should be filtered to keep min num values so that sum is >= 0.7
209+
# exp (-inf) => 0
210+
EXPECTED_FILTERED_DIST = torch.tensor(
211+
[[0.97, 0.0, 0.0, 0.0], [0.0, 0.2, 0.2, 0.2]], device=torch_device, dtype=torch.float
212+
)
213+
self.assertTrue(torch.allclose(filtered_dist, EXPECTED_FILTERED_DIST, atol=1e-3))
214+
215+
# check special cases
216+
length = 5
217+
218+
logits = self._get_uniform_logits(batch_size=batch_size, length=length)
219+
typical_warp_safety_check = TypicalLogitsWarper(mass=0.5, filter_value=0.0, min_tokens_to_keep=3)
220+
221+
scores = typical_warp_safety_check(input_ids, logits)
222+
# uniform dist is not changed
223+
self.assertListEqual((scores == 0.0).to(torch.long).sum(dim=-1).tolist(), [0, 0])
224+
225+
# check edge cases with negative and extreme logits
226+
ramp_logits = torch.arange(vocab_size, device=torch_device, dtype=torch.float).unsqueeze(0).repeat(
227+
batch_size, 1
228+
) - (vocab_size // 2)
229+
230+
# make ramp_logits more extreme
231+
ramp_logits[1] = ramp_logits[1] * 100.0
232+
233+
# make sure at least 2 tokens are kept
234+
typical_warp = TypicalLogitsWarper(0.7, min_tokens_to_keep=2, filter_value=0.0)
235+
filtered_dist = typical_warp(input_ids, ramp_logits)
236+
237+
# first batch should keep two tokens, second batch would keep only 1, but due to `min_tokens_to_keep=2` keeps 2.
238+
self.assertListEqual((filtered_dist != 0.0).to(torch.long).sum(dim=-1).tolist(), [2, 2])
239+
194240
def test_no_repeat_ngram_dist_processor(self):
195241
vocab_size = 3
196242
batch_size = 2

0 commit comments

Comments
 (0)