Skip to content

Commit e84470e

Browse files
authored
Merge pull request #1384 from huggingface/encoding-qol
Quality of life enhancements in encoding + patch MLM masking
2 parents 69629c4 + 78ef1a9 commit e84470e

17 files changed

+421
-256
lines changed

examples/README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ similar API between the different models.
99
| [Language Generation](#language-generation) | Conditional text generation using the auto-regressive models of the library: GPT, GPT-2, Transformer-XL and XLNet. |
1010
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
1111
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
12-
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
12+
| [Multiple Choice](#multiple-choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
1313

1414
## Language model fine-tuning
1515

@@ -283,17 +283,17 @@ The results are the following:
283283
loss = 0.04755385363816904
284284
```
285285

286-
##Multiple Choice
286+
## Multiple Choice
287287

288288
Based on the script [`run_multiple_choice.py`]().
289289

290290
#### Fine-tuning on SWAG
291291
Download [swag](https://github.com/rowanz/swagaf/tree/master/data) data
292292

293-
```
293+
```bash
294294
#training on 4 tesla V100(16GB) GPUS
295295
export SWAG_DIR=/path/to/swag_data_dir
296-
python ./examples/single_model_scripts/run_multiple_choice.py \
296+
python ./examples/run_multiple_choice.py \
297297
--model_type roberta \
298298
--task_name swag \
299299
--model_name_or_path roberta-base \

examples/run_glue.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
271271
list(filter(None, args.model_name_or_path.split('/'))).pop(),
272272
str(args.max_seq_length),
273273
str(task)))
274-
if os.path.exists(cached_features_file):
274+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
275275
logger.info("Loading features from cached file %s", cached_features_file)
276276
features = torch.load(cached_features_file)
277277
else:

examples/run_lm_finetuning.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class TextDataset(Dataset):
6161
def __init__(self, tokenizer, file_path='train', block_size=512):
6262
assert os.path.isfile(file_path)
6363
directory, filename = os.path.split(file_path)
64-
cached_features_file = os.path.join(directory, 'cached_lm_{}_{}'.format(block_size, filename))
64+
cached_features_file = os.path.join(directory, 'cached_lm_' + block_size + '_' + filename)
6565

6666
if os.path.exists(cached_features_file):
6767
logger.info("Loading features from cached file %s", cached_features_file)
@@ -77,7 +77,7 @@ def __init__(self, tokenizer, file_path='train', block_size=512):
7777
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
7878

7979
for i in range(0, len(tokenized_text)-block_size+1, block_size): # Truncate in block of block_size
80-
self.examples.append(tokenizer.add_special_tokens_single_sequence(tokenized_text[i:i+block_size]))
80+
self.examples.append(tokenizer.build_inputs_with_special_tokens(tokenized_text[i:i+block_size]))
8181
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
8282
# If your dataset is small, first you should loook for a bigger one :-) and second you
8383
# can change this behavior by adding (model specific) padding.
@@ -139,7 +139,10 @@ def mask_tokens(inputs, tokenizer, args):
139139
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
140140
labels = inputs.clone()
141141
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
142-
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).bool()
142+
probability_matrix = torch.full(labels.shape, args.mlm_probability)
143+
special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
144+
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
145+
masked_indices = torch.bernoulli(probability_matrix).bool()
143146
labels[~masked_indices] = -1 # We only compute loss on masked tokens
144147

145148
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])

examples/run_multiple_choice.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
293293
list(filter(None, args.model_name_or_path.split('/'))).pop(),
294294
str(args.max_seq_length),
295295
str(task)))
296-
if os.path.exists(cached_features_file):
296+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
297297
logger.info("Loading features from cached file %s", cached_features_file)
298298
features = torch.load(cached_features_file)
299299
else:
@@ -306,14 +306,14 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
306306
else:
307307
examples = processor.get_train_examples(args.data_dir)
308308
logger.info("Training number: %s", str(len(examples)))
309-
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer,
310-
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
311-
cls_token=tokenizer.cls_token,
312-
sep_token=tokenizer.sep_token,
313-
sep_token_extra=bool(args.model_type in ['roberta']),
314-
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
309+
features = convert_examples_to_features(
310+
examples,
311+
label_list,
312+
args.max_seq_length,
313+
tokenizer,
315314
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
316-
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
315+
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0
316+
)
317317
if args.local_rank in [-1, 0]:
318318
logger.info("Saving features into cached file %s", cached_features_file)
319319
torch.save(features, cached_features_file)
@@ -362,7 +362,7 @@ def main():
362362
help="Whether to run eval on the dev set.")
363363
parser.add_argument("--do_test", action='store_true', help='Whether to run test on the test set')
364364
parser.add_argument("--evaluate_during_training", action='store_true',
365-
help="Rul evaluation during training at each logging step.")
365+
help="Run evaluation during training at each logging step.")
366366
parser.add_argument("--do_lower_case", action='store_true',
367367
help="Set this flag if you are using an uncased model.")
368368

examples/utils_multiple_choice.py

+54-113
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16-
""" BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
16+
""" Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
1717

1818
from __future__ import absolute_import, division, print_function
1919

@@ -26,6 +26,8 @@
2626
import csv
2727
import glob
2828
import tqdm
29+
from typing import List
30+
from transformers import PreTrainedTokenizer
2931

3032

3133
logger = logging.getLogger(__name__)
@@ -34,13 +36,13 @@
3436
class InputExample(object):
3537
"""A single training/test example for multiple choice"""
3638

37-
def __init__(self, example_id, question, contexts, endings, label=None):
39+
def __init__(self, example_id, question, contexts, endings, label=None):
3840
"""Constructs a InputExample.
3941
4042
Args:
4143
example_id: Unique id for the example.
4244
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
43-
question: string. The untokenized text of the second sequence (qustion).
45+
question: string. The untokenized text of the second sequence (question).
4446
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
4547
label: (Optional) string. The label of the example. This should be
4648
specified for train and dev examples, but not for test examples.
@@ -66,7 +68,7 @@ def __init__(self,
6668
'input_mask': input_mask,
6769
'segment_ids': segment_ids
6870
}
69-
for _, input_ids, input_mask, segment_ids in choices_features
71+
for input_ids, input_mask, segment_ids in choices_features
7072
]
7173
self.label = label
7274

@@ -192,7 +194,7 @@ def _read_csv(self, input_file):
192194
return lines
193195

194196

195-
def _create_examples(self, lines, type):
197+
def _create_examples(self, lines: List[List[str]], type: str):
196198
"""Creates examples for the training and dev sets."""
197199
if type == "train" and lines[0][-1] != 'label':
198200
raise ValueError(
@@ -300,24 +302,18 @@ def normalize(truth):
300302
return examples
301303

302304

303-
def convert_examples_to_features(examples, label_list, max_seq_length,
304-
tokenizer,
305-
cls_token_at_end=False,
306-
cls_token='[CLS]',
307-
cls_token_segment_id=1,
308-
sep_token='[SEP]',
309-
sequence_a_segment_id=0,
310-
sequence_b_segment_id=1,
311-
sep_token_extra=False,
312-
pad_token_segment_id=0,
313-
pad_on_left=False,
314-
pad_token=0,
315-
mask_padding_with_zero=True):
316-
""" Loads a data file into a list of `InputBatch`s
317-
`cls_token_at_end` define the location of the CLS token:
318-
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
319-
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
320-
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
305+
def convert_examples_to_features(
306+
examples: List[InputExample],
307+
label_list: List[str],
308+
max_length: int,
309+
tokenizer: PreTrainedTokenizer,
310+
pad_token_segment_id=0,
311+
pad_on_left=False,
312+
pad_token=0,
313+
mask_padding_with_zero=True,
314+
) -> List[InputFeatures]:
315+
"""
316+
Loads a data file into a list of `InputFeatures`
321317
"""
322318

323319
label_map = {label : i for i, label in enumerate(label_list)}
@@ -328,125 +324,70 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
328324
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
329325
choices_features = []
330326
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
331-
tokens_a = tokenizer.tokenize(context)
332-
tokens_b = None
327+
text_a = context
333328
if example.question.find("_") != -1:
334-
#this is for cloze question
335-
tokens_b = tokenizer.tokenize(example.question.replace("_", ending))
336-
else:
337-
tokens_b = tokenizer.tokenize(example.question + " " + ending)
338-
# you can add seq token between quesiotn and ending. This does not make too much difference.
339-
# tokens_b = tokenizer.tokenize(example.question)
340-
# tokens_b += [sep_token]
341-
# if sep_token_extra:
342-
# tokens_b += [sep_token]
343-
# tokens_b += tokenizer.tokenize(ending)
344-
345-
special_tokens_count = 4 if sep_token_extra else 3
346-
_truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
347-
348-
# The convention in BERT is:
349-
# (a) For sequence pairs:
350-
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
351-
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
352-
# (b) For single sequences:
353-
# tokens: [CLS] the dog is hairy . [SEP]
354-
# type_ids: 0 0 0 0 0 0 0
355-
#
356-
# Where "type_ids" are used to indicate whether this is the first
357-
# sequence or the second sequence. The embedding vectors for `type=0` and
358-
# `type=1` were learned during pre-training and are added to the wordpiece
359-
# embedding vector (and position vector). This is not *strictly* necessary
360-
# since the [SEP] token unambiguously separates the sequences, but it makes
361-
# it easier for the model to learn the concept of sequences.
362-
#
363-
# For classification tasks, the first vector (corresponding to [CLS]) is
364-
# used as as the "sentence vector". Note that this only makes sense because
365-
# the entire model is fine-tuned.
366-
tokens = tokens_a + [sep_token]
367-
if sep_token_extra:
368-
# roberta uses an extra separator b/w pairs of sentences
369-
tokens += [sep_token]
370-
371-
segment_ids = [sequence_a_segment_id] * len(tokens)
372-
373-
if tokens_b:
374-
tokens += tokens_b + [sep_token]
375-
segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
376-
377-
if cls_token_at_end:
378-
tokens = tokens + [cls_token]
379-
segment_ids = segment_ids + [cls_token_segment_id]
329+
# this is for cloze question
330+
text_b = example.question.replace("_", ending)
380331
else:
381-
tokens = [cls_token] + tokens
382-
segment_ids = [cls_token_segment_id] + segment_ids
332+
text_b = example.question + " " + ending
383333

384-
input_ids = tokenizer.convert_tokens_to_ids(tokens)
334+
inputs = tokenizer.encode_plus(
335+
text_a,
336+
text_b,
337+
add_special_tokens=True,
338+
max_length=max_length,
339+
)
340+
if 'num_truncated_tokens' in inputs and inputs['num_truncated_tokens'] > 0:
341+
logger.info('Attention! you are cropping tokens (swag task is ok). '
342+
'If you are training ARC and RACE and you are poping question + options,'
343+
'you need to try to use a bigger max seq length!')
344+
345+
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
385346

386347
# The mask has 1 for real tokens and 0 for padding tokens. Only real
387348
# tokens are attended to.
388-
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
349+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
389350

390351
# Zero-pad up to the sequence length.
391-
padding_length = max_seq_length - len(input_ids)
352+
padding_length = max_length - len(input_ids)
392353
if pad_on_left:
393354
input_ids = ([pad_token] * padding_length) + input_ids
394-
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
395-
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
355+
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
356+
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
396357
else:
397358
input_ids = input_ids + ([pad_token] * padding_length)
398-
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
399-
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
359+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
360+
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
361+
362+
assert len(input_ids) == max_length
363+
assert len(attention_mask) == max_length
364+
assert len(token_type_ids) == max_length
365+
choices_features.append((input_ids, attention_mask, token_type_ids))
366+
400367

401-
assert len(input_ids) == max_seq_length
402-
assert len(input_mask) == max_seq_length
403-
assert len(segment_ids) == max_seq_length
404-
choices_features.append((tokens, input_ids, input_mask, segment_ids))
405368
label = label_map[example.label]
406369

407370
if ex_index < 2:
408371
logger.info("*** Example ***")
409372
logger.info("race_id: {}".format(example.example_id))
410-
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
373+
for choice_idx, (input_ids, attention_mask, token_type_ids) in enumerate(choices_features):
411374
logger.info("choice: {}".format(choice_idx))
412-
logger.info("tokens: {}".format(' '.join(tokens)))
413375
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
414-
logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
415-
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
376+
logger.info("attention_mask: {}".format(' '.join(map(str, attention_mask))))
377+
logger.info("token_type_ids: {}".format(' '.join(map(str, token_type_ids))))
416378
logger.info("label: {}".format(label))
417379

418380
features.append(
419381
InputFeatures(
420-
example_id = example.example_id,
421-
choices_features = choices_features,
422-
label = label
382+
example_id=example.example_id,
383+
choices_features=choices_features,
384+
label=label,
423385
)
424386
)
425387

426388
return features
427389

428390

429-
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
430-
"""Truncates a sequence pair in place to the maximum length."""
431-
432-
# This is a simple heuristic which will always truncate the longer sequence
433-
# one token at a time. This makes more sense than truncating an equal percent
434-
# of tokens from each, since if one sequence is very short then each token
435-
# that's truncated likely contains more information than a longer sequence.
436-
437-
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
438-
# length or only pop from context
439-
while True:
440-
total_length = len(tokens_a) + len(tokens_b)
441-
if total_length <= max_length:
442-
break
443-
if len(tokens_a) > len(tokens_b):
444-
tokens_a.pop()
445-
else:
446-
logger.info('Attention! you are removing from token_b (swag task is ok). '
447-
'If you are training ARC and RACE (you are poping question + options), '
448-
'you need to try to use a bigger max seq length!')
449-
tokens_b.pop()
450391

451392

452393
processors = {
@@ -456,7 +397,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
456397
}
457398

458399

459-
GLUE_TASKS_NUM_LABELS = {
400+
MULTIPLE_CHOICE_TASKS_NUM_LABELS = {
460401
"race", 4,
461402
"swag", 4,
462403
"arc", 4

transformers/data/processors/glue.py

-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def glue_convert_examples_to_features(examples, tokenizer,
8686
example.text_b,
8787
add_special_tokens=True,
8888
max_length=max_length,
89-
truncate_first_sequence=True # We're truncating the first sequence in priority
9089
)
9190
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
9291

transformers/tests/tokenization_bert_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def test_sequence_builders(self):
131131
text = tokenizer.encode("sequence builders")
132132
text_2 = tokenizer.encode("multi-sequence build")
133133

134-
encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
135-
encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
134+
encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
135+
encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
136136

137137
assert encoded_sentence == [101] + text + [102]
138138
assert encoded_pair == [101] + text + [102] + text_2 + [102]

0 commit comments

Comments
 (0)