@@ -95,11 +95,14 @@ def get_max_input_length(worker_config: inference.WorkerConfig, plugin_used: boo
95
95
return max_input_length
96
96
97
97
98
- def get_tokens_until (tokens : list [int ], target : int | list [int ]) -> list [int ]:
99
- if isinstance (target , int ):
100
- return tokens [: tokens .index (target )]
101
- else :
102
- return next ((i for i in range (len (tokens ) - len (target ) + 1 ) if tokens [i : i + len (target )] == target ))
98
+ def get_tokens_until (tokens : list [int ], target : list [int ]) -> list [int ]:
99
+ if len (target ) == 1 :
100
+ return tokens [: tokens .index (target [0 ])]
101
+
102
+ for i in range (len (tokens ) - len (target )):
103
+ if tokens [i : i + len (target )] == target :
104
+ break
105
+ return tokens [:i ]
103
106
104
107
105
108
def truncate_prompt (
@@ -118,8 +121,8 @@ def truncate_prompt(
118
121
"""
119
122
with shared_tokenizer_lock :
120
123
ids = tokenizer .encode (prompt )
121
- # prompter_prefix_ids could be int or list of ints
122
- prompter_prefix_ids = tokenizer .convert_tokens_to_ids (special_tokens ["prompter" ])
124
+ # list of int IDs
125
+ prompter_prefix_ids = tokenizer .encode (special_tokens ["prompter" ])
123
126
124
127
system_prompt : str | None = None
125
128
system_tokens : list [int ] | None = None
@@ -134,7 +137,9 @@ def truncate_prompt(
134
137
135
138
num_system_tokens = len (system_tokens ) if system_tokens else 0
136
139
# Maximum token allowed for the conversation, ex system prompt
137
- max_conversation_length = max_input_length - num_system_tokens
140
+ # We incorporate a buffer to allow for final inference tokenization differing from ours
141
+ # This is a slightly hacky workaround and it would be better to find a cleaner solution
142
+ max_conversation_length = max_input_length - num_system_tokens - int (0.01 * max_input_length )
138
143
ids = ids [- (max_conversation_length - 1 ) :]
139
144
140
145
with shared_tokenizer_lock :
0 commit comments