diff --git a/unsloth/kernels/cross_entropy_loss.py b/unsloth/kernels/cross_entropy_loss.py index 5abed6a3e..1c8f8c8d9 100644 --- a/unsloth/kernels/cross_entropy_loss.py +++ b/unsloth/kernels/cross_entropy_loss.py @@ -411,7 +411,7 @@ def fast_cross_entropy_loss( labels = shift_labels, logit_softcapping = logit_softcapping, logit_scaling = logit_scaling, - n_items = kwargs.get("n_items", None), + n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), ) else: if logit_scaling != 0: diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index aa7a69c94..7a13f0bbd 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2024.10.0" +__version__ = "2024.10.1" __all__ = [ "prepare_model_for_kbit_training", @@ -43,6 +43,7 @@ "accelerate_new_send_to_device", "patch_gradient_checkpointing", "unpatch_gradient_checkpointing", + "patch_gradient_accumulation_fix", ] import torch @@ -1138,3 +1139,63 @@ def test_mask_creation(): assert(torch.all(correct_mask == our_mask)) pass pass + + +def _unsloth_get_batch_samples(self, epoch_iterator, num_batches): + batch_samples = [] + num_items_in_batch = None + for _ in range(num_batches): + try: + batch_samples += [next(epoch_iterator)] + except StopIteration: + break + if len(batch_samples) > 0 and "labels" in batch_samples[0]: + try: + num_items_in_batch = sum( + [torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples] + ) + except TypeError: + pass + return batch_samples, num_items_in_batch +pass + + +def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs): + if "num_items_in_batch" in kwargs: + if "num_items_in_batch" not in inputs: + inputs["num_items_in_batch"] = kwargs["num_items_in_batch"] + pass + pass + return self._old_compute_loss(model, inputs, args, kwargs) +pass + + +def patch_gradient_accumulation_fix(Trainer): + # Fixes gradient accumulation + if hasattr(Trainer, "get_batch_samples"): + from inspect import getsource + if \ + not getsource(Trainer.get_batch_samples).strip()\ + .endswith("return batch_samples, num_items_in_batch"): + + raise NotImplementedError("Unsloth: Please make a Github issue immediately!!") + else: + if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples": + Trainer.get_batch_samples = _unsloth_get_batch_samples + pass + + # Also fix passing in num_items_in_batch + if not hasattr(Trainer, "_old_compute_loss"): + Trainer._old_compute_loss = Trainer.compute_loss + Trainer.compute_loss = _unsloth_pre_compute_loss + pass + pass + else: + logger.warning_once( + "Unsloth: We fixed a gradient accumulation bug, "\ + "but it seems like you don't have the latest transformers version!\n"\ + "Please update transformers via:\n"\ + '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`' + ) + pass +pass diff --git a/unsloth/models/llama.py b/unsloth/models/llama.py index 4cd512a98..f0437207b 100644 --- a/unsloth/models/llama.py +++ b/unsloth/models/llama.py @@ -982,7 +982,7 @@ def _CausalLM_fast_forward( labels = shift_labels, logit_softcapping = logit_softcapping, logit_scaling = logit_scaling, - n_items = kwargs.get("n_items", None), + n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None), ) else: if logit_scaling != 0: @@ -1777,6 +1777,9 @@ def from_pretrained( patch_saving_functions(model) Trainer._inner_training_loop = _fast_inner_training_loop + # Fix gradient accumulation + patch_gradient_accumulation_fix(Trainer) + # Save tokenizer for inference purposes tokenizer.padding_side = "left" # Force inference internal_model = model diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 63d07c927..ffe9933f4 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -928,8 +928,23 @@ def patch_sft_trainer_tokenizer(): " torch.cuda.empty_cache()\n"\ "pass\n"\ "\n"\ - "fix_untrained_tokens(self.model, self.tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"\ - "fix_zero_training_loss(self.model, self.tokenizer, self.train_dataset)\n\n" + "tokenizer = self.processing_class if hasattr(self, 'processing_class') else self.tokenizer\n"\ + "fix_untrained_tokens(self.model, tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"\ + "fix_zero_training_loss(self.model, tokenizer, self.train_dataset)\n\n" + + # Warn on gradient accumulation steps if it's used + check_text += \ + "\n"\ + "try:\n"\ + " gradient_accumulation_steps = self.args.gradient_accumulation_steps\n"\ + " if type(gradient_accumulation_steps) is int and gradient_accumulation_steps > 1:\n"\ + " from transformers import __version__ as transformers_version\n"\ + " from packaging.version import Version\n"\ + " if Version(transformers_version) <= Version('4.45.2'):\n"\ + " print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers and Unsloth!')\n"\ + "except:\n"\ + " pass\n"\ + "\n\n" # Add NEFTune since it doesn't seem to work?? We need to manually inject it check_text += \ diff --git a/unsloth/trainer.py b/unsloth/trainer.py index c9c0ca2d0..25bb43402 100644 --- a/unsloth/trainer.py +++ b/unsloth/trainer.py @@ -22,7 +22,25 @@ from transformers import TrainingArguments pass from . import is_bfloat16_supported -from unsloth_zoo.training_utils import unsloth_train +from unsloth_zoo.training_utils import unsloth_train as _unsloth_train +from packaging.version import Version + +# Unsloth gradient accumulation fix: +from transformers import __version__ as transformers_version +if Version(transformers_version) > Version("4.45.2"): + def unsloth_train(trainer): + return trainer.train() + pass +else: + def unsloth_train(trainer): + print( + "Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\ + "If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\ + '`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`' + ) + return _unsloth_train(trainer) + pass +pass __all__ = [ "UnslothTrainingArguments",