Skip to content

Commit 3783d18

Browse files
committed
migrate to latest main for hf transformers.
1 parent 7a95b21 commit 3783d18

File tree

4 files changed

+125
-45
lines changed

4 files changed

+125
-45
lines changed

README.md

+55-15
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ This is the repo for the Stanford Alpaca project, which aims to build and share
1515
- The [52K data](#data-release) used for fine-tuning the model.
1616
- The code for [generating the data](#data-generation-process).
1717
- The code for [fine-tuning the model](#fine-tuning).
18+
- The code for [recovering Alpaca-7B weights from our released weight diff](#recovering-alpaca-weights).
1819

1920
Note: We thank the community for feedback on Stanford-Alpaca and supporting our research. Our live demo is suspended until further notice.
2021

@@ -115,10 +116,7 @@ We fine-tune LLaMA-7B and LLaMA-13B with the following hyperparameters:
115116
| Max length | 512 | 512 |
116117
| Weight decay | 0 | 0 |
117118

118-
We have also fine-tuned larger variants of LLaMA and are in the process of evaluating those models.
119-
120-
Given Hugging Face hasn't officially supported the LLaMA models, we fine-tuned LLaMA with Hugging Face's transformers library by installing it from a particular fork (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged).
121-
The hash of the specific commit we installed was `68d640f7c368bcaaaecfc678f11908ebbd3d6176`.
119+
We have also fine-tuned larger variants of LLaMA and performed subsequent RLHF and are in the process of evaluating those models.
122120

123121
To reproduce our fine-tuning runs for LLaMA, first install the requirements
124122

@@ -153,20 +151,10 @@ torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
153151
--lr_scheduler_type "cosine" \
154152
--logging_steps 1 \
155153
--fsdp "full_shard auto_wrap" \
156-
--fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
154+
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
157155
--tf32 True
158156
```
159157

160-
### Warning
161-
162-
`fsdp_transformer_layer_cls_to_wrap` must be set to the name of the specific decoder layer.
163-
The LLaMA Hugging Face PR is not stable.
164-
Earlier commits used the name `LLaMADecoderLayer` for their decoder layer (the commit hash our code is based on this).
165-
More recent commits use `LlamaDecoderLayer` (notice the small case difference).
166-
Not setting `fsdp_transformer_layer_cls_to_wrap` to the correct name will lead to drastic slowdowns in training.
167-
168-
### Side notes
169-
170158
The same script also works for OPT fine-tuning. Here's an example for fine-tuning OPT-6.7B
171159

172160
```bash
@@ -196,6 +184,58 @@ torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
196184
Note the given training script is meant to be simple and easy to use, and is not particularly optimized.
197185
To run on more gpus, you may prefer to turn down `gradient_accumulation_steps` to keep a global batch size of 128. Global batch size has not been tested for optimality.
198186

187+
### Addressing OOM
188+
189+
Naively, fine-tuning a 7B model requires about 7 x 4 x 4 = 112 GB of VRAM. Commands given above enable parameter sharding, so no redundant model copy is stored on any GPU.
190+
If you'd like to further reduce the memory footprint, here are some options:
191+
192+
- Turn on CPU offload for FSDP with `--fsdp "full_shard auto_wrap offload"`. This saves VRAM at the cost longer runtime.
193+
- In our experience, DeepSpeed stage-3 (with offload) can at times be more memory efficient than FSDP. Here's an example to use DeepSpeed stage-3 with 4 GPUs with both parameter and optimizer offload:
194+
```bash
195+
pip install deepspeed
196+
torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
197+
--model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
198+
--data_path ./alpaca_data.json \
199+
--bf16 True \
200+
--output_dir <your_output_dir> \
201+
--num_train_epochs 3 \
202+
--per_device_train_batch_size 4 \
203+
--per_device_eval_batch_size 4 \
204+
--gradient_accumulation_steps 8 \
205+
--evaluation_strategy "no" \
206+
--save_strategy "steps" \
207+
--save_steps 2000 \
208+
--save_total_limit 1 \
209+
--learning_rate 2e-5 \
210+
--weight_decay 0. \
211+
--warmup_ratio 0.03 \
212+
--deepspeed "./configs/default_offload_opt_param.json" \
213+
--tf32 True
214+
```
215+
- The DeepSpeed library also provides some [helpful functions](https://deepspeed.readthedocs.io/en/latest/memory.html) to estimate memory usage.
216+
- [LoRA](https://arxiv.org/abs/2106.09685) fine-tunes low-rank slices of the query, key, and value embeddings. This can reduce the total memory footprint from 112GB to about 7x4=28GB. We may release our re-implemention of this in the future, but for now the [peft](https://github.com/huggingface/peft) codebase can be a useful resource.
217+
218+
## Recovering Alpaca Weights
219+
220+
The weight diff between Alpaca-7B and LLaMA-7B is located [here](https://huggingface.co/tatsu-lab/alpaca-7b-wdiff/tree/main).
221+
To recover the original Alpaca-7B weights, follow these steps:
222+
```text
223+
1. Convert Meta's released weights into huggingface format. Follow this guide:
224+
https://huggingface.co/docs/transformers/main/model_doc/llama
225+
2. Make sure you cloned the released weight diff into your local machine. The weight diff is located at:
226+
https://huggingface.co/tatsu-lab/alpaca-7b/tree/main
227+
3. Run this function with the correct paths. E.g.,
228+
python weight_diff.py recover --path_raw <path_to_step_1_dir> --path_diff <path_to_step_2_dir> --path_tuned <path_to_store_recovered_weights>
229+
```
230+
231+
Once step 3 completes, you should have a directory with the recovered weights, from which you can load the model like the following
232+
233+
```python
234+
import transformers
235+
alpaca_model = transformers.AutoModelForCausalLM.from_pretrained("<path_to_store_recovered_weights>")
236+
alpaca_tokenizer = transformers.AutoTokenizer.from_pretrained("<path_to_store_recovered_weights>")
237+
```
238+
199239
### Authors
200240
201241
All grad students below contributed equally and the order is determined by random draw.
+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
{
2+
"bf16": {
3+
"enabled": "auto"
4+
},
5+
"optimizer": {
6+
"type": "AdamW",
7+
"params": {
8+
"lr": "auto",
9+
"betas": "auto",
10+
"eps": "auto",
11+
"weight_decay": "auto"
12+
}
13+
},
14+
"scheduler": {
15+
"type": "WarmupDecayLR",
16+
"params": {
17+
"total_num_steps": "auto",
18+
"warmup_min_lr": "auto",
19+
"warmup_max_lr": "auto",
20+
"warmup_num_steps": "auto"
21+
}
22+
},
23+
"zero_optimization": {
24+
"stage": 3,
25+
"offload_optimizer": {
26+
"device": "cpu",
27+
"pin_memory": true
28+
},
29+
"offload_param": {
30+
"device": "cpu",
31+
"pin_memory": true
32+
},
33+
"overlap_comm": true,
34+
"contiguous_gradients": true,
35+
"sub_group_size": 1e9,
36+
"reduce_bucket_size": "auto",
37+
"stage3_prefetch_bucket_size": "auto",
38+
"stage3_param_persistence_threshold": "auto",
39+
"stage3_max_live_parameters": 1e9,
40+
"stage3_max_reuse_distance": 1e9,
41+
"stage3_gather_16bit_weights_on_model_save": false
42+
},
43+
"gradient_accumulation_steps": "auto",
44+
"gradient_clipping": "auto",
45+
"steps_per_print": 5,
46+
"train_batch_size": "auto",
47+
"train_micro_batch_size_per_gpu": "auto",
48+
"wall_clock_breakdown": false
49+
}

requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ numpy
22
rouge_score
33
fire
44
openai
5-
transformers>=4.26.1
5+
transformers>=4.28.1
66
torch
77
sentencepiece
8-
tokenizers==0.12.1
8+
tokenizers>=0.13.3
99
wandb

train.py

+19-28
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,19 @@
1515
import copy
1616
import logging
1717
from dataclasses import dataclass, field
18-
from typing import Optional, Dict, Sequence
18+
from typing import Dict, Optional, Sequence
1919

2020
import torch
2121
import transformers
22+
import utils
2223
from torch.utils.data import Dataset
2324
from transformers import Trainer
2425

25-
import utils
26-
2726
IGNORE_INDEX = -100
2827
DEFAULT_PAD_TOKEN = "[PAD]"
2928
DEFAULT_EOS_TOKEN = "</s>"
30-
DEFAULT_BOS_TOKEN = "</s>"
31-
DEFAULT_UNK_TOKEN = "</s>"
29+
DEFAULT_BOS_TOKEN = "<s>"
30+
DEFAULT_UNK_TOKEN = "<unk>"
3231
PROMPT_DICT = {
3332
"prompt_input": (
3433
"Below is an instruction that describes a task, paired with an input that provides further context. "
@@ -63,15 +62,6 @@ class TrainingArguments(transformers.TrainingArguments):
6362
)
6463

6564

66-
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
67-
"""Collects the state dict and dump to disk."""
68-
state_dict = trainer.model.state_dict()
69-
if trainer.args.should_save:
70-
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
71-
del state_dict
72-
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
73-
74-
7565
def smart_tokenizer_and_embedding_resize(
7666
special_tokens_dict: Dict,
7767
tokenizer: transformers.PreTrainedTokenizer,
@@ -205,26 +195,27 @@ def train():
205195
padding_side="right",
206196
use_fast=False,
207197
)
198+
special_tokens_dict = dict()
208199
if tokenizer.pad_token is None:
209-
smart_tokenizer_and_embedding_resize(
210-
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
211-
tokenizer=tokenizer,
212-
model=model,
213-
)
214-
if "llama" in model_args.model_name_or_path:
215-
tokenizer.add_special_tokens(
216-
{
217-
"eos_token": DEFAULT_EOS_TOKEN,
218-
"bos_token": DEFAULT_BOS_TOKEN,
219-
"unk_token": DEFAULT_UNK_TOKEN,
220-
}
221-
)
200+
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
201+
if tokenizer.eos_token is None:
202+
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
203+
if tokenizer.bos_token is None:
204+
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
205+
if tokenizer.unk_token is None:
206+
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
207+
208+
smart_tokenizer_and_embedding_resize(
209+
special_tokens_dict=special_tokens_dict,
210+
tokenizer=tokenizer,
211+
model=model,
212+
)
222213

223214
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
224215
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
225216
trainer.train()
226217
trainer.save_state()
227-
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
218+
trainer.save_model(output_dir=training_args.output_dir)
228219

229220

230221
if __name__ == "__main__":

0 commit comments

Comments
 (0)