40
40
TemperatureLogitsWarper ,
41
41
TopKLogitsWarper ,
42
42
TopPLogitsWarper ,
43
+ TypicalLogitsWarper ,
43
44
)
44
45
from .generation_stopping_criteria import (
45
46
MaxLengthCriteria ,
@@ -620,7 +621,12 @@ def _reorder_cache(self, past, beam_idx):
620
621
)
621
622
622
623
def _get_logits_warper (
623
- self , top_k : int = None , top_p : float = None , temperature : float = None , num_beams : int = None
624
+ self ,
625
+ top_k : int = None ,
626
+ top_p : float = None ,
627
+ typical_p : float = None ,
628
+ temperature : float = None ,
629
+ num_beams : int = None ,
624
630
) -> LogitsProcessorList :
625
631
"""
626
632
This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
@@ -630,6 +636,7 @@ def _get_logits_warper(
630
636
# init warp parameters
631
637
top_k = top_k if top_k is not None else self .config .top_k
632
638
top_p = top_p if top_p is not None else self .config .top_p
639
+ typical_p = typical_p if typical_p is not None else self .config .typical_p
633
640
temperature = temperature if temperature is not None else self .config .temperature
634
641
# instantiate warpers list
635
642
warpers = LogitsProcessorList ()
@@ -642,6 +649,8 @@ def _get_logits_warper(
642
649
warpers .append (TopKLogitsWarper (top_k = top_k , min_tokens_to_keep = (2 if num_beams > 1 else 1 )))
643
650
if top_p is not None and top_p < 1.0 :
644
651
warpers .append (TopPLogitsWarper (top_p = top_p , min_tokens_to_keep = (2 if num_beams > 1 else 1 )))
652
+ if typical_p is not None and typical_p < 1.0 :
653
+ warpers .append (TypicalLogitsWarper (mass = typical_p , min_tokens_to_keep = (2 if num_beams > 1 else 1 )))
645
654
return warpers
646
655
647
656
def _get_logits_processor (
@@ -811,6 +820,7 @@ def generate(
811
820
temperature : Optional [float ] = None ,
812
821
top_k : Optional [int ] = None ,
813
822
top_p : Optional [float ] = None ,
823
+ typical_p : Optional [float ] = None ,
814
824
repetition_penalty : Optional [float ] = None ,
815
825
bad_words_ids : Optional [Iterable [int ]] = None ,
816
826
bos_token_id : Optional [int ] = None ,
@@ -1191,7 +1201,7 @@ def generate(
1191
1201
elif is_sample_gen_mode :
1192
1202
# 10. prepare logits warper
1193
1203
logits_warper = self ._get_logits_warper (
1194
- top_k = top_k , top_p = top_p , temperature = temperature , num_beams = num_beams
1204
+ top_k = top_k , top_p = top_p , typical_p = typical_p , temperature = temperature , num_beams = num_beams
1195
1205
)
1196
1206
1197
1207
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
@@ -1253,7 +1263,7 @@ def generate(
1253
1263
elif is_beam_sample_gen_mode :
1254
1264
# 10. prepare logits warper
1255
1265
logits_warper = self ._get_logits_warper (
1256
- top_k = top_k , top_p = top_p , temperature = temperature , num_beams = num_beams
1266
+ top_k = top_k , top_p = top_p , typical_p = typical_p , temperature = temperature , num_beams = num_beams
1257
1267
)
1258
1268
1259
1269
if stopping_criteria .max_length is None :
0 commit comments