Skip to content

Commit b8637f2

Browse files
authored
Configure PEFT from config (LAION-AI#3571)
## What Added support to configure PEFT from config and save `WTE` embeddings with adapter files to enable easy loading of OA Lora weights. ## Why Earlier PEFT modules were hardcoded for llama model only. This was an issue when training other models using peft like RWModel, GPTNeoX, etc ## How Introduces extra parameter `peft_config` to config.yml
1 parent 018657b commit b8637f2

File tree

4 files changed

+65
-24
lines changed

4 files changed

+65
-24
lines changed

model/model_training/configs/config.yaml

+6-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ defaults:
22
rng_seed: 0xa1221f97
33
learning_rate: 1e-5
44
gradient_checkpointing: false
5+
int8_training: false
56
gradient_accumulation_steps: 32
67
per_device_train_batch_size: 2
78
per_device_eval_batch_size: 2
@@ -803,8 +804,12 @@ rope_scaling_test:
803804
residual_dropout_lima: true
804805
log_wandb: true
805806
peft_model: true
806-
peft_type: "lora"
807+
peft_config:
808+
peft_type: "lora"
809+
r: 16
807810
superhot: true
808811
superhot_config:
809812
type: linear
810813
scale: 2
814+
datasets:
815+
- dolly15k

model/model_training/models/peft_modeling.py

+37-20
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from huggingface_hub import hf_hub_download
6-
from model_training.utils.utils import get_model, get_tokenizer
6+
from model_training.utils.utils import get_all_linear_layers, get_model, get_tokenizer, merge_dicts
77
from peft import LoraConfig, PeftModel, PrefixTuningConfig, get_peft_model, prepare_model_for_int8_training
88

99

@@ -18,11 +18,15 @@ def load_peft_model(model, peft_model_path, tokenizer):
1818
torch_dtype=model.dtype,
1919
)
2020
model.eos_token_id = tokenizer.eos_token_id
21-
extra_embeds = hf_hub_download(peft_model_path, "extra_embeddings.pt")
22-
embed_weights = torch.load(extra_embeds, map_location=model.device)
23-
model.base_model.model.model.embed_tokens.weight[len(tokenizer) - embed_weights.shape[0] :, :] = embed_weights.to(
24-
model.base_model.model.model.embed_tokens.weight.dtype
25-
)
21+
try:
22+
extra_embeds = hf_hub_download(peft_model_path, "extra_embeddings.pt")
23+
embed_weights = torch.load(extra_embeds, map_location=model.device)
24+
model.base_model.model.model.embed_tokens.weight[
25+
len(tokenizer) - embed_weights.shape[0] :, :
26+
] = embed_weights.to(model.base_model.model.model.embed_tokens.weight.dtype)
27+
except Exception:
28+
print("Warning:Extra embeddings not added. This is expected if adapter file contains WTE")
29+
2630
return model
2731

2832

@@ -42,27 +46,40 @@ def make_inputs_require_grad(module, input, output):
4246
return model
4347

4448

45-
def peft_model(model, peft_type="lora", int8_training=False, gradient_checkpointing=False):
49+
def peft_model(model, training_config):
50+
peft_config = training_config.peft_config
51+
peft_type = peft_config.pop("peft_type", "lora")
4652
if peft_type == "lora":
47-
config = LoraConfig(
48-
r=16,
49-
lora_alpha=32,
50-
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
51-
lora_dropout=0.05,
52-
bias="none",
53-
task_type="CAUSAL_LM",
54-
)
53+
default_args = {
54+
"r": 16,
55+
"lora_alpha": 32,
56+
"target_modules": "all",
57+
"lora_dropout": 0.05,
58+
"bias": "none",
59+
"task_type": "CAUSAL_LM",
60+
"modules_to_save": ["wte", "lm_head"],
61+
}
62+
kwargs = merge_dicts(default_args, peft_config)
63+
if kwargs.get("target_modules") == "all":
64+
kwargs.update({"target_modules": get_all_linear_layers(model)})
65+
config = LoraConfig(**kwargs)
5566
elif peft_type == "prefix-tuning":
56-
config = PrefixTuningConfig(
57-
num_virtual_tokens=30, prefix_projection=True, encoder_hidden_size=1024, task_type="CAUSAL_LM"
58-
)
67+
default_args = {
68+
"num_virtual_tokens": 30,
69+
"prefix_projection": True,
70+
"encoder_hidden_size": 1024,
71+
"task_type": "CAUSAL_LM",
72+
}
73+
kwargs = merge_dicts(default_args, peft_config)
74+
config = PrefixTuningConfig(**kwargs)
5975
else:
6076
raise ValueError("peft_method config is lora or prefix-tuning")
6177
model = get_peft_model(model, config)
62-
if int8_training:
78+
79+
if training_config.int8_training:
6380
model = prepare_model_for_int8_training(model)
6481

65-
if gradient_checkpointing:
82+
if training_config.gradient_checkpointing:
6683
model = prepare_model_for_gradient_checkpointing(model)
6784
model.print_trainable_parameters()
6885
return model

model/model_training/trainer_sft.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,7 @@ def main():
424424

425425
if training_conf.peft_model:
426426
print("Using PEFT model")
427-
model = peft_model(
428-
model, peft_type=training_conf.peft_type, gradient_checkpointing=training_conf.gradient_checkpointing
429-
)
427+
model = peft_model(model, training_conf)
430428

431429
if training_conf.quantization:
432430
import bitsandbytes # This is noisy, so delay importing until after argument parsing so it doesn't make --help noisy

model/model_training/utils/utils.py

+21
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,24 @@ def process_output(output: str, method: str = "v2", bot_name: str = "Joi") -> st
432432
answer = output.split("\n\n{}:".format(bot_name))[-1]
433433
answer = answer.split("</s>")[0].replace("<|endoftext|>", "").lstrip().split("\n\n{}:".format(bot_name))[0]
434434
return answer
435+
436+
437+
def merge_dicts(default: dict, config: dict):
438+
"""
439+
merge default dict with config dict to override params
440+
"""
441+
for k, v in default.items():
442+
if k not in config.keys():
443+
config.update({k: v})
444+
445+
return config
446+
447+
448+
def get_all_linear_layers(model):
449+
cls = torch.nn.Linear
450+
451+
modules = {name.split(".")[-1] for name, module in model.named_modules() if isinstance(module, cls)}
452+
if "lm_head" in modules:
453+
modules.remove("lm_head")
454+
455+
return list(modules)

0 commit comments

Comments
 (0)