Skip to content

Commit aafd8cb

Browse files
BenasdTWkashifqgallouedec
authored
🍟 [SFT] Handles the dataset if it has been preprocessed (#2863)
* return dataset if it's preprocessed * add is_processed flag variable * add test * move test_sft_trainer_directly_with_pretokenized_data to Tester2 * Update sft_trainer.py * no need for padding and truncation * minor reorganization * Update trl/trainer/sft_trainer.py * let the collator pad * style * fix tests --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
1 parent 8226538 commit aafd8cb

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

tests/test_sft_trainer.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def test_sft_trainer(self):
288288

289289
self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
290290

291-
def test_sft_trainer_with_pretokenzied_data_packing(self):
291+
def test_sft_trainer_with_pretokenized_data_packing(self):
292292
with tempfile.TemporaryDirectory() as tmp_dir:
293293
training_args = SFTConfig(
294294
output_dir=tmp_dir,
@@ -1400,3 +1400,35 @@ def rename_fields(example: list[dict]):
14001400
for n, param in previous_trainable_params.items():
14011401
new_param = trainer.model.get_parameter(n)
14021402
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
1403+
1404+
def test_sft_trainer_with_pretokenized_data(self):
1405+
# Get the model and dataset
1406+
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
1407+
model = AutoModelForCausalLM.from_pretrained(model_id)
1408+
tokenizer = AutoTokenizer.from_pretrained(model_id)
1409+
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
1410+
1411+
def tokenize_example(example):
1412+
return tokenizer(example["text"])
1413+
1414+
# Apply tokenization
1415+
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"])
1416+
1417+
with tempfile.TemporaryDirectory() as tmp_dir:
1418+
# Initialize the trainer
1419+
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
1420+
trainer = SFTTrainer(args=training_args, model=model, train_dataset=tokenized_dataset)
1421+
1422+
# Save the initial parameters to compare them later
1423+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1424+
1425+
# Train the model
1426+
trainer.train()
1427+
1428+
# Check that the training loss is not None
1429+
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
1430+
1431+
# Check the params have changed
1432+
for n, param in previous_trainable_params.items():
1433+
new_param = trainer.model.get_parameter(n)
1434+
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

trl/trainer/sft_trainer.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ class SFTTrainer(Trainer):
109109
- [Standard](dataset_formats#standard): Each sample contains plain text.
110110
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
111111
and content).
112+
113+
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
112114
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
113115
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
114116
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
@@ -370,14 +372,26 @@ def _prepare_dataset(
370372
if isinstance(dataset, ConstantLengthDataset):
371373
return dataset
372374

375+
# If the dataset is already preprocessed (tokenized), skip the processing steps.
376+
column_names = list(next(iter(dataset)).keys())
377+
is_processed = "input_ids" in column_names
378+
373379
# Build the kwargs for the `map` function
374380
map_kwargs = {}
375381
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
376382
map_kwargs["num_proc"] = args.dataset_num_proc
377383

378384
with PartialState().local_main_process_first():
379385
# Apply the formatting function if any
380-
if formatting_func is not None:
386+
if formatting_func is not None and is_processed:
387+
warnings.warn(
388+
"You passed a dataset that is already processed (contains an `input_ids` field) together with a "
389+
"formatting function. Therefore `formatting_func` will be ignored. Either remove the "
390+
"`formatting_func` or pass a dataset that is not already processed.",
391+
UserWarning,
392+
)
393+
394+
if formatting_func is not None and not is_processed:
381395
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
382396
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
383397

@@ -416,10 +430,16 @@ def concat_prompt_completion(example):
416430
**map_kwargs,
417431
)
418432

419-
# Tokenize the dataset
420-
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
421-
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
422-
dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
433+
# Tokenize the dataset if needed
434+
if not is_processed:
435+
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
436+
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
437+
438+
def tokenize(ex):
439+
tokenized = processing_class(ex[args.dataset_text_field])
440+
return {"input_ids": tokenized["input_ids"], "attention_mask": tokenized["attention_mask"]}
441+
442+
dataset = dataset.map(tokenize, **map_kwargs)
423443

424444
# Pack or truncate
425445
if packing:

0 commit comments

Comments
 (0)