56
56
)
57
57
58
58
59
- # Cache the result
60
- has_tensorboard = is_tensorboard_available ()
61
- if has_tensorboard :
62
- try :
63
- from flax .metrics .tensorboard import SummaryWriter
64
- except ImportError as ie :
65
- has_tensorboard = False
66
- print (f"Unable to display metrics through TensorBoard because some package are not installed: { ie } " )
67
-
68
- else :
69
- print (
70
- "Unable to display metrics through TensorBoard because the package is not installed: "
71
- "Please run pip install tensorboard to enable."
72
- )
73
-
74
-
75
59
MODEL_CONFIG_CLASSES = list (FLAX_MODEL_FOR_MASKED_LM_MAPPING .keys ())
76
60
MODEL_TYPES = tuple (conf .model_type for conf in MODEL_CONFIG_CLASSES )
77
61
@@ -269,7 +253,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
269
253
return batch_idx
270
254
271
255
272
- def write_metric (summary_writer , train_metrics , eval_metrics , train_time , step ):
256
+ def write_train_metric (summary_writer , train_metrics , train_time , step ):
273
257
summary_writer .scalar ("train_time" , train_time , step )
274
258
275
259
train_metrics = get_metrics (train_metrics )
@@ -278,6 +262,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
278
262
for i , val in enumerate (vals ):
279
263
summary_writer .scalar (tag , val , step - len (vals ) + i + 1 )
280
264
265
+
266
+ def write_eval_metric (summary_writer , eval_metrics , step ):
281
267
for metric_name , value in eval_metrics .items ():
282
268
summary_writer .scalar (f"eval_{ metric_name } " , value , step )
283
269
@@ -315,10 +301,6 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
315
301
316
302
# Log on each process the small summary:
317
303
logger = logging .getLogger (__name__ )
318
- logger .warning (
319
- f"Process rank: { training_args .local_rank } , device: { training_args .device } , n_gpu: { training_args .n_gpu } "
320
- + f"distributed training: { bool (training_args .local_rank != - 1 )} , 16-bits training: { training_args .fp16 } "
321
- )
322
304
323
305
# Set the verbosity to info of the Transformers logger (on main process only):
324
306
logger .info (f"Training/evaluation parameters { training_args } " )
@@ -471,8 +453,22 @@ def group_texts(examples):
471
453
)
472
454
473
455
# Enable tensorboard only on the master node
456
+ has_tensorboard = is_tensorboard_available ()
474
457
if has_tensorboard and jax .process_index () == 0 :
475
- summary_writer = SummaryWriter (log_dir = Path (training_args .output_dir ))
458
+ try :
459
+ from flax .metrics .tensorboard import SummaryWriter
460
+
461
+ summary_writer = SummaryWriter (log_dir = Path (training_args .output_dir ))
462
+ except ImportError as ie :
463
+ has_tensorboard = False
464
+ logger .warning (
465
+ f"Unable to display metrics through TensorBoard because some package are not installed: { ie } "
466
+ )
467
+ else :
468
+ logger .warning (
469
+ "Unable to display metrics through TensorBoard because the package is not installed: "
470
+ "Please run pip install tensorboard to enable."
471
+ )
476
472
477
473
# Data collator
478
474
# This one will take care of randomly masking the tokens.
@@ -601,7 +597,7 @@ def eval_step(params, batch):
601
597
train_batch_idx = generate_batch_splits (train_samples_idx , train_batch_size )
602
598
603
599
# Gather the indexes for creating the batch and do a training step
604
- for i , batch_idx in enumerate (tqdm (train_batch_idx , desc = "Training..." , position = 1 )):
600
+ for step , batch_idx in enumerate (tqdm (train_batch_idx , desc = "Training..." , position = 1 )):
605
601
samples = [tokenized_datasets ["train" ][int (idx )] for idx in batch_idx ]
606
602
model_inputs = data_collator (samples , pad_to_multiple_of = 16 )
607
603
@@ -610,11 +606,20 @@ def eval_step(params, batch):
610
606
state , train_metric , dropout_rngs = p_train_step (state , model_inputs , dropout_rngs )
611
607
train_metrics .append (train_metric )
612
608
613
- train_time += time . time () - train_start
609
+ cur_step = epoch * num_train_samples + step
614
610
615
- epochs .write (
616
- f"Epoch... ({ epoch + 1 } /{ num_epochs } | Loss: { train_metric ['loss' ]} , Learning Rate: { train_metric ['learning_rate' ]} )"
617
- )
611
+ if cur_step % training_args .logging_steps and cur_step > 0 :
612
+ # Save metrics
613
+ train_metric = jax_utils .unreplicate (train_metric )
614
+ train_time += time .time () - train_start
615
+ if has_tensorboard and jax .process_index () == 0 :
616
+ write_train_metric (summary_writer , train_metrics , train_time , cur_step )
617
+
618
+ epochs .write (
619
+ f"Step... ({ cur_step } | Loss: { train_metric ['loss' ]} , Learning Rate: { train_metric ['learning_rate' ]} )"
620
+ )
621
+
622
+ train_metrics = []
618
623
619
624
# ======================== Evaluating ==============================
620
625
num_eval_samples = len (tokenized_datasets ["validation" ])
@@ -645,7 +650,7 @@ def eval_step(params, batch):
645
650
# Save metrics
646
651
if has_tensorboard and jax .process_index () == 0 :
647
652
cur_step = epoch * (len (tokenized_datasets ["train" ]) // train_batch_size )
648
- write_metric (summary_writer , train_metrics , eval_metrics , train_time , cur_step )
653
+ write_eval_metric (summary_writer , eval_metrics , cur_step )
649
654
650
655
# save checkpoint after each epoch and push checkpoint to the hub
651
656
if jax .process_index () == 0 :
0 commit comments