Skip to content

Commit 319d840

Browse files
authored
examples: add keep_linebreaks option to CLM examples (#13150)
* examples: add keep_linebreaks option to text dataset loader for all CLM examples * examples: introduce new keep_linebreaks option as data argument in CLM examples
1 parent 45a8eb6 commit 319d840

File tree

4 files changed

+19
-1
lines changed

4 files changed

+19
-1
lines changed

examples/flax/language-modeling/run_clm_flax.py

+5
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ class DataTrainingArguments:
156156
default=None,
157157
metadata={"help": "The number of processes to use for the preprocessing."},
158158
)
159+
keep_linebreaks: bool = field(
160+
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
161+
)
159162

160163
def __post_init__(self):
161164
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -314,12 +317,14 @@ def main():
314317
if "validation" not in dataset.keys():
315318
dataset["validation"] = load_dataset(
316319
extension,
320+
keep_linebreaks=data_args.keep_linebreaks,
317321
data_files=data_files,
318322
split=f"train[:{data_args.validation_split_percentage}%]",
319323
cache_dir=model_args.cache_dir,
320324
)
321325
dataset["train"] = load_dataset(
322326
extension,
327+
keep_linebreaks=data_args.keep_linebreaks,
323328
data_files=data_files,
324329
split=f"train[{data_args.validation_split_percentage}%:]",
325330
cache_dir=model_args.cache_dir,

examples/pytorch/language-modeling/run_clm.py

+5
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ class DataTrainingArguments:
172172
default=None,
173173
metadata={"help": "The number of processes to use for the preprocessing."},
174174
)
175+
keep_linebreaks: bool = field(
176+
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
177+
)
175178

176179
def __post_init__(self):
177180
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -282,12 +285,14 @@ def main():
282285
if "validation" not in raw_datasets.keys():
283286
raw_datasets["validation"] = load_dataset(
284287
extension,
288+
keep_linebreaks=data_args.keep_linebreaks,
285289
data_files=data_files,
286290
split=f"train[:{data_args.validation_split_percentage}%]",
287291
cache_dir=model_args.cache_dir,
288292
)
289293
raw_datasets["train"] = load_dataset(
290294
extension,
295+
keep_linebreaks=data_args.keep_linebreaks,
291296
data_files=data_files,
292297
split=f"train[{data_args.validation_split_percentage}%:]",
293298
cache_dir=model_args.cache_dir,

examples/pytorch/language-modeling/run_clm_no_trainer.py

+5
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ def parse_args():
173173
parser.add_argument(
174174
"--overwrite_cache", type=bool, default=False, help="Overwrite the cached training and evaluation sets"
175175
)
176+
parser.add_argument(
177+
"--no_keep_linebreaks", action="store_true", help="Do not keep line breaks when using CSV/JSON/TXT files."
178+
)
176179

177180
args = parser.parse_args()
178181

@@ -257,11 +260,13 @@ def main():
257260
if "validation" not in raw_datasets.keys():
258261
raw_datasets["validation"] = load_dataset(
259262
extension,
263+
keep_linebreaks=not args.no_keep_linebreaks,
260264
data_files=data_files,
261265
split=f"train[:{args.validation_split_percentage}%]",
262266
)
263267
raw_datasets["train"] = load_dataset(
264268
extension,
269+
keep_linebreaks=not args.no_keep_linebreaks,
265270
data_files=data_files,
266271
split=f"train[{args.validation_split_percentage}%:]",
267272
)

examples/tensorflow/language-modeling/run_clm.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ class DataTrainingArguments:
186186
"value if set."
187187
},
188188
)
189+
keep_linebreaks: bool = field(
190+
default=True, metadata={"help": "Whether to keep line breaks when using CSV/JSON/TXT files or not."}
191+
)
189192

190193
def __post_init__(self):
191194
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
@@ -325,7 +328,7 @@ def main():
325328
extension = data_args.train_file.split(".")[-1]
326329
if extension == "txt":
327330
extension = "text"
328-
raw_datasets = load_dataset(extension, data_files=data_files)
331+
raw_datasets = load_dataset(extension, keep_linebreaks=data_args.keep_linebreaks, data_files=data_files)
329332
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
330333
# https://huggingface.co/docs/datasets/loading_datasets.html.
331334
# endregion

0 commit comments

Comments
 (0)