@@ -831,6 +831,7 @@ def check_nvidia():
831831PRE_CHECK = check_nvidia ()
832832
833833
834+ import inspect
834835from inspect import getsource
835836import trl .trainer .sft_trainer
836837from trl .trainer .sft_trainer import *
@@ -869,6 +870,35 @@ def neftune_post_forward_hook(module, input, output):
869870pass
870871
871872
873+ def patch_trl_tokenizer_processing_class (trainer_name ):
874+ # New TRL removes tokenizer!
875+ # We return it back!
876+ exec (f"from trl import { trainer_name } " , globals ())
877+ if str (eval (f"{ trainer_name } " ).__name__ ).startswith ("Unsloth" ): return None
878+ parameters = eval (f"inspect.signature({ trainer_name } ).parameters" )
879+ if "tokenizer" in parameters : return None
880+
881+ args = {
882+ key : \
883+ value .default \
884+ if type (value .default ) is not str else \
885+ f"'{ value .default } '" \
886+ for key , value in parameters .items ()
887+ }
888+ args ["tokenizer" ] = None
889+ new_args = args .copy ()
890+ del new_args ["tokenizer" ]
891+ del new_args ["processing_class" ]
892+ new_args = ",\n " .join (f"{ ' ' * 12 } { key } = { key } " for key in new_args ) + \
893+ f",\n { ' ' * 12 } processing_class = tokenizer if tokenizer else processing_class"
894+ args = ",\n " .join (f"{ ' ' * 8 } { key } = { value } " for key , value in args .items ())
895+ args = f"def __init__(\n " + f"{ ' ' * 8 } self,\n " + args + "):"
896+ args += f"\n { ' ' * 8 } \n { ' ' * 8 } super().__init__(\n { new_args } \n { ' ' * 8 } )"
897+ new_class = f"""class Unsloth{ trainer_name } ({ trainer_name } ):\n { ' ' * 4 } { args } \n """
898+ return new_class
899+ pass
900+
901+
872902def patch_sft_trainer_tokenizer ():
873903 """
874904 Patches the trainer with changes
@@ -884,7 +914,8 @@ def patch_sft_trainer_tokenizer():
884914
885915 check_text = \
886916 "\n " \
887- "test_text = dataset[0][dataset_text_field] if (formatting_func is None or not use_formatting_func) else formatting_func(dataset[0])[0]\n " \
917+ "if 'tokenizer' not in locals(): tokenizer = processing_class\n " \
918+ "test_text = dataset[0][dataset_text_field] if (formatting_func is not None and dataset_text_field is None) else formatting_func(dataset[0])[0]\n " \
888919 "chat_template = getattr(tokenizer, 'chat_template', None)\n " \
889920 "chat_template = '' if chat_template is None else chat_template\n " \
890921 "has_bos_token_already = (test_text.startswith(tokenizer.bos_token) or tokenizer.bos_token in chat_template) " \
@@ -941,7 +972,8 @@ def patch_sft_trainer_tokenizer():
941972 " from transformers import __version__ as transformers_version\n " \
942973 " from packaging.version import Version\n " \
943974 " if Version(transformers_version) <= Version('4.45.2'):\n " \
944- " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers and Unsloth!')\n " \
975+ " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\\ n'\\ \n " \
976+ " '`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`')\n " \
945977 "except:\n " \
946978 " pass\n " \
947979 "\n \n "
@@ -981,4 +1013,13 @@ def patch_sft_trainer_tokenizer():
9811013 pass
9821014pass
9831015
1016+ # Fix TRL trainers with removed tokenizer args (got replaced with processing_class)
1017+ for trainer_name in ("SFTTrainer" , "DPOTrainer" , "KTOTrainer" ):
1018+ trainer_text = patch_trl_tokenizer_processing_class (trainer_name )
1019+ if trainer_text is None : continue
1020+ exec (trainer_text , globals ())
1021+ exec (f"trl.trainer.{ trainer_name } = Unsloth{ trainer_name } " , globals ())
1022+ pass
1023+
1024+ # FInally patch TRL tokenizer things
9841025patch_sft_trainer_tokenizer ()
0 commit comments