13
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
- """ BERT multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
16
+ """ Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
17
17
18
18
from __future__ import absolute_import , division , print_function
19
19
26
26
import csv
27
27
import glob
28
28
import tqdm
29
+ from typing import List
30
+ from transformers import PreTrainedTokenizer
29
31
30
32
31
33
logger = logging .getLogger (__name__ )
34
36
class InputExample (object ):
35
37
"""A single training/test example for multiple choice"""
36
38
37
- def __init__ (self , example_id , question , contexts , endings , label = None ):
39
+ def __init__ (self , example_id , question , contexts , endings , label = None ):
38
40
"""Constructs a InputExample.
39
41
40
42
Args:
41
43
example_id: Unique id for the example.
42
44
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
43
- question: string. The untokenized text of the second sequence (qustion ).
45
+ question: string. The untokenized text of the second sequence (question ).
44
46
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
45
47
label: (Optional) string. The label of the example. This should be
46
48
specified for train and dev examples, but not for test examples.
@@ -66,7 +68,7 @@ def __init__(self,
66
68
'input_mask' : input_mask ,
67
69
'segment_ids' : segment_ids
68
70
}
69
- for _ , input_ids , input_mask , segment_ids in choices_features
71
+ for input_ids , input_mask , segment_ids in choices_features
70
72
]
71
73
self .label = label
72
74
@@ -192,7 +194,7 @@ def _read_csv(self, input_file):
192
194
return lines
193
195
194
196
195
- def _create_examples (self , lines , type ):
197
+ def _create_examples (self , lines : List [ List [ str ]] , type : str ):
196
198
"""Creates examples for the training and dev sets."""
197
199
if type == "train" and lines [0 ][- 1 ] != 'label' :
198
200
raise ValueError (
@@ -300,24 +302,18 @@ def normalize(truth):
300
302
return examples
301
303
302
304
303
- def convert_examples_to_features (examples , label_list , max_seq_length ,
304
- tokenizer ,
305
- cls_token_at_end = False ,
306
- cls_token = '[CLS]' ,
307
- cls_token_segment_id = 1 ,
308
- sep_token = '[SEP]' ,
309
- sequence_a_segment_id = 0 ,
310
- sequence_b_segment_id = 1 ,
311
- sep_token_extra = False ,
312
- pad_token_segment_id = 0 ,
313
- pad_on_left = False ,
314
- pad_token = 0 ,
315
- mask_padding_with_zero = True ):
316
- """ Loads a data file into a list of `InputBatch`s
317
- `cls_token_at_end` define the location of the CLS token:
318
- - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
319
- - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
320
- `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
305
+ def convert_examples_to_features (
306
+ examples : List [InputExample ],
307
+ label_list : List [str ],
308
+ max_length : int ,
309
+ tokenizer : PreTrainedTokenizer ,
310
+ pad_token_segment_id = 0 ,
311
+ pad_on_left = False ,
312
+ pad_token = 0 ,
313
+ mask_padding_with_zero = True ,
314
+ ) -> List [InputFeatures ]:
315
+ """
316
+ Loads a data file into a list of `InputFeatures`
321
317
"""
322
318
323
319
label_map = {label : i for i , label in enumerate (label_list )}
@@ -328,125 +324,70 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
328
324
logger .info ("Writing example %d of %d" % (ex_index , len (examples )))
329
325
choices_features = []
330
326
for ending_idx , (context , ending ) in enumerate (zip (example .contexts , example .endings )):
331
- tokens_a = tokenizer .tokenize (context )
332
- tokens_b = None
327
+ text_a = context
333
328
if example .question .find ("_" ) != - 1 :
334
- #this is for cloze question
335
- tokens_b = tokenizer .tokenize (example .question .replace ("_" , ending ))
336
- else :
337
- tokens_b = tokenizer .tokenize (example .question + " " + ending )
338
- # you can add seq token between quesiotn and ending. This does not make too much difference.
339
- # tokens_b = tokenizer.tokenize(example.question)
340
- # tokens_b += [sep_token]
341
- # if sep_token_extra:
342
- # tokens_b += [sep_token]
343
- # tokens_b += tokenizer.tokenize(ending)
344
-
345
- special_tokens_count = 4 if sep_token_extra else 3
346
- _truncate_seq_pair (tokens_a , tokens_b , max_seq_length - special_tokens_count )
347
-
348
- # The convention in BERT is:
349
- # (a) For sequence pairs:
350
- # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
351
- # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
352
- # (b) For single sequences:
353
- # tokens: [CLS] the dog is hairy . [SEP]
354
- # type_ids: 0 0 0 0 0 0 0
355
- #
356
- # Where "type_ids" are used to indicate whether this is the first
357
- # sequence or the second sequence. The embedding vectors for `type=0` and
358
- # `type=1` were learned during pre-training and are added to the wordpiece
359
- # embedding vector (and position vector). This is not *strictly* necessary
360
- # since the [SEP] token unambiguously separates the sequences, but it makes
361
- # it easier for the model to learn the concept of sequences.
362
- #
363
- # For classification tasks, the first vector (corresponding to [CLS]) is
364
- # used as as the "sentence vector". Note that this only makes sense because
365
- # the entire model is fine-tuned.
366
- tokens = tokens_a + [sep_token ]
367
- if sep_token_extra :
368
- # roberta uses an extra separator b/w pairs of sentences
369
- tokens += [sep_token ]
370
-
371
- segment_ids = [sequence_a_segment_id ] * len (tokens )
372
-
373
- if tokens_b :
374
- tokens += tokens_b + [sep_token ]
375
- segment_ids += [sequence_b_segment_id ] * (len (tokens_b ) + 1 )
376
-
377
- if cls_token_at_end :
378
- tokens = tokens + [cls_token ]
379
- segment_ids = segment_ids + [cls_token_segment_id ]
329
+ # this is for cloze question
330
+ text_b = example .question .replace ("_" , ending )
380
331
else :
381
- tokens = [cls_token ] + tokens
382
- segment_ids = [cls_token_segment_id ] + segment_ids
332
+ text_b = example .question + " " + ending
383
333
384
- input_ids = tokenizer .convert_tokens_to_ids (tokens )
334
+ inputs = tokenizer .encode_plus (
335
+ text_a ,
336
+ text_b ,
337
+ add_special_tokens = True ,
338
+ max_length = max_length ,
339
+ )
340
+ if 'num_truncated_tokens' in inputs and inputs ['num_truncated_tokens' ] > 0 :
341
+ logger .info ('Attention! you are cropping tokens (swag task is ok). '
342
+ 'If you are training ARC and RACE and you are poping question + options,'
343
+ 'you need to try to use a bigger max seq length!' )
344
+
345
+ input_ids , token_type_ids = inputs ["input_ids" ], inputs ["token_type_ids" ]
385
346
386
347
# The mask has 1 for real tokens and 0 for padding tokens. Only real
387
348
# tokens are attended to.
388
- input_mask = [1 if mask_padding_with_zero else 0 ] * len (input_ids )
349
+ attention_mask = [1 if mask_padding_with_zero else 0 ] * len (input_ids )
389
350
390
351
# Zero-pad up to the sequence length.
391
- padding_length = max_seq_length - len (input_ids )
352
+ padding_length = max_length - len (input_ids )
392
353
if pad_on_left :
393
354
input_ids = ([pad_token ] * padding_length ) + input_ids
394
- input_mask = ([0 if mask_padding_with_zero else 1 ] * padding_length ) + input_mask
395
- segment_ids = ([pad_token_segment_id ] * padding_length ) + segment_ids
355
+ attention_mask = ([0 if mask_padding_with_zero else 1 ] * padding_length ) + attention_mask
356
+ token_type_ids = ([pad_token_segment_id ] * padding_length ) + token_type_ids
396
357
else :
397
358
input_ids = input_ids + ([pad_token ] * padding_length )
398
- input_mask = input_mask + ([0 if mask_padding_with_zero else 1 ] * padding_length )
399
- segment_ids = segment_ids + ([pad_token_segment_id ] * padding_length )
359
+ attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1 ] * padding_length )
360
+ token_type_ids = token_type_ids + ([pad_token_segment_id ] * padding_length )
361
+
362
+ assert len (input_ids ) == max_length
363
+ assert len (attention_mask ) == max_length
364
+ assert len (token_type_ids ) == max_length
365
+ choices_features .append ((input_ids , attention_mask , token_type_ids ))
366
+
400
367
401
- assert len (input_ids ) == max_seq_length
402
- assert len (input_mask ) == max_seq_length
403
- assert len (segment_ids ) == max_seq_length
404
- choices_features .append ((tokens , input_ids , input_mask , segment_ids ))
405
368
label = label_map [example .label ]
406
369
407
370
if ex_index < 2 :
408
371
logger .info ("*** Example ***" )
409
372
logger .info ("race_id: {}" .format (example .example_id ))
410
- for choice_idx , (tokens , input_ids , input_mask , segment_ids ) in enumerate (choices_features ):
373
+ for choice_idx , (input_ids , attention_mask , token_type_ids ) in enumerate (choices_features ):
411
374
logger .info ("choice: {}" .format (choice_idx ))
412
- logger .info ("tokens: {}" .format (' ' .join (tokens )))
413
375
logger .info ("input_ids: {}" .format (' ' .join (map (str , input_ids ))))
414
- logger .info ("input_mask : {}" .format (' ' .join (map (str , input_mask ))))
415
- logger .info ("segment_ids : {}" .format (' ' .join (map (str , segment_ids ))))
376
+ logger .info ("attention_mask : {}" .format (' ' .join (map (str , attention_mask ))))
377
+ logger .info ("token_type_ids : {}" .format (' ' .join (map (str , token_type_ids ))))
416
378
logger .info ("label: {}" .format (label ))
417
379
418
380
features .append (
419
381
InputFeatures (
420
- example_id = example .example_id ,
421
- choices_features = choices_features ,
422
- label = label
382
+ example_id = example .example_id ,
383
+ choices_features = choices_features ,
384
+ label = label ,
423
385
)
424
386
)
425
387
426
388
return features
427
389
428
390
429
- def _truncate_seq_pair (tokens_a , tokens_b , max_length ):
430
- """Truncates a sequence pair in place to the maximum length."""
431
-
432
- # This is a simple heuristic which will always truncate the longer sequence
433
- # one token at a time. This makes more sense than truncating an equal percent
434
- # of tokens from each, since if one sequence is very short then each token
435
- # that's truncated likely contains more information than a longer sequence.
436
-
437
- # However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
438
- # length or only pop from context
439
- while True :
440
- total_length = len (tokens_a ) + len (tokens_b )
441
- if total_length <= max_length :
442
- break
443
- if len (tokens_a ) > len (tokens_b ):
444
- tokens_a .pop ()
445
- else :
446
- logger .info ('Attention! you are removing from token_b (swag task is ok). '
447
- 'If you are training ARC and RACE (you are poping question + options), '
448
- 'you need to try to use a bigger max seq length!' )
449
- tokens_b .pop ()
450
391
451
392
452
393
processors = {
@@ -456,7 +397,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
456
397
}
457
398
458
399
459
- GLUE_TASKS_NUM_LABELS = {
400
+ MULTIPLE_CHOICE_TASKS_NUM_LABELS = {
460
401
"race" , 4 ,
461
402
"swag" , 4 ,
462
403
"arc" , 4
0 commit comments