Skip to content

Commit 9d07be0

Browse files
committed
ignored labels
1 parent 02437a8 commit 9d07be0

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

unsloth/models/gemma.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,18 @@ def pre_patch():
339339

340340

341341
@staticmethod
342-
def post_patch(model, tokenizer):
342+
def post_patch(model, tokenizer, max_seq_length):
343+
# Add max_seq_length to all modules
344+
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
345+
internal_model = model
346+
while hasattr(internal_model, "model"):
347+
internal_model.max_seq_length = max_seq_length
348+
internal_model.extra_ignored_labels = extra_ignored_labels
349+
internal_model = internal_model.model
350+
pass
351+
internal_model.max_seq_length = max_seq_length
352+
internal_model.extra_ignored_labels = extra_ignored_labels
353+
343354
# Torch.compile fails on embedding matrix??
344355
# Workaround randomnly fixes it for torch versions < 2.2
345356
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)

unsloth/models/gemma2.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,18 @@ def pre_patch():
490490

491491

492492
@staticmethod
493-
def post_patch(model, tokenizer):
493+
def post_patch(model, tokenizer, max_seq_length):
494+
# Add max_seq_length to all modules
495+
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
496+
internal_model = model
497+
while hasattr(internal_model, "model"):
498+
internal_model.max_seq_length = max_seq_length
499+
internal_model.extra_ignored_labels = extra_ignored_labels
500+
internal_model = internal_model.model
501+
pass
502+
internal_model.max_seq_length = max_seq_length
503+
internal_model.extra_ignored_labels = extra_ignored_labels
504+
494505
# Torch.compile fails on embedding matrix??
495506
# Workaround randomnly fixes it for torch versions < 2.2
496507
model.model.embed_tokens = torch.nn.Embedding.from_pretrained(model.model.embed_tokens.weight)

unsloth/models/llama.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,7 +1621,7 @@ def from_pretrained(
16211621
)
16221622

16231623
model, tokenizer = patch_tokenizer(model, tokenizer)
1624-
model, tokenizer = model_patcher.post_patch(model, tokenizer)
1624+
model, tokenizer = model_patcher.post_patch(model, tokenizer, max_position_embeddings)
16251625

16261626
# Patch up QKV / O and MLP
16271627
for idx, layer in enumerate(model.model.layers):
@@ -1827,7 +1827,18 @@ def from_pretrained(
18271827

18281828

18291829
@staticmethod
1830-
def post_patch(model, tokenizer):
1830+
def post_patch(model, tokenizer, max_seq_length):
1831+
# Add max_seq_length to all modules
1832+
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
1833+
internal_model = model
1834+
while hasattr(internal_model, "model"):
1835+
internal_model.max_seq_length = max_seq_length
1836+
internal_model.extra_ignored_labels = extra_ignored_labels
1837+
internal_model = internal_model.model
1838+
pass
1839+
internal_model.max_seq_length = max_seq_length
1840+
internal_model.extra_ignored_labels = extra_ignored_labels
1841+
18311842
# Torch.compile fails on embedding matrix??
18321843
try: old_input_embedding = model.get_input_embeddings ().weight
18331844
except: return model, tokenizer
@@ -2459,18 +2470,6 @@ def patch_peft_model(
24592470
)
24602471
patch_saving_functions(model)
24612472

2462-
# Patch cross entropy loss labels
2463-
# Fixes https://github.com/unslothai/unsloth/issues/10
2464-
max_seq_length = model.max_seq_length
2465-
extra_ignored_labels = torch.full((max_seq_length, 1), -100, device = "cuda:0")
2466-
model.model.extra_ignored_labels = extra_ignored_labels
2467-
internal_model = model
2468-
while hasattr(internal_model, "model"):
2469-
internal_model.max_seq_length = max_seq_length
2470-
internal_model = internal_model.model
2471-
pass
2472-
internal_model.max_seq_length = max_seq_length
2473-
24742473
# Patch tokenizer to pad to the right
24752474
internal_model = model
24762475
while hasattr(internal_model, "model"):

0 commit comments

Comments
 (0)