3
3
4
4
import torch
5
5
from huggingface_hub import hf_hub_download
6
- from model_training .utils .utils import get_model , get_tokenizer
6
+ from model_training .utils .utils import get_all_linear_layers , get_model , get_tokenizer , merge_dicts
7
7
from peft import LoraConfig , PeftModel , PrefixTuningConfig , get_peft_model , prepare_model_for_int8_training
8
8
9
9
@@ -18,11 +18,15 @@ def load_peft_model(model, peft_model_path, tokenizer):
18
18
torch_dtype = model .dtype ,
19
19
)
20
20
model .eos_token_id = tokenizer .eos_token_id
21
- extra_embeds = hf_hub_download (peft_model_path , "extra_embeddings.pt" )
22
- embed_weights = torch .load (extra_embeds , map_location = model .device )
23
- model .base_model .model .model .embed_tokens .weight [len (tokenizer ) - embed_weights .shape [0 ] :, :] = embed_weights .to (
24
- model .base_model .model .model .embed_tokens .weight .dtype
25
- )
21
+ try :
22
+ extra_embeds = hf_hub_download (peft_model_path , "extra_embeddings.pt" )
23
+ embed_weights = torch .load (extra_embeds , map_location = model .device )
24
+ model .base_model .model .model .embed_tokens .weight [
25
+ len (tokenizer ) - embed_weights .shape [0 ] :, :
26
+ ] = embed_weights .to (model .base_model .model .model .embed_tokens .weight .dtype )
27
+ except Exception :
28
+ print ("Warning:Extra embeddings not added. This is expected if adapter file contains WTE" )
29
+
26
30
return model
27
31
28
32
@@ -42,27 +46,40 @@ def make_inputs_require_grad(module, input, output):
42
46
return model
43
47
44
48
45
- def peft_model (model , peft_type = "lora" , int8_training = False , gradient_checkpointing = False ):
49
+ def peft_model (model , training_config ):
50
+ peft_config = training_config .peft_config
51
+ peft_type = peft_config .pop ("peft_type" , "lora" )
46
52
if peft_type == "lora" :
47
- config = LoraConfig (
48
- r = 16 ,
49
- lora_alpha = 32 ,
50
- target_modules = ["q_proj" , "k_proj" , "v_proj" , "o_proj" ],
51
- lora_dropout = 0.05 ,
52
- bias = "none" ,
53
- task_type = "CAUSAL_LM" ,
54
- )
53
+ default_args = {
54
+ "r" : 16 ,
55
+ "lora_alpha" : 32 ,
56
+ "target_modules" : "all" ,
57
+ "lora_dropout" : 0.05 ,
58
+ "bias" : "none" ,
59
+ "task_type" : "CAUSAL_LM" ,
60
+ "modules_to_save" : ["wte" , "lm_head" ],
61
+ }
62
+ kwargs = merge_dicts (default_args , peft_config )
63
+ if kwargs .get ("target_modules" ) == "all" :
64
+ kwargs .update ({"target_modules" : get_all_linear_layers (model )})
65
+ config = LoraConfig (** kwargs )
55
66
elif peft_type == "prefix-tuning" :
56
- config = PrefixTuningConfig (
57
- num_virtual_tokens = 30 , prefix_projection = True , encoder_hidden_size = 1024 , task_type = "CAUSAL_LM"
58
- )
67
+ default_args = {
68
+ "num_virtual_tokens" : 30 ,
69
+ "prefix_projection" : True ,
70
+ "encoder_hidden_size" : 1024 ,
71
+ "task_type" : "CAUSAL_LM" ,
72
+ }
73
+ kwargs = merge_dicts (default_args , peft_config )
74
+ config = PrefixTuningConfig (** kwargs )
59
75
else :
60
76
raise ValueError ("peft_method config is lora or prefix-tuning" )
61
77
model = get_peft_model (model , config )
62
- if int8_training :
78
+
79
+ if training_config .int8_training :
63
80
model = prepare_model_for_int8_training (model )
64
81
65
- if gradient_checkpointing :
82
+ if training_config . gradient_checkpointing :
66
83
model = prepare_model_for_gradient_checkpointing (model )
67
84
model .print_trainable_parameters ()
68
85
return model
0 commit comments