|
| 1 | +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import copy |
| 16 | +import logging |
| 17 | +from dataclasses import dataclass, field |
| 18 | +from typing import Optional, Dict, Sequence |
| 19 | + |
| 20 | +import torch |
| 21 | +import transformers |
| 22 | +from torch.utils.data import Dataset |
| 23 | +from transformers import Trainer |
| 24 | + |
| 25 | +import utils |
| 26 | + |
| 27 | +IGNORE_INDEX = -100 |
| 28 | +DEFAULT_PAD_TOKEN = "[PAD]" |
| 29 | +DEFAULT_EOS_TOKEN = "</s>" |
| 30 | +DEFAULT_BOS_TOKEN = "</s>" |
| 31 | +DEFAULT_UNK_TOKEN = "</s>" |
| 32 | +PROMPT_DICT = { |
| 33 | + "prompt_input": ( |
| 34 | + "Below is an instruction that describes a task, paired with an input that provides further context. " |
| 35 | + "Write a response that appropriately completes the request.\n\n" |
| 36 | + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" |
| 37 | + ), |
| 38 | + "prompt_no_input": ( |
| 39 | + "Below is an instruction that describes a task. " |
| 40 | + "Write a response that appropriately completes the request.\n\n" |
| 41 | + "### Instruction:\n{instruction}\n\n### Response:" |
| 42 | + ), |
| 43 | +} |
| 44 | + |
| 45 | + |
| 46 | +@dataclass |
| 47 | +class ModelArguments: |
| 48 | + model_name_or_path: Optional[str] = field(default="facebook/opt-125m") |
| 49 | + |
| 50 | + |
| 51 | +@dataclass |
| 52 | +class DataArguments: |
| 53 | + data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
| 54 | + |
| 55 | + |
| 56 | +@dataclass |
| 57 | +class TrainingArguments(transformers.TrainingArguments): |
| 58 | + cache_dir: Optional[str] = field(default=None) |
| 59 | + optim: str = field(default="adamw_torch") |
| 60 | + model_max_length: int = field( |
| 61 | + default=512, |
| 62 | + metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, |
| 63 | + ) |
| 64 | + |
| 65 | + |
| 66 | +def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): |
| 67 | + """Collects the state dict and dump to disk.""" |
| 68 | + state_dict = trainer.model.state_dict() |
| 69 | + if trainer.args.should_save: |
| 70 | + cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} |
| 71 | + del state_dict |
| 72 | + trainer._save(output_dir, state_dict=cpu_state_dict) # noqa |
| 73 | + |
| 74 | + |
| 75 | +def smart_tokenizer_and_embedding_resize( |
| 76 | + special_tokens_dict: Dict, |
| 77 | + tokenizer: transformers.PreTrainedTokenizer, |
| 78 | + model: transformers.PreTrainedModel, |
| 79 | +): |
| 80 | + """Resize tokenizer and embedding. |
| 81 | +
|
| 82 | + Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
| 83 | + """ |
| 84 | + num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
| 85 | + model.resize_token_embeddings(len(tokenizer)) |
| 86 | + |
| 87 | + if num_new_tokens > 0: |
| 88 | + input_embeddings = model.get_input_embeddings().weight.data |
| 89 | + output_embeddings = model.get_output_embeddings().weight.data |
| 90 | + |
| 91 | + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
| 92 | + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
| 93 | + |
| 94 | + input_embeddings[-num_new_tokens:] = input_embeddings_avg |
| 95 | + output_embeddings[-num_new_tokens:] = output_embeddings_avg |
| 96 | + |
| 97 | + |
| 98 | +def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: |
| 99 | + """Tokenize a list of strings.""" |
| 100 | + tokenized_list = [ |
| 101 | + tokenizer( |
| 102 | + text, |
| 103 | + return_tensors="pt", |
| 104 | + padding="longest", |
| 105 | + max_length=tokenizer.model_max_length, |
| 106 | + truncation=True, |
| 107 | + ) |
| 108 | + for text in strings |
| 109 | + ] |
| 110 | + input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] |
| 111 | + input_ids_lens = labels_lens = [ |
| 112 | + tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list |
| 113 | + ] |
| 114 | + return dict( |
| 115 | + input_ids=input_ids, |
| 116 | + labels=labels, |
| 117 | + input_ids_lens=input_ids_lens, |
| 118 | + labels_lens=labels_lens, |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +def preprocess( |
| 123 | + sources: Sequence[str], |
| 124 | + targets: Sequence[str], |
| 125 | + tokenizer: transformers.PreTrainedTokenizer, |
| 126 | +) -> Dict: |
| 127 | + """Preprocess the data by tokenizing.""" |
| 128 | + examples = [s + t for s, t in zip(sources, targets)] |
| 129 | + examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] |
| 130 | + input_ids = examples_tokenized["input_ids"] |
| 131 | + labels = copy.deepcopy(input_ids) |
| 132 | + for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): |
| 133 | + label[:source_len] = IGNORE_INDEX |
| 134 | + return dict(input_ids=input_ids, labels=labels) |
| 135 | + |
| 136 | + |
| 137 | +class SupervisedDataset(Dataset): |
| 138 | + """Dataset for supervised fine-tuning.""" |
| 139 | + |
| 140 | + def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): |
| 141 | + super(SupervisedDataset, self).__init__() |
| 142 | + logging.warning("Loading data...") |
| 143 | + list_data_dict = utils.jload(data_path) |
| 144 | + |
| 145 | + logging.warning("Formatting inputs...") |
| 146 | + prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] |
| 147 | + sources = [ |
| 148 | + prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) |
| 149 | + for example in list_data_dict |
| 150 | + ] |
| 151 | + targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] |
| 152 | + |
| 153 | + logging.warning("Tokenizing inputs... This may take some time...") |
| 154 | + data_dict = preprocess(sources, targets, tokenizer) |
| 155 | + |
| 156 | + self.input_ids = data_dict["input_ids"] |
| 157 | + self.labels = data_dict["labels"] |
| 158 | + |
| 159 | + def __len__(self): |
| 160 | + return len(self.input_ids) |
| 161 | + |
| 162 | + def __getitem__(self, i) -> Dict[str, torch.Tensor]: |
| 163 | + return dict(input_ids=self.input_ids[i], labels=self.labels[i]) |
| 164 | + |
| 165 | + |
| 166 | +@dataclass |
| 167 | +class DataCollatorForSupervisedDataset(object): |
| 168 | + """Collate examples for supervised fine-tuning.""" |
| 169 | + |
| 170 | + tokenizer: transformers.PreTrainedTokenizer |
| 171 | + |
| 172 | + def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
| 173 | + input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) |
| 174 | + input_ids = torch.nn.utils.rnn.pad_sequence( |
| 175 | + input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id |
| 176 | + ) |
| 177 | + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) |
| 178 | + return dict( |
| 179 | + input_ids=input_ids, |
| 180 | + labels=labels, |
| 181 | + attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
| 182 | + ) |
| 183 | + |
| 184 | + |
| 185 | +def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: |
| 186 | + """Make dataset and collator for supervised fine-tuning.""" |
| 187 | + train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) |
| 188 | + data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
| 189 | + return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) |
| 190 | + |
| 191 | + |
| 192 | +def train(): |
| 193 | + parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| 194 | + model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| 195 | + |
| 196 | + model = transformers.AutoModelForCausalLM.from_pretrained( |
| 197 | + model_args.model_name_or_path, |
| 198 | + cache_dir=training_args.cache_dir, |
| 199 | + ) |
| 200 | + |
| 201 | + tokenizer = transformers.AutoTokenizer.from_pretrained( |
| 202 | + model_args.model_name_or_path, |
| 203 | + cache_dir=training_args.cache_dir, |
| 204 | + model_max_length=training_args.model_max_length, |
| 205 | + padding_side="right", |
| 206 | + use_fast=False, |
| 207 | + ) |
| 208 | + if tokenizer.pad_token is None: |
| 209 | + smart_tokenizer_and_embedding_resize( |
| 210 | + special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), |
| 211 | + tokenizer=tokenizer, |
| 212 | + model=model, |
| 213 | + ) |
| 214 | + if "llama" in model_args.model_name_or_path: |
| 215 | + tokenizer.add_special_tokens( |
| 216 | + { |
| 217 | + "eos_token": DEFAULT_EOS_TOKEN, |
| 218 | + "bos_token": DEFAULT_BOS_TOKEN, |
| 219 | + "unk_token": DEFAULT_UNK_TOKEN, |
| 220 | + } |
| 221 | + ) |
| 222 | + |
| 223 | + data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) |
| 224 | + trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) |
| 225 | + trainer.train() |
| 226 | + trainer.evaluate() |
| 227 | + trainer.save_state() |
| 228 | + safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) |
| 229 | + |
| 230 | + |
| 231 | +if __name__ == "__main__": |
| 232 | + train() |
0 commit comments