23
23
# You can also adapt this script on your own mlm task. Pointers for this are left as comments.
24
24
25
25
import argparse
26
+ import json
26
27
import logging
27
28
import math
28
29
import os
@@ -457,9 +458,11 @@ def group_texts(examples):
457
458
train_dataset = tokenized_datasets ["train" ]
458
459
eval_dataset = tokenized_datasets ["validation" ]
459
460
460
- # Log a few random samples from the training set:
461
- for index in random .sample (range (len (train_dataset )), 3 ):
462
- logger .info (f"Sample { index } of the training set: { train_dataset [index ]} ." )
461
+ # Conditional for small test subsets
462
+ if len (train_dataset ) > 3 :
463
+ # Log a few random samples from the training set:
464
+ for index in random .sample (range (len (train_dataset )), 3 ):
465
+ logger .info (f"Sample { index } of the training set: { train_dataset [index ]} ." )
463
466
464
467
# Data collator
465
468
# This one will take care of randomly masking the tokens.
@@ -581,7 +584,10 @@ def group_texts(examples):
581
584
582
585
if isinstance (checkpointing_steps , int ):
583
586
if completed_steps % checkpointing_steps == 0 :
584
- accelerator .save_state (f"step_{ completed_steps } " )
587
+ output_dir = f"step_{ completed_steps } "
588
+ if args .output_dir is not None :
589
+ output_dir = os .path .join (args .output_dir , output_dir )
590
+ accelerator .save_state (output_dir )
585
591
586
592
if completed_steps >= args .max_train_steps :
587
593
break
@@ -625,7 +631,10 @@ def group_texts(examples):
625
631
)
626
632
627
633
if args .checkpointing_steps == "epoch" :
628
- accelerator .save_state (f"epoch_{ epoch } " )
634
+ output_dir = f"epoch_{ epoch } "
635
+ if args .output_dir is not None :
636
+ output_dir = os .path .join (args .output_dir , output_dir )
637
+ accelerator .save_state (output_dir )
629
638
630
639
if args .output_dir is not None :
631
640
accelerator .wait_for_everyone ()
@@ -636,6 +645,9 @@ def group_texts(examples):
636
645
if args .push_to_hub :
637
646
repo .push_to_hub (commit_message = "End of training" , auto_lfs_prune = True )
638
647
648
+ with open (os .path .join (args .output_dir , "all_results.json" ), "w" ) as f :
649
+ json .dump ({"perplexity" : perplexity }, f )
650
+
639
651
640
652
if __name__ == "__main__" :
641
653
main ()
0 commit comments