Skip to content

Commit 54f9fbe

Browse files
authored
Rework TF trainer (#6038)
* Fully rework training/prediction loops * fix method name * Fix variable name * Fix property name * Fix scope * Fix method name * Fix tuple index * Fix tuple index * Fix indentation * Fix variable name * fix eval before log * Add drop remainder for test dataset * Fix step number + fix logging datetime * fix eval loss value * use global step instead of step + fix logging at step 0 * Fix logging datetime * Fix global_step usage * Fix breaking loop + logging datetime * Fix step in prediction loop * Fix step breaking * Fix train/test loops * Force TF at least 2.2 for the trainer * Use assert_cardinality to facilitate the dataset size computation * Log steps per epoch * Make tfds compliant with TPU * Make tfds compliant with TPU * Use TF dataset enumerate instead of the Python one * revert previous commit * Fix data_dir * Apply style * rebase on master * Address Sylvain's comments * Address Sylvain's and Lysandre comments * Trigger CI * Remove unused import
1 parent 3f94170 commit 54f9fbe

File tree

9 files changed

+248
-215
lines changed

9 files changed

+248
-215
lines changed

examples/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Examples
22

33
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
4-
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+.
4+
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.2+.
55

66
Here is the list of all our examples:
77
- **grouped by task** (all official examples work for multiple models)

examples/multiple-choice/utils_multiple_choice.py

+2
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def gen():
204204
)
205205

206206
def get_dataset(self):
207+
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
208+
207209
return self.dataset
208210

209211
def __len__(self):

examples/question-answering/run_tf_squad.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from dataclasses import dataclass, field
2222
from typing import Optional
2323

