Skip to content

Commit bb4ac2b

Browse files
[Flax] Correct flax training scripts (#12514)
* fix_torch_device_generate_test * remove @ * add logging steps * correct training scripts * correct readme * correct
1 parent ea55675 commit bb4ac2b

File tree

4 files changed

+87
-62
lines changed

4 files changed

+87
-62
lines changed

examples/flax/language-modeling/README.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ Next we can run the example script to pretrain the model:
137137
--learning_rate="3e-4" \
138138
--warmup_steps="1000" \
139139
--overwrite_output_dir \
140-
--pad_to_max_length \
141140
--num_train_epochs="18" \
142141
--adam_beta1="0.9" \
143142
--adam_beta2="0.98" \
143+
--logging_steps="500" \
144144
--push_to_hub
145145
```
146146

@@ -233,6 +233,7 @@ Next we can run the example script to pretrain the model:
233233
--adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
234234
--overwrite_output_dir \
235235
--num_train_epochs="20" \
236+
--logging_steps="500" \
236237
--push_to_hub
237238
```
238239

@@ -368,6 +369,7 @@ Next we can run the example script to pretrain the model:
368369
--warmup_steps="5000" \
369370
--overwrite_output_dir \
370371
--num_train_epochs="10" \
372+
--logging_steps="500" \
371373
--push_to_hub
372374
```
373375

examples/flax/language-modeling/run_clm_flax.py

+33-26
Original file line numberDiff line numberDiff line change
@@ -57,22 +57,6 @@
5757

5858
logger = logging.getLogger(__name__)
5959

60-
# Cache the result
61-
has_tensorboard = is_tensorboard_available()
62-
if has_tensorboard:
63-
try:
64-
from flax.metrics.tensorboard import SummaryWriter
65-
except ImportError as ie:
66-
has_tensorboard = False
67-
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
68-
69-
else:
70-
print(
71-
"Unable to display metrics through TensorBoard because the package is not installed: "
72-
"Please run pip install tensorboard to enable."
73-
)
74-
75-
7660
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())
7761
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
7862

@@ -214,7 +198,7 @@ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuf
214198
yield batch
215199

216200

217-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
201+
def write_train_metric(summary_writer, train_metrics, train_time, step):
218202
summary_writer.scalar("train_time", train_time, step)
219203

220204
train_metrics = get_metrics(train_metrics)
@@ -223,6 +207,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
223207
for i, val in enumerate(vals):
224208
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
225209

210+
211+
def write_eval_metric(summary_writer, eval_metrics, step):
226212
for metric_name, value in eval_metrics.items():
227213
summary_writer.scalar(f"eval_{metric_name}", value, step)
228214

@@ -450,8 +436,22 @@ def group_texts(examples):
450436
eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
451437

452438
# Enable tensorboard only on the master node
439+
has_tensorboard = is_tensorboard_available()
453440
if has_tensorboard and jax.process_index() == 0:
454-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
441+
try:
442+
from flax.metrics.tensorboard import SummaryWriter
443+
444+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
445+
except ImportError as ie:
446+
has_tensorboard = False
447+
logger.warning(
448+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
449+
)
450+
else:
451+
logger.warning(
452+
"Unable to display metrics through TensorBoard because the package is not installed: "
453+
"Please run pip install tensorboard to enable."
454+
)
455455

456456
# Initialize our training
457457
rng = jax.random.PRNGKey(training_args.seed)
@@ -554,31 +554,38 @@ def eval_step(params, batch):
554554
logger.info(f" Total optimization steps = {total_train_steps}")
555555

556556
train_time = 0
557+
train_metrics = []
557558
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
558559
for epoch in epochs:
559560
# ======================== Training ================================
560561
train_start = time.time()
561562

562563
# Create sampling rng
563564
rng, input_rng = jax.random.split(rng)
564-
train_metrics = []
565565

566566
# Generate an epoch by shuffling sampling indices from the train dataset
567567
train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True)
568568
steps_per_epoch = len(train_dataset) // train_batch_size
569569
# train
570-
for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
570+
for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False):
571571
batch = next(train_loader)
572572
state, train_metric = p_train_step(state, batch)
573573
train_metrics.append(train_metric)
574574

575-
train_time += time.time() - train_start
575+
cur_step = epoch * (len(train_dataset) // train_batch_size) + step
576576

577-
train_metric = unreplicate(train_metric)
577+
if cur_step % training_args.logging_steps and cur_step > 0:
578+
# Save metrics
579+
train_metric = unreplicate(train_metric)
580+
train_time += time.time() - train_start
581+
if has_tensorboard and jax.process_index() == 0:
582+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
578583

579-
epochs.write(
580-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
581-
)
584+
epochs.write(
585+
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
586+
)
587+
588+
train_metrics = []
582589

583590
# ======================== Evaluating ==============================
584591
eval_metrics = []
@@ -608,7 +615,7 @@ def eval_step(params, batch):
608615
# Save metrics
609616
if has_tensorboard and jax.process_index() == 0:
610617
cur_step = epoch * (len(train_dataset) // train_batch_size)
611-
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
618+
write_eval_metric(summary_writer, eval_metrics, cur_step)
612619

613620
# save checkpoint after each epoch and push checkpoint to the hub
614621
if jax.process_index() == 0:

examples/flax/language-modeling/run_mlm_flax.py

+33-28
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,6 @@
5656
)
5757

5858

59-
# Cache the result
60-
has_tensorboard = is_tensorboard_available()
61-
if has_tensorboard:
62-
try:
63-
from flax.metrics.tensorboard import SummaryWriter
64-
except ImportError as ie:
65-
has_tensorboard = False
66-
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
67-
68-
else:
69-
print(
70-
"Unable to display metrics through TensorBoard because the package is not installed: "
71-
"Please run pip install tensorboard to enable."
72-
)
73-
74-
7559
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
7660
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
7761

@@ -269,7 +253,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
269253
return batch_idx
270254

271255

272-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
256+
def write_train_metric(summary_writer, train_metrics, train_time, step):
273257
summary_writer.scalar("train_time", train_time, step)
274258

275259
train_metrics = get_metrics(train_metrics)
@@ -278,6 +262,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
278262
for i, val in enumerate(vals):
279263
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
280264

265+
266+
def write_eval_metric(summary_writer, eval_metrics, step):
281267
for metric_name, value in eval_metrics.items():
282268
summary_writer.scalar(f"eval_{metric_name}", value, step)
283269

@@ -315,10 +301,6 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
315301

316302
# Log on each process the small summary:
317303
logger = logging.getLogger(__name__)
318-
logger.warning(
319-
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
320-
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
321-
)
322304

323305
# Set the verbosity to info of the Transformers logger (on main process only):
324306
logger.info(f"Training/evaluation parameters {training_args}")
@@ -471,8 +453,22 @@ def group_texts(examples):
471453
)
472454

473455
# Enable tensorboard only on the master node
456+
has_tensorboard = is_tensorboard_available()
474457
if has_tensorboard and jax.process_index() == 0:
475-
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
458+
try:
459+
from flax.metrics.tensorboard import SummaryWriter
460+
461+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
462+
except ImportError as ie:
463+
has_tensorboard = False
464+
logger.warning(
465+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
466+
)
467+
else:
468+
logger.warning(
469+
"Unable to display metrics through TensorBoard because the package is not installed: "
470+
"Please run pip install tensorboard to enable."
471+
)
476472

477473
# Data collator
478474
# This one will take care of randomly masking the tokens.
@@ -601,7 +597,7 @@ def eval_step(params, batch):
601597
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
602598

603599
# Gather the indexes for creating the batch and do a training step
604-
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
600+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
605601
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
606602
model_inputs = data_collator(samples, pad_to_multiple_of=16)
607603

@@ -610,11 +606,20 @@ def eval_step(params, batch):
610606
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
611607
train_metrics.append(train_metric)
612608

613-
train_time += time.time() - train_start
609+
cur_step = epoch * num_train_samples + step
614610

615-
epochs.write(
616-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
617-
)
611+
if cur_step % training_args.logging_steps and cur_step > 0:
612+
# Save metrics
613+
train_metric = jax_utils.unreplicate(train_metric)
614+
train_time += time.time() - train_start
615+
if has_tensorboard and jax.process_index() == 0:
616+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
617+
618+
epochs.write(
619+
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
620+
)
621+
622+
train_metrics = []
618623

619624
# ======================== Evaluating ==============================
620625
num_eval_samples = len(tokenized_datasets["validation"])
@@ -645,7 +650,7 @@ def eval_step(params, batch):
645650
# Save metrics
646651
if has_tensorboard and jax.process_index() == 0:
647652
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
648-
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
653+
write_eval_metric(summary_writer, eval_metrics, cur_step)
649654

650655
# save checkpoint after each epoch and push checkpoint to the hub
651656
if jax.process_index() == 0:

examples/flax/language-modeling/run_t5_mlm_flax.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
382382
return batch_idx
383383

384384

385-
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
385+
def write_train_metric(summary_writer, train_metrics, train_time, step):
386386
summary_writer.scalar("train_time", train_time, step)
387387

388388
train_metrics = get_metrics(train_metrics)
@@ -391,6 +391,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
391391
for i, val in enumerate(vals):
392392
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
393393

394+
395+
def write_eval_metric(summary_writer, eval_metrics, step):
394396
for metric_name, value in eval_metrics.items():
395397
summary_writer.scalar(f"eval_{metric_name}", value, step)
396398

@@ -711,7 +713,7 @@ def eval_step(params, batch):
711713
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
712714

713715
# Gather the indexes for creating the batch and do a training step
714-
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
716+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
715717
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
716718
model_inputs = data_collator(samples)
717719

@@ -720,11 +722,20 @@ def eval_step(params, batch):
720722
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
721723
train_metrics.append(train_metric)
722724

723-
train_time += time.time() - train_start
725+
cur_step = epoch * num_train_samples + step
724726

725-
epochs.write(
726-
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
727-
)
727+
if cur_step % training_args.logging_steps and cur_step > 0:
728+
# Save metrics
729+
train_metric = jax_utils.unreplicate(train_metric)
730+
train_time += time.time() - train_start
731+
if has_tensorboard and jax.process_index() == 0:
732+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
733+
734+
epochs.write(
735+
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
736+
)
737+
738+
train_metrics = []
728739

729740
# ======================== Evaluating ==============================
730741
num_eval_samples = len(tokenized_datasets["validation"])
@@ -753,7 +764,7 @@ def eval_step(params, batch):
753764
# Save metrics
754765
if has_tensorboard and jax.process_index() == 0:
755766
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
756-
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
767+
write_eval_metric(summary_writer, eval_metrics, cur_step)
757768

758769
# save checkpoint after each epoch and push checkpoint to the hub
759770
if jax.process_index() == 0:

0 commit comments

Comments
 (0)