Skip to content

Commit e210840

Browse files
Gradient Accumulation Fix (#1146)
* Unsloth Zoo * Update trainer.py * Update trainer.py * Update cross_entropy_loss.py * n_items * Update llama.py * kwargs * Remove extraneous f prefixes (#1133) Co-authored-by: Emil Sadek <esadek@users.noreply.github.com> * Update __init__.py * kwargs * Update trainer.py * Update trainer.py * Update trainer.py * Fix GA * Update _utils.py * Update llama.py * Update tokenizer_utils.py * Warn on old versions * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py * Update tokenizer_utils.py --------- Co-authored-by: Emil Sadek <esadek@hotmail.com> Co-authored-by: Emil Sadek <esadek@users.noreply.github.com>
1 parent a395211 commit e210840

File tree

5 files changed

+103
-6
lines changed

5 files changed

+103
-6
lines changed

unsloth/kernels/cross_entropy_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def fast_cross_entropy_loss(
411411
labels = shift_labels,
412412
logit_softcapping = logit_softcapping,
413413
logit_scaling = logit_scaling,
414-
n_items = kwargs.get("n_items", None),
414+
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
415415
)
416416
else:
417417
if logit_scaling != 0:

unsloth/models/_utils.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
__version__ = "2024.10.0"
15+
__version__ = "2024.10.1"
1616

1717
__all__ = [
1818
"prepare_model_for_kbit_training",
@@ -43,6 +43,7 @@
4343
"accelerate_new_send_to_device",
4444
"patch_gradient_checkpointing",
4545
"unpatch_gradient_checkpointing",
46+
"patch_gradient_accumulation_fix",
4647
]
4748

4849
import torch
@@ -1138,3 +1139,63 @@ def test_mask_creation():
11381139
assert(torch.all(correct_mask == our_mask))
11391140
pass
11401141
pass
1142+
1143+
1144+
def _unsloth_get_batch_samples(self, epoch_iterator, num_batches):
1145+
batch_samples = []
1146+
num_items_in_batch = None
1147+
for _ in range(num_batches):
1148+
try:
1149+
batch_samples += [next(epoch_iterator)]
1150+
except StopIteration:
1151+
break
1152+
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
1153+
try:
1154+
num_items_in_batch = sum(
1155+
[torch.count_nonzero(x["labels"][..., 1:] != -100) for x in batch_samples]
1156+
)
1157+
except TypeError:
1158+
pass
1159+
return batch_samples, num_items_in_batch
1160+
pass
1161+
1162+
1163+
def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
1164+
if "num_items_in_batch" in kwargs:
1165+
if "num_items_in_batch" not in inputs:
1166+
inputs["num_items_in_batch"] = kwargs["num_items_in_batch"]
1167+
pass
1168+
pass
1169+
return self._old_compute_loss(model, inputs, args, kwargs)
1170+
pass
1171+
1172+
1173+
def patch_gradient_accumulation_fix(Trainer):
1174+
# Fixes gradient accumulation
1175+
if hasattr(Trainer, "get_batch_samples"):
1176+
from inspect import getsource
1177+
if \
1178+
not getsource(Trainer.get_batch_samples).strip()\
1179+
.endswith("return batch_samples, num_items_in_batch"):
1180+
1181+
raise NotImplementedError("Unsloth: Please make a Github issue immediately!!")
1182+
else:
1183+
if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples":
1184+
Trainer.get_batch_samples = _unsloth_get_batch_samples
1185+
pass
1186+
1187+
# Also fix passing in num_items_in_batch
1188+
if not hasattr(Trainer, "_old_compute_loss"):
1189+
Trainer._old_compute_loss = Trainer.compute_loss
1190+
Trainer.compute_loss = _unsloth_pre_compute_loss
1191+
pass
1192+
pass
1193+
else:
1194+
logger.warning_once(
1195+
"Unsloth: We fixed a gradient accumulation bug, "\
1196+
"but it seems like you don't have the latest transformers version!\n"\
1197+
"Please update transformers via:\n"\
1198+
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`'
1199+
)
1200+
pass
1201+
pass

unsloth/models/llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,7 @@ def _CausalLM_fast_forward(
982982
labels = shift_labels,
983983
logit_softcapping = logit_softcapping,
984984
logit_scaling = logit_scaling,
985-
n_items = kwargs.get("n_items", None),
985+
n_items = kwargs.get("num_items_in_batch", None) or kwargs.get("n_items", None),
986986
)
987987
else:
988988
if logit_scaling != 0:
@@ -1777,6 +1777,9 @@ def from_pretrained(
17771777
patch_saving_functions(model)
17781778
Trainer._inner_training_loop = _fast_inner_training_loop
17791779

1780+
# Fix gradient accumulation
1781+
patch_gradient_accumulation_fix(Trainer)
1782+
17801783
# Save tokenizer for inference purposes
17811784
tokenizer.padding_side = "left" # Force inference
17821785
internal_model = model

unsloth/tokenizer_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -928,8 +928,23 @@ def patch_sft_trainer_tokenizer():
928928
" torch.cuda.empty_cache()\n"\
929929
"pass\n"\
930930
"\n"\
931-
"fix_untrained_tokens(self.model, self.tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"\
932-
"fix_zero_training_loss(self.model, self.tokenizer, self.train_dataset)\n\n"
931+
"tokenizer = self.processing_class if hasattr(self, 'processing_class') else self.tokenizer\n"\
932+
"fix_untrained_tokens(self.model, tokenizer, self.train_dataset, IGNORED_TOKENIZER_NAMES, eps = 1e-16)\n\n"\
933+
"fix_zero_training_loss(self.model, tokenizer, self.train_dataset)\n\n"
934+
935+
# Warn on gradient accumulation steps if it's used
936+
check_text += \
937+
"\n"\
938+
"try:\n"\
939+
" gradient_accumulation_steps = self.args.gradient_accumulation_steps\n"\
940+
" if type(gradient_accumulation_steps) is int and gradient_accumulation_steps > 1:\n"\
941+
" from transformers import __version__ as transformers_version\n"\
942+
" from packaging.version import Version\n"\
943+
" if Version(transformers_version) <= Version('4.45.2'):\n"\
944+
" print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers and Unsloth!')\n"\
945+
"except:\n"\
946+
" pass\n"\
947+
"\n\n"
933948

934949
# Add NEFTune since it doesn't seem to work?? We need to manually inject it
935950
check_text += \

unsloth/trainer.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,25 @@
2222
from transformers import TrainingArguments
2323
pass
2424
from . import is_bfloat16_supported
25-
from unsloth_zoo.training_utils import unsloth_train
25+
from unsloth_zoo.training_utils import unsloth_train as _unsloth_train
26+
from packaging.version import Version
27+
28+
# Unsloth gradient accumulation fix:
29+
from transformers import __version__ as transformers_version
30+
if Version(transformers_version) > Version("4.45.2"):
31+
def unsloth_train(trainer):
32+
return trainer.train()
33+
pass
34+
else:
35+
def unsloth_train(trainer):
36+
print(
37+
"Unsloth: Using our custom gradient accumulation fixed trainer, which is not feature complete.\n"\
38+
"If you want to use our fix inside of HF, please update `transformers` to the latest version via:\n"\
39+
'`pip uninstall transformers -y && pip install --upgrade --no-cache-dir "git+https://github.com/huggingface/transformers.git"`'
40+
)
41+
return _unsloth_train(trainer)
42+
pass
43+
pass
2644

2745
__all__ = [
2846
"UnslothTrainingArguments",

0 commit comments

Comments
 (0)