24+
import tensorflow as tf
25+
2426
from transformers import (
2527
AutoConfig,
2628
AutoTokenizer,
@@ -68,6 +70,7 @@ class DataTrainingArguments:
6870
data_dir: Optional[str] = field(
6971
default=None, metadata={"help": "The input data dir. Should contain the .json files for the SQuAD task."}
7072
)
73+
use_tfds: Optional[bool] = field(default=True, metadata={"help": "If TFDS should be used or not."})
7174
max_seq_length: int = field(
7275
default=128,
7376
metadata={
@@ -170,7 +173,7 @@ def main():
170173
)
171174

172175
# Get datasets
173-
if not data_args.data_dir:
176+
if data_args.use_tfds:
174177
if data_args.version_2_with_negative:
175178
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically")
176179

@@ -179,7 +182,7 @@ def main():
179182
except ImportError:
180183
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
181184

182-
tfds_examples = tfds.load("squad")
185+
tfds_examples = tfds.load("squad", data_dir=data_args.data_dir)
183186
train_examples = (
184187
SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=False)
185188
if training_args.do_train
@@ -209,6 +212,8 @@ def main():
209212
else None
210213
)
211214

215+
train_dataset = train_dataset.apply(tf.data.experimental.assert_cardinality(len(train_examples)))
216+
212217
eval_dataset = (
213218
squad_convert_examples_to_features(
214219
examples=eval_examples,
@@ -223,6 +228,8 @@ def main():
223228
else None
224229
)
225230

231+
eval_dataset = eval_dataset.apply(tf.data.experimental.assert_cardinality(len(eval_examples)))
232+
226233
# Initialize our Trainer
227234
trainer = TFTrainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,)
228235

examples/text-classification/run_tf_glue.py

+22-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict, Optional
1010

1111
import numpy as np
12+
import tensorflow as tf
1213
import tensorflow_datasets as tfds
1314

1415
from transformers import (
@@ -35,7 +36,11 @@ class Split(Enum):
3536

3637

3738
def get_tfds(
38-
task_name: str, tokenizer: PreTrainedTokenizer, max_seq_length: Optional[int] = None, mode: Split = Split.train
39+
task_name: str,
40+
tokenizer: PreTrainedTokenizer,
41+
max_seq_length: Optional[int] = None,
42+
mode: Split = Split.train,
43+
data_dir: str = None,
3944
):
4045
if task_name == "mnli-mm" and mode == Split.dev:
4146
tfds_name = "mnli_mismatched"
@@ -50,9 +55,11 @@ def get_tfds(
5055
else:
5156
tfds_name = task_name
5257

53-
ds = tfds.load("glue/" + tfds_name, split=mode.value)
58+
ds, info = tfds.load("glue/" + tfds_name, split=mode.value, with_info=True, data_dir=data_dir)
59+
ds = glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
60+
ds = ds.apply(tf.data.experimental.assert_cardinality(info.splits[mode.value].num_examples))
5461

55-
return glue_convert_examples_to_features(ds, tokenizer, max_seq_length, task_name)
62+
return ds
5663

5764

5865
logger = logging.getLogger(__name__)
@@ -69,6 +76,7 @@ class GlueDataTrainingArguments:
6976
"""
7077

7178
task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(glue_processors.keys())})
79+
data_dir: Optional[str] = field(default=None, metadata={"help": "The input/output data dir for TFDS."})
7280
max_seq_length: int = field(
7381
default=128,
7482
metadata={
@@ -171,13 +179,22 @@ def main():
171179

172180
# Get datasets
173181
train_dataset = (
174-
get_tfds(task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length)
182+
get_tfds(
183+
task_name=data_args.task_name,
184+
tokenizer=tokenizer,
185+
max_seq_length=data_args.max_seq_length,
186+
data_dir=data_args.data_dir,
187+
)
175188
if training_args.do_train
176189
else None
177190
)
178191
eval_dataset = (
179192
get_tfds(
180-
task_name=data_args.task_name, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, mode=Split.dev
193+
task_name=data_args.task_name,
194+
tokenizer=tokenizer,
195+
max_seq_length=data_args.max_seq_length,
196+
mode=Split.dev,
197+
data_dir=data_args.data_dir,
181198
)
182199
if training_args.do_eval
183200
else None

examples/token-classification/run_tf_ner.py

-6
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import logging
1919
import os
20-
import warnings
2120
from dataclasses import dataclass, field
2221
from typing import Dict, List, Optional, Tuple
2322

@@ -185,11 +184,6 @@ def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[L
185184

186185
for i in range(batch_size):
187186
for j in range(seq_len):
188-
if label_ids[i, j] == -1:
189-
label_ids[i, j] = -100
190-
warnings.warn(
191-
"Using `-1` to mask the loss for the token is depreciated. Please use `-100` instead."
192-
)
193187
if label_ids[i, j] != -100:
194188
out_label_list[i].append(label_map[label_ids[i][j]])
195189
preds_list[i].append(label_map[preds[i][j]])

examples/token-classification/utils_ner.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class TFNerDataset:
146146
"""
147147

148148
features: List[InputFeatures]
149-
pad_token_label_id: int = -1
149+
pad_token_label_id: int = -100
150150
# Use cross entropy ignore_index as padding label id so that only
151151
# real label ids contribute to the loss later.
152152

@@ -221,6 +221,8 @@ def gen():
221221
)
222222

223223
def get_dataset(self):
224+
self.dataset = self.dataset.apply(tf.data.experimental.assert_cardinality(len(self.features)))
225+
224226
return self.dataset
225227

226228
def __len__(self):

src/transformers/modeling_tf_utils.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import functools
1818
import logging
1919
import os
20-
import warnings
2120
from typing import Dict, List, Optional, Union
2221

2322
import h5py
@@ -174,11 +173,7 @@ def compute_loss(self, labels, logits):
174173
)
175174
# make sure only labels that are not equal to -100
176175
# are taken into account as loss
177-
if tf.math.reduce_any(labels == -1).numpy() is True:
178-
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
179-
active_loss = tf.reshape(labels, (-1,)) != -1
180-
else:
181-
active_loss = tf.reshape(labels, (-1,)) != -100
176+
active_loss = tf.reshape(labels, (-1,)) != -100
182177
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
183178
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
184179

0 commit comments

Comments
 (0)