Skip to content

Commit adcf8dc

Browse files
Fixes for worker prompt truncation in ChatML case (LAION-AI#3673)
Co-authored-by: Andreas Koepf <andreas.koepf@provisio.com>
1 parent d613c81 commit adcf8dc

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

inference/worker/__main__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def main():
3434
tokenizer = None
3535
else:
3636
tokenizer: transformers.PreTrainedTokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
37-
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {tokenizer.vocab_size}")
37+
logger.warning(f"Tokenizer {tokenizer.name_or_path} vocab size: {len(tokenizer)}")
3838

3939
inference_http = utils.HttpClient(
4040
base_url=settings.inference_server_url,

inference/worker/basic_hf_server.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def load_models():
138138
hf_config = transformers.AutoConfig.from_pretrained(model_config.model_id)
139139
logger.warning(f"Loading model {model_config.model_id}...")
140140
tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.model_id)
141-
logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {tokenizer.vocab_size}")
141+
logger.warning(f"tokenizer {tokenizer.name_or_path} has vocab size {len(tokenizer)}")
142142

143143
# see `decode_token` method, taken from HF text-generation-inference
144144
tokenizer.add_special_tokens({"additional_special_tokens": ["<decode-token>"]})

inference/worker/utils.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,14 @@ def get_max_input_length(worker_config: inference.WorkerConfig, plugin_used: boo
9595
return max_input_length
9696

9797

98-
def get_tokens_until(tokens: list[int], target: int | list[int]) -> list[int]:
99-
if isinstance(target, int):
100-
return tokens[: tokens.index(target)]
101-
else:
102-
return next((i for i in range(len(tokens) - len(target) + 1) if tokens[i : i + len(target)] == target))
98+
def get_tokens_until(tokens: list[int], target: list[int]) -> list[int]:
99+
if len(target) == 1:
100+
return tokens[: tokens.index(target[0])]
101+
102+
for i in range(len(tokens) - len(target)):
103+
if tokens[i : i + len(target)] == target:
104+
break
105+
return tokens[:i]
103106

104107

105108
def truncate_prompt(
@@ -118,8 +121,8 @@ def truncate_prompt(
118121
"""
119122
with shared_tokenizer_lock:
120123
ids = tokenizer.encode(prompt)
121-
# prompter_prefix_ids could be int or list of ints
122-
prompter_prefix_ids = tokenizer.convert_tokens_to_ids(special_tokens["prompter"])
124+
# list of int IDs
125+
prompter_prefix_ids = tokenizer.encode(special_tokens["prompter"])
123126

124127
system_prompt: str | None = None
125128
system_tokens: list[int] | None = None
@@ -134,7 +137,9 @@ def truncate_prompt(
134137

135138
num_system_tokens = len(system_tokens) if system_tokens else 0
136139
# Maximum token allowed for the conversation, ex system prompt
137-
max_conversation_length = max_input_length - num_system_tokens
140+
# We incorporate a buffer to allow for final inference tokenization differing from ours
141+
# This is a slightly hacky workaround and it would be better to find a cleaner solution
142+
max_conversation_length = max_input_length - num_system_tokens - int(0.01 * max_input_length)
138143
ids = ids[-(max_conversation_length - 1) :]
139144

140145
with shared_tokenizer_lock:

oasst-shared/oasst_shared/model_configs.py

+5
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,9 @@ def compat_hash(self) -> str:
150150
max_input_length=3072,
151151
max_total_length=4096,
152152
),
153+
"OA_SFT_CodeLlama_13B_10": ModelConfig(
154+
model_id="OpenAssistant/codellama-13b-oasst-sft-v10",
155+
max_input_length=8192,
156+
max_total_length=12288,
157+
),
153158
}

0 commit comments

Comments
 (0)