@@ -2170,6 +2170,9 @@ def from_pretrained(
21702170 m = m .model
21712171 pass
21722172 m .max_seq_length = max_seq_length
2173+ # Save to modules as well
2174+ for module in model .modules ():
2175+ module .max_seq_length = max_seq_length
21732176
21742177 # We check the tokenizer first for errors
21752178 if fix_tokenizer :
@@ -2228,6 +2231,11 @@ def from_pretrained(
22282231 # Add for_inference and for_training
22292232 model .for_training = functools .partial (FastLlamaModel .for_training , model )
22302233 model .for_inference = functools .partial (FastLlamaModel .for_inference , model )
2234+ m = model
2235+ while hasattr (m , "model" ):
2236+ m .for_training = functools .partial (FastBaseModel .for_training , m )
2237+ m .for_inference = functools .partial (FastBaseModel .for_inference , m )
2238+ m = m .model
22312239
22322240 # Patch generate
22332241 is_classification = "Classification" in str (type (model ))
@@ -2236,6 +2244,13 @@ def from_pretrained(
22362244 unsloth_fast_generate .__doc__ = model ._old_generate .__doc__
22372245 model .generate = types .MethodType (unsloth_fast_generate , model )
22382246 pass
2247+ # Set weight[padding_idx] = 0
2248+ with torch .no_grad ():
2249+ for name , module in model .named_modules ():
2250+ if type (module ) is torch .nn .Embedding :
2251+ if getattr (module , "weight" , None ) is not None and getattr (module , "padding_idx" , None ) is not None :
2252+ if module .padding_idx < module .weight .shape [0 ]:
2253+ module .weight [module .padding_idx ] = 0
22392254 return model , tokenizer
22402255 pass
22412256
@@ -2704,6 +2719,11 @@ def get_peft_model(
27042719 # Add for_inference and for_training
27052720 model .for_training = functools .partial (FastLlamaModel .for_training , model )
27062721 model .for_inference = functools .partial (FastLlamaModel .for_inference , model )
2722+ m = model
2723+ while hasattr (m , "model" ):
2724+ m .for_training = functools .partial (FastBaseModel .for_training , m )
2725+ m .for_inference = functools .partial (FastBaseModel .for_inference , m )
2726+ m = m .model
27072727 return model
27082728 pass
27092729
@@ -2892,6 +2912,9 @@ def patch_peft_model(
28922912 internal_model = internal_model .model
28932913 pass
28942914 internal_model .max_seq_length = max_seq_length
2915+ # Save to modules as well
2916+ for module in model .modules ():
2917+ module .max_seq_length = max_seq_length
28952918
28962919 # Patch tokenizer to pad to the right
28972920 internal_model = model
@@ -2916,6 +2939,11 @@ def patch_peft_model(
29162939 # Add for_inference and for_training
29172940 model .for_training = functools .partial (FastLlamaModel .for_training , model )
29182941 model .for_inference = functools .partial (FastLlamaModel .for_inference , model )
2942+ m = model
2943+ while hasattr (m , "model" ):
2944+ m .for_training = functools .partial (FastBaseModel .for_training , m )
2945+ m .for_inference = functools .partial (FastBaseModel .for_inference , m )
2946+ m = m .model
29192947 return model
29202948 pass
29212949
0 commit comments