Skip to content

Commit 539ee45

Browse files
[Examples] Replicates the new --log_level feature to all trainer-based pytorch (#12359)
* added log_level * fix comment * fixed log_level * Trigger CI * Unfied logging * simplified args for log_level
1 parent 64e6098 commit 539ee45

File tree

13 files changed

+202
-165
lines changed

13 files changed

+202
-165
lines changed

examples/pytorch/language-modeling/run_clm.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from dataclasses import dataclass, field
2929
from typing import Optional
3030

31+
import datasets
3132
from datasets import load_dataset
3233

3334
import transformers
@@ -203,18 +204,19 @@ def main():
203204
datefmt="%m/%d/%Y %H:%M:%S",
204205
handlers=[logging.StreamHandler(sys.stdout)],
205206
)
206-
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
207+
208+
log_level = training_args.get_process_log_level()
209+
logger.setLevel(log_level)
210+
datasets.utils.logging.set_verbosity(log_level)
211+
transformers.utils.logging.set_verbosity(log_level)
212+
transformers.utils.logging.enable_default_handler()
213+
transformers.utils.logging.enable_explicit_format()
207214

208215
# Log on each process the small summary:
209216
logger.warning(
210217
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
211218
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
212219
)
213-
# Set the verbosity to info of the Transformers logger (on main process only):
214-
if training_args.should_log:
215-
transformers.utils.logging.set_verbosity_info()
216-
transformers.utils.logging.enable_default_handler()
217-
transformers.utils.logging.enable_explicit_format()
218220
logger.info(f"Training/evaluation parameters {training_args}")
219221

