@@ -588,15 +588,21 @@ def load_correct_tokenizer(
588588def _fix_chat_template (chat_template ):
589589 endfor = "{% endfor %}"
590590 where = chat_template .find (endfor )
591- if where == - 1 : return chat_template
591+ if where == - 1 :
592+ endfor = "{%- endfor %}"
593+ where = chat_template .find (endfor )
594+ if where == - 1 :
595+ return chat_template
592596
593597 after_endfor = chat_template [where + len (endfor ):]
594598
595- if "{% if" not in after_endfor and "{% set " not in after_endfor and \
599+ dash = "-" if endfor .startswith ("{%-" ) else ""
600+
601+ if "{%" + dash + " if" not in after_endfor and "{%" + dash + " set " not in after_endfor and \
596602 after_endfor .startswith ("{{" ) and after_endfor .endswith ("}}" ) and \
597603 after_endfor .count ("{{" ) == 1 and after_endfor .count ("}}" ) == 1 :
598604
599- after_endfor = "{% if add_generation_prompt %}" + after_endfor + "{% endif %}"
605+ after_endfor = "{%" + dash + " if add_generation_prompt %}" + after_endfor + endfor
600606
601607 chat_template = chat_template [:where + len (endfor )] + after_endfor
602608 pass
@@ -643,10 +649,12 @@ def fix_chat_template(tokenizer):
643649
644650 if no == yes :
645651 # SAME?! That's not good! We check for add_generation_prompt
646- if "{% if add_generation_prompt %}" not in chat_template :
652+ if "{% if add_generation_prompt %}" not in chat_template and \
653+ "{%- if add_generation_prompt %}" not in chat_template :
647654 # Try fixing it by adding it
648655 new_chat_template = _fix_chat_template (chat_template )
649- if "{% if add_generation_prompt %}" not in new_chat_template :
656+ if "{% if add_generation_prompt %}" not in new_chat_template and \
657+ "{%- if add_generation_prompt %}" not in new_chat_template :
650658 raise RuntimeError (
651659 f"Unsloth: The tokenizer `{ tokenizer .name_or_path } `\n " \
652660 "does not have a {% if add_generation_prompt %} for generation purposes.\n " \
@@ -1001,13 +1009,14 @@ def patch_sft_trainer_tokenizer():
10011009 # Also DPO weirdly tokenizes non numeric columns? Delete them!
10021010 check_text += \
10031011 "\n " \
1004- "column_names = set(self.train_dataset.column_names)\n " \
1005- "check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n " \
1006- " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n " \
1007- " 'prompt_input_ids', 'prompt_attention_mask']\n " \
1008- "if all(x in column_names for x in check):\n " \
1009- " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n " \
1010- "del check, column_names\n " \
1012+ "if hasattr(self.train_dataset, 'column_names'):\n " \
1013+ " column_names = set(self.train_dataset.column_names)\n " \
1014+ " check = ['chosen', 'rejected', 'prompt', 'chosen_input_ids', 'chosen_attention_mask',\n " \
1015+ " 'chosen_labels', 'rejected_input_ids', 'rejected_attention_mask', 'rejected_labels',\n " \
1016+ " 'prompt_input_ids', 'prompt_attention_mask']\n " \
1017+ " if all(x in column_names for x in check):\n " \
1018+ " self.train_dataset = self.train_dataset.remove_columns(['chosen', 'rejected', 'prompt'])\n " \
1019+ " del check, column_names\n " \
10111020 "\n "
10121021
10131022 check_text = check_text .split ("\n " )
0 commit comments