@@ -1621,7 +1621,7 @@ def from_pretrained(
16211621 )
16221622
16231623 model , tokenizer = patch_tokenizer (model , tokenizer )
1624- model , tokenizer = model_patcher .post_patch (model , tokenizer )
1624+ model , tokenizer = model_patcher .post_patch (model , tokenizer , max_position_embeddings )
16251625
16261626 # Patch up QKV / O and MLP
16271627 for idx , layer in enumerate (model .model .layers ):
@@ -1827,7 +1827,18 @@ def from_pretrained(
18271827
18281828
18291829 @staticmethod
1830- def post_patch (model , tokenizer ):
1830+ def post_patch (model , tokenizer , max_seq_length ):
1831+ # Add max_seq_length to all modules
1832+ extra_ignored_labels = torch .full ((max_seq_length , 1 ), - 100 , device = "cuda:0" )
1833+ internal_model = model
1834+ while hasattr (internal_model , "model" ):
1835+ internal_model .max_seq_length = max_seq_length
1836+ internal_model .extra_ignored_labels = extra_ignored_labels
1837+ internal_model = internal_model .model
1838+ pass
1839+ internal_model .max_seq_length = max_seq_length
1840+ internal_model .extra_ignored_labels = extra_ignored_labels
1841+
18311842 # Torch.compile fails on embedding matrix??
18321843 try : old_input_embedding = model .get_input_embeddings ().weight
18331844 except : return model , tokenizer
@@ -2459,18 +2470,6 @@ def patch_peft_model(
24592470 )
24602471 patch_saving_functions (model )
24612472
2462- # Patch cross entropy loss labels
2463- # Fixes https://github.com/unslothai/unsloth/issues/10
2464- max_seq_length = model .max_seq_length
2465- extra_ignored_labels = torch .full ((max_seq_length , 1 ), - 100 , device = "cuda:0" )
2466- model .model .extra_ignored_labels = extra_ignored_labels
2467- internal_model = model
2468- while hasattr (internal_model , "model" ):
2469- internal_model .max_seq_length = max_seq_length
2470- internal_model = internal_model .model
2471- pass
2472- internal_model .max_seq_length = max_seq_length
2473-
24742473 # Patch tokenizer to pad to the right
24752474 internal_model = model
24762475 while hasattr (internal_model , "model" ):
0 commit comments