Skip to content

Commit fa78c4f

Browse files
committed
training code.
1 parent 7ad0c6b commit fa78c4f

File tree

3 files changed

+301
-3
lines changed

3 files changed

+301
-3
lines changed

README.md

+63-2
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ The inner circle of the plot represents the root verb of the instructions, and t
9191
[<img src="assets/parse_analysis.png" width="750" />](./assets/parse_analysis.png)
9292

9393
## Fine-tuning
94-
We fine-tune our model using standard Hugging Face training code with the following hyperparameters:
94+
We fine-tune our models using standard Hugging Face training code with the following hyperparameters:
9595

9696
| Hyperparameter | Value |
9797
|----------------|-------|
@@ -101,7 +101,68 @@ We fine-tune our model using standard Hugging Face training code with the follow
101101
| Max length | 512 |
102102
| Weight decay | 1 |
103103

104-
We are waiting for Hugging Face to officially support the llama models (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged) before we release a stable version of the finetuning code.
104+
Given Hugging Face hasn't officially supported the LLaMA models, we fine-tuned LLaMA with Hugging Face's transformers library by installing it from a particular fork (i.e. this [PR](https://github.com/huggingface/transformers/pull/21955) to be merged).
105+
The hash of the specific commit we installed was `68d640f7c368bcaaaecfc678f11908ebbd3d6176`.
106+
107+
To reproduce our fine-tuning runs for LLaMA, first install the requirements
108+
```bash
109+
pip install -r requirements.txt
110+
```
111+
Then, install the particular fork of Hugging Face's transformers library.
112+
113+
Below is a command that fine-tunes LLaMA-7B with our dataset on a machine with 4 A100 80G GPUs in FSDP `full_shard` mode.
114+
Replace `<your_random_port>` with a port of your own, `<your_path_to_hf_converted_llama_ckpt_and_tokenizer>` with the
115+
path to your converted checkpoint and tokenizer (following instructions in the PR), and `<your_output_dir>` with where you want to store your outputs.
116+
117+
```
118+
torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
119+
--model_name_or_path <your_path_to_hf_converted_llama_ckpt_and_tokenizer> \
120+
--data_path ./alpaca_data.json \
121+
--bf16 True \
122+
--output_dir <your_output_dir> \
123+
--num_train_epochs 3 \
124+
--per_device_train_batch_size 4 \
125+
--per_device_eval_batch_size 4 \
126+
--gradient_accumulation_steps 8 \
127+
--evaluation_strategy "no" \
128+
--save_strategy "steps" \
129+
--save_steps 2000 \
130+
--save_total_limit 1 \
131+
--learning_rate 2e-5 \
132+
--weight_decay 0. \
133+
--warmup_ratio 0.03 \
134+
--lr_scheduler_type "cosine" \
135+
--logging_steps 1 \
136+
--fsdp "full_shard auto_wrap" \
137+
--fsdp_transformer_layer_cls_to_wrap 'LLaMADecoderLayer' \
138+
--tf32 True
139+
```
140+
141+
The same script also works for OPT fine-tuning. Here's an example for fine-tuning OPT-6.7B
142+
143+
```bash
144+
torchrun --nproc_per_node=4 --master_port=<your_random_port> train.py \
145+
--model_name_or_path "facebook/opt-6.7b" \
146+
--data_path ./alpaca_data.json \
147+
--bf16 True \
148+
--output_dir <your_output_dir> \
149+
--num_train_epochs 3 \
150+
--per_device_train_batch_size 4 \
151+
--per_device_eval_batch_size 4 \
152+
--gradient_accumulation_steps 8 \
153+
--evaluation_strategy "no" \
154+
--save_strategy "steps" \
155+
--save_steps 2000 \
156+
--save_total_limit 1 \
157+
--learning_rate 2e-5 \
158+
--weight_decay 0. \
159+
--warmup_ratio 0.03 \
160+
--lr_scheduler_type "cosine" \
161+
--logging_steps 1 \
162+
--fsdp "full_shard auto_wrap" \
163+
--fsdp_transformer_layer_cls_to_wrap 'OPTDecoderLayer' \
164+
--tf32 True
165+
```
105166

106167
### Authors
107168
All grad students below contributed equally and the order is determined by random draw.

requirements.txt

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
numpy
22
rouge_score
33
fire
4-
openai
4+
openai
5+
transformers>=4.26.1
6+
torch
7+
sentencepiece
8+
tokenizers==0.12.1
9+
wandb

train.py

+232
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import copy
16+
import logging
17+
from dataclasses import dataclass, field
18+
from typing import Optional, Dict, Sequence
19+
20+
import torch
21+
import transformers
22+
from torch.utils.data import Dataset
23+
from transformers import Trainer
24+
25+
import utils
26+
27+
IGNORE_INDEX = -100
28+
DEFAULT_PAD_TOKEN = "[PAD]"
29+
DEFAULT_EOS_TOKEN = "</s>"
30+
DEFAULT_BOS_TOKEN = "</s>"
31+
DEFAULT_UNK_TOKEN = "</s>"
32+
PROMPT_DICT = {
33+
"prompt_input": (
34+
"Below is an instruction that describes a task, paired with an input that provides further context. "
35+
"Write a response that appropriately completes the request.\n\n"
36+
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
37+
),
38+
"prompt_no_input": (
39+
"Below is an instruction that describes a task. "
40+
"Write a response that appropriately completes the request.\n\n"
41+
"### Instruction:\n{instruction}\n\n### Response:"
42+
),
43+
}
44+
45+
46+
@dataclass
47+
class ModelArguments:
48+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
49+
50+
51+
@dataclass
52+
class DataArguments:
53+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
54+
55+
56+
@dataclass
57+
class TrainingArguments(transformers.TrainingArguments):
58+
cache_dir: Optional[str] = field(default=None)
59+
optim: str = field(default="adamw_torch")
60+
model_max_length: int = field(
61+
default=512,
62+
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
63+
)
64+
65+
66+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
67+
"""Collects the state dict and dump to disk."""
68+
state_dict = trainer.model.state_dict()
69+
if trainer.args.should_save:
70+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
71+
del state_dict
72+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
73+
74+
75+
def smart_tokenizer_and_embedding_resize(
76+
special_tokens_dict: Dict,
77+
tokenizer: transformers.PreTrainedTokenizer,
78+
model: transformers.PreTrainedModel,
79+
):
80+
"""Resize tokenizer and embedding.
81+
82+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
83+
"""
84+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
85+
model.resize_token_embeddings(len(tokenizer))
86+
87+
if num_new_tokens > 0:
88+
input_embeddings = model.get_input_embeddings().weight.data
89+
output_embeddings = model.get_output_embeddings().weight.data
90+
91+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
92+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
93+
94+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
95+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
96+
97+
98+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
99+
"""Tokenize a list of strings."""
100+
tokenized_list = [
101+
tokenizer(
102+
text,
103+
return_tensors="pt",
104+
padding="longest",
105+
max_length=tokenizer.model_max_length,
106+
truncation=True,
107+
)
108+
for text in strings
109+
]
110+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
111+
input_ids_lens = labels_lens = [
112+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
113+
]
114+
return dict(
115+
input_ids=input_ids,
116+
labels=labels,
117+
input_ids_lens=input_ids_lens,
118+
labels_lens=labels_lens,
119+
)
120+
121+
122+
def preprocess(
123+
sources: Sequence[str],
124+
targets: Sequence[str],
125+
tokenizer: transformers.PreTrainedTokenizer,
126+
) -> Dict:
127+
"""Preprocess the data by tokenizing."""
128+
examples = [s + t for s, t in zip(sources, targets)]
129+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
130+
input_ids = examples_tokenized["input_ids"]
131+
labels = copy.deepcopy(input_ids)
132+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
133+
label[:source_len] = IGNORE_INDEX
134+
return dict(input_ids=input_ids, labels=labels)
135+
136+
137+
class SupervisedDataset(Dataset):
138+
"""Dataset for supervised fine-tuning."""
139+
140+
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
141+
super(SupervisedDataset, self).__init__()
142+
logging.warning("Loading data...")
143+
list_data_dict = utils.jload(data_path)
144+
145+
logging.warning("Formatting inputs...")
146+
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
147+
sources = [
148+
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
149+
for example in list_data_dict
150+
]
151+
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
152+
153+
logging.warning("Tokenizing inputs... This may take some time...")
154+
data_dict = preprocess(sources, targets, tokenizer)
155+
156+
self.input_ids = data_dict["input_ids"]
157+
self.labels = data_dict["labels"]
158+
159+
def __len__(self):
160+
return len(self.input_ids)
161+
162+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
163+
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
164+
165+
166+
@dataclass
167+
class DataCollatorForSupervisedDataset(object):
168+
"""Collate examples for supervised fine-tuning."""
169+
170+
tokenizer: transformers.PreTrainedTokenizer
171+
172+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
173+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
174+
input_ids = torch.nn.utils.rnn.pad_sequence(
175+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
176+
)
177+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
178+
return dict(
179+
input_ids=input_ids,
180+
labels=labels,
181+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
182+
)
183+
184+
185+
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
186+
"""Make dataset and collator for supervised fine-tuning."""
187+
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
188+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
189+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
190+
191+
192+
def train():
193+
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
194+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
195+
196+
model = transformers.AutoModelForCausalLM.from_pretrained(
197+
model_args.model_name_or_path,
198+
cache_dir=training_args.cache_dir,
199+
)
200+
201+
tokenizer = transformers.AutoTokenizer.from_pretrained(
202+
model_args.model_name_or_path,
203+
cache_dir=training_args.cache_dir,
204+
model_max_length=training_args.model_max_length,
205+
padding_side="right",
206+
use_fast=False,
207+
)
208+
if tokenizer.pad_token is None:
209+
smart_tokenizer_and_embedding_resize(
210+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
211+
tokenizer=tokenizer,
212+
model=model,
213+
)
214+
if "llama" in model_args.model_name_or_path:
215+
tokenizer.add_special_tokens(
216+
{
217+
"eos_token": DEFAULT_EOS_TOKEN,
218+
"bos_token": DEFAULT_BOS_TOKEN,
219+
"unk_token": DEFAULT_UNK_TOKEN,
220+
}
221+
)
222+
223+
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
224+
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
225+
trainer.train()
226+
trainer.evaluate()
227+
trainer.save_state()
228+
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
229+
230+
231+
if __name__ == "__main__":
232+
train()

0 commit comments

Comments
 (0)