220222
# Detecting last checkpoint.
@@ -246,15 +248,17 @@ def main():
246248
# download the dataset.
247249
if data_args.dataset_name is not None:
248250
# Downloading and loading a dataset from the hub.
249-
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
250-
if "validation" not in datasets.keys():
251-
datasets["validation"] = load_dataset(
251+
raw_datasets = load_dataset(
252+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
253+
)
254+
if "validation" not in raw_datasets.keys():
255+
raw_datasets["validation"] = load_dataset(
252256
data_args.dataset_name,
253257
data_args.dataset_config_name,
254258
split=f"train[:{data_args.validation_split_percentage}%]",
255259
cache_dir=model_args.cache_dir,
256260
)
257-
datasets["train"] = load_dataset(
261+
raw_datasets["train"] = load_dataset(
258262
data_args.dataset_name,
259263
data_args.dataset_config_name,
260264
split=f"train[{data_args.validation_split_percentage}%:]",
@@ -273,7 +277,7 @@ def main():
273277
)
274278
if extension == "txt":
275279
extension = "text"
276-
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
280+
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
277281
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
278282
# https://huggingface.co/docs/datasets/loading_datasets.html.
279283

@@ -334,9 +338,9 @@ def main():
334338
# Preprocessing the datasets.
335339
# First we tokenize all the texts.
336340
if training_args.do_train:
337-
column_names = datasets["train"].column_names
341+
column_names = raw_datasets["train"].column_names
338342
else:
339-
column_names = datasets["validation"].column_names
343+
column_names = raw_datasets["validation"].column_names
340344
text_column_name = "text" if "text" in column_names else column_names[0]
341345

342346
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
@@ -352,7 +356,7 @@ def tokenize_function(examples):
352356
)
353357
return output
354358

355-
tokenized_datasets = datasets.map(
359+
tokenized_datasets = raw_datasets.map(
356360
tokenize_function,
357361
batched=True,
358362
num_proc=data_args.preprocessing_num_workers,

examples/pytorch/language-modeling/run_mlm.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from dataclasses import dataclass, field
2929
from typing import Optional
3030

31+
import datasets
3132
from datasets import load_dataset
3233

3334
import transformers
@@ -212,18 +213,20 @@ def main():
212213
datefmt="%m/%d/%Y %H:%M:%S",
213214
handlers=[logging.StreamHandler(sys.stdout)],
214215
)
215-
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
216+
217+
log_level = training_args.get_process_log_level()
218+
logger.setLevel(log_level)
219+
datasets.utils.logging.set_verbosity(log_level)
220+
transformers.utils.logging.set_verbosity(log_level)
221+
transformers.utils.logging.enable_default_handler()
222+
transformers.utils.logging.enable_explicit_format()
216223

217224
# Log on each process the small summary:
218225
logger.warning(
219226
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
220227
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
221228
)
222229
# Set the verbosity to info of the Transformers logger (on main process only):
223-
if training_args.should_log:
224-
transformers.utils.logging.set_verbosity_info()
225-
transformers.utils.logging.enable_default_handler()
226-
transformers.utils.logging.enable_explicit_format()
227230
logger.info(f"Training/evaluation parameters {training_args}")
228231

229232
# Detecting last checkpoint.
@@ -255,15 +258,17 @@ def main():
255258
# download the dataset.
256259
if data_args.dataset_name is not None:
257260
# Downloading and loading a dataset from the hub.
258-
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
259-
if "validation" not in datasets.keys():
260-
datasets["validation"] = load_dataset(
261+
raw_datasets = load_dataset(
262+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
263+
)
264+
if "validation" not in raw_datasets.keys():
265+
raw_datasets["validation"] = load_dataset(
261266
data_args.dataset_name,
262267
data_args.dataset_config_name,
263268
split=f"train[:{data_args.validation_split_percentage}%]",
264269
cache_dir=model_args.cache_dir,
265270
)
266-
datasets["train"] = load_dataset(
271+
raw_datasets["train"] = load_dataset(
267272
data_args.dataset_name,
268273
data_args.dataset_config_name,
269274
split=f"train[{data_args.validation_split_percentage}%:]",
@@ -278,7 +283,7 @@ def main():
278283
extension = data_args.train_file.split(".")[-1]
279284
if extension == "txt":
280285
extension = "text"
281-
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
286+
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
282287
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
283288
# https://huggingface.co/docs/datasets/loading_datasets.html.
284289

@@ -337,9 +342,9 @@ def main():
337342
# Preprocessing the datasets.
338343
# First we tokenize all the texts.
339344
if training_args.do_train:
340-
column_names = datasets["train"].column_names
345+
column_names = raw_datasets["train"].column_names
341346
else:
342-
column_names = datasets["validation"].column_names
347+
column_names = raw_datasets["validation"].column_names
343348
text_column_name = "text" if "text" in column_names else column_names[0]
344349

345350
if data_args.max_seq_length is None:
@@ -377,7 +382,7 @@ def tokenize_function(examples):
377382
return_special_tokens_mask=True,
378383
)
379384

380-
tokenized_datasets = datasets.map(
385+
tokenized_datasets = raw_datasets.map(
381386
tokenize_function,
382387
batched=True,
383388
num_proc=data_args.preprocessing_num_workers,
@@ -392,7 +397,7 @@ def tokenize_function(examples):
392397
def tokenize_function(examples):
393398
return tokenizer(examples[text_column_name], return_special_tokens_mask=True)
394399

395-
tokenized_datasets = datasets.map(
400+
tokenized_datasets = raw_datasets.map(
396401
tokenize_function,
397402
batched=True,
398403
num_proc=data_args.preprocessing_num_workers,

examples/pytorch/language-modeling/run_plm.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from dataclasses import dataclass, field
2626
from typing import Optional
2727

28+
import datasets
2829
from datasets import load_dataset
2930

3031
import transformers
@@ -209,18 +210,19 @@ def main():
209210
datefmt="%m/%d/%Y %H:%M:%S",
210211
handlers=[logging.StreamHandler(sys.stdout)],
211212
)
212-
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
213+
214+
log_level = training_args.get_process_log_level()
215+
logger.setLevel(log_level)
216+
datasets.utils.logging.set_verbosity(log_level)
217+
transformers.utils.logging.set_verbosity(log_level)
218+
transformers.utils.logging.enable_default_handler()
219+
transformers.utils.logging.enable_explicit_format()
213220

214221
# Log on each process the small summary:
215222
logger.warning(
216223
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
217224
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
218225
)
219-
# Set the verbosity to info of the Transformers logger (on main process only):
220-
if training_args.should_log:
221-
transformers.utils.logging.set_verbosity_info()
222-
transformers.utils.logging.enable_default_handler()
223-
transformers.utils.logging.enable_explicit_format()
224226
logger.info(f"Training/evaluation parameters {training_args}")
225227

226228
# Detecting last checkpoint.
@@ -252,15 +254,17 @@ def main():
252254
# download the dataset.
253255
if data_args.dataset_name is not None:
254256
# Downloading and loading a dataset from the hub.
255-
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
256-
if "validation" not in datasets.keys():
257-
datasets["validation"] = load_dataset(
257+
raw_datasets = load_dataset(
258+
data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir
259+
)
260+
if "validation" not in raw_datasets.keys():
261+
raw_datasets["validation"] = load_dataset(
258262
data_args.dataset_name,
259263
data_args.dataset_config_name,
260264
split=f"train[:{data_args.validation_split_percentage}%]",
261265
cache_dir=model_args.cache_dir,
262266
)
263-
datasets["train"] = load_dataset(
267+
raw_datasets["train"] = load_dataset(
264268
data_args.dataset_name,
265269
data_args.dataset_config_name,
266270
split=f"train[{data_args.validation_split_percentage}%:]",
@@ -275,7 +279,7 @@ def main():
275279
extension = data_args.train_file.split(".")[-1]
276280
if extension == "txt":
277281
extension = "text"
278-
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
282+
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
279283
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
280284
# https://huggingface.co/docs/datasets/loading_datasets.html.
281285

@@ -334,9 +338,9 @@ def main():
334338
# Preprocessing the datasets.
335339
# First we tokenize all the texts.
336340
if training_args.do_train:
337-
column_names = datasets["train"].column_names
341+
column_names = raw_datasets["train"].column_names
338342
else:
339-
column_names = datasets["validation"].column_names
343+
column_names = raw_datasets["validation"].column_names
340344
text_column_name = "text" if "text" in column_names else column_names[0]
341345

342346
if data_args.max_seq_length > tokenizer.model_max_length:
@@ -355,7 +359,7 @@ def tokenize_function(examples):
355359
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
356360
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=max_seq_length)
357361

358-
tokenized_datasets = datasets.map(
362+
tokenized_datasets = raw_datasets.map(
359363
tokenize_function,
360364
batched=True,
361365
num_proc=data_args.preprocessing_num_workers,
@@ -368,7 +372,7 @@ def tokenize_function(examples):
368372
def tokenize_function(examples):
369373
return tokenizer(examples[text_column_name])
370374

371-
tokenized_datasets = datasets.map(
375+
tokenized_datasets = raw_datasets.map(
372376
tokenize_function,
373377
batched=True,
374378
num_proc=data_args.preprocessing_num_workers,

examples/pytorch/multiple-choice/run_swag.py

+13-12
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from dataclasses import dataclass, field
2525
from typing import Optional, Union
2626

27+
import datasets
2728
import numpy as np
2829
import torch
2930
from datasets import load_dataset
@@ -220,18 +221,18 @@ def main():
220221
datefmt="%m/%d/%Y %H:%M:%S",
221222
handlers=[logging.StreamHandler(sys.stdout)],
222223
)
223-
logger.setLevel(logging.INFO if training_args.should_log else logging.WARN)
224+
log_level = training_args.get_process_log_level()
225+
logger.setLevel(log_level)
226+
datasets.utils.logging.set_verbosity(log_level)
227+
transformers.utils.logging.set_verbosity(log_level)
228+
transformers.utils.logging.enable_default_handler()
229+
transformers.utils.logging.enable_explicit_format()
224230

225231
# Log on each process the small summary:
226232
logger.warning(
227233
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
228234
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
229235
)
230-
# Set the verbosity to info of the Transformers logger (on main process only):
231-
if training_args.should_log:
232-
transformers.utils.logging.set_verbosity_info()
233-
transformers.utils.logging.enable_default_handler()
234-
transformers.utils.logging.enable_explicit_format()
235236
logger.info(f"Training/evaluation parameters {training_args}")
236237

237238
# Detecting last checkpoint.
@@ -268,10 +269,10 @@ def main():
268269
if data_args.validation_file is not None:
269270
data_files["validation"] = data_args.validation_file
270271
extension = data_args.train_file.split(".")[-1]
271-
datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
272+
raw_datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir)
272273
else:
273274
# Downloading and loading the swag dataset from the hub.
274-
datasets = load_dataset("swag", "regular", cache_dir=model_args.cache_dir)
275+
raw_datasets = load_dataset("swag", "regular", cache_dir=model_args.cache_dir)
275276
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
276277
# https://huggingface.co/docs/datasets/loading_datasets.html.
277278

@@ -347,9 +348,9 @@ def preprocess_function(examples):
347348
return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
348349

349350
if training_args.do_train:
350-
if "train" not in datasets:
351+
if "train" not in raw_datasets:
351352
raise ValueError("--do_train requires a train dataset")
352-
train_dataset = datasets["train"]
353+
train_dataset = raw_datasets["train"]
353354
if data_args.max_train_samples is not None:
354355
train_dataset = train_dataset.select(range(data_args.max_train_samples))
355356
train_dataset = train_dataset.map(
@@ -360,9 +361,9 @@ def preprocess_function(examples):
360361
)
361362

362363
if training_args.do_eval:
363-
if "validation" not in datasets:
364+
if "validation" not in raw_datasets:
364365
raise ValueError("--do_eval requires a validation dataset")
365-
eval_dataset = datasets["validation"]
366+
eval_dataset = raw_datasets["validation"]
366367
if data_args.max_eval_samples is not None:
367368
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
368369
eval_dataset = eval_dataset.map(

0 commit comments

Comments
 (0)