Skip to content

Commit 15d8b12

Browse files
committed
update tokenizer - update squad example for xlnet
1 parent 3b469cb commit 15d8b12

20 files changed

+176
-116
lines changed

examples/run_glue.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
242242
# Load data features from cache or dataset file
243243
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
244244
'dev' if evaluate else 'train',
245-
list(filter(None, args.model_name.split('/'))).pop(),
245+
list(filter(None, args.model_name_or_path.split('/'))).pop(),
246246
str(args.max_seq_length),
247247
str(task)))
248248
if os.path.exists(cached_features_file):
@@ -282,8 +282,10 @@ def main():
282282
## Required parameters
283283
parser.add_argument("--data_dir", default=None, type=str, required=True,
284284
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
285-
parser.add_argument("--model_name", default=None, type=str, required=True,
286-
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
285+
parser.add_argument("--model_type", default=None, type=str, required=True,
286+
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
287+
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
288+
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
287289
parser.add_argument("--task_name", default=None, type=str, required=True,
288290
help="The name of the task to train selected in the list: " + ", ".join(processors.keys()))
289291
parser.add_argument("--output_dir", default=None, type=str, required=True,
@@ -400,15 +402,11 @@ def main():
400402
if args.local_rank not in [-1, 0]:
401403
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
402404

403-
args.model_type = ""
404-
for key in MODEL_CLASSES:
405-
if key in args.model_name.lower():
406-
args.model_type = key # take the first match in model types
407-
break
405+
args.model_type = args.model_type.lower()
408406
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
409-
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name, num_labels=num_labels, finetuning_task=args.task_name)
410-
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
411-
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
407+
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, num_labels=num_labels, finetuning_task=args.task_name)
408+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
409+
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
412410

413411
if args.local_rank == 0:
414412
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab

examples/run_squad.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,6 @@ def evaluate(args, model, tokenizer, prefix=""):
213213
inputs.update({'cls_index': batch[4],
214214
'p_mask': batch[5]})
215215
outputs = model(**inputs)
216-
batch_start_logits, batch_end_logits = outputs[:2]
217216

218217
for i, example_index in enumerate(example_indices):
219218
eval_feature = features[example_index.item()]
@@ -242,7 +241,8 @@ def evaluate(args, model, tokenizer, prefix=""):
242241
write_predictions_extended(examples, features, all_results, args.n_best_size,
243242
args.max_answer_length, output_prediction_file,
244243
output_nbest_file, output_null_log_odds_file, args.predict_file,
245-
args.start_n_top, args.end_n_top, args.version_2_with_negative)
244+
model.config.start_n_top, model.config.end_n_top,
245+
args.version_2_with_negative, tokenizer, args.verbose_logging)
246246
else:
247247
write_predictions(examples, features, all_results, args.n_best_size,
248248
args.max_answer_length, args.do_lower_case, output_prediction_file,
@@ -262,7 +262,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
262262
input_file = args.predict_file if evaluate else args.train_file
263263
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
264264
'dev' if evaluate else 'train',
265-
list(filter(None, args.model_name.split('/'))).pop(),
265+
list(filter(None, args.model_name_or_path.split('/'))).pop(),
266266
str(args.max_seq_length)))
267267
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
268268
logger.info("Loading features from cached file %s", cached_features_file)
@@ -312,8 +312,10 @@ def main():
312312
help="SQuAD json for training. E.g., train-v1.1.json")
313313
parser.add_argument("--predict_file", default=None, type=str, required=True,
314314
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
315-
parser.add_argument("--model_name", default=None, type=str, required=True,
316-
help="Bert/XLNet/XLM pre-trained model selected in the list: " + ", ".join(ALL_MODELS))
315+
parser.add_argument("--model_type", default=None, type=str, required=True,
316+
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
317+
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
318+
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
317319
parser.add_argument("--output_dir", default=None, type=str, required=True,
318320
help="The output directory where the model checkpoints and predictions will be written.")
319321

@@ -438,15 +440,11 @@ def main():
438440
if args.local_rank not in [-1, 0]:
439441
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
440442

441-
args.model_type = ""
442-
for key in MODEL_CLASSES:
443-
if key in args.model_name.lower():
444-
args.model_type = key # take the first match in model types
445-
break
443+
args.model_type = args.model_type.lower()
446444
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
447-
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name)
448-
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name, do_lower_case=args.do_lower_case)
449-
model = model_class.from_pretrained(args.model_name, from_tf=bool('.ckpt' in args.model_name), config=config)
445+
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
446+
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
447+
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
450448

451449
if args.local_rank == 0:
452450
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab

examples/test_examples.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ def test_run_glue(self):
6060
"--warmup_steps=2",
6161
"--overwrite_output_dir",
6262
"--seed=42"]
63-
model_name = "--model_name=bert-base-uncased"
64-
with patch.object(sys, 'argv', testargs + [model_name]):
63+
model_type, model_name = ("--model_type=bert",
64+
"--model_name_or_path=bert-base-uncased")
65+
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
6566
result = run_glue.main()
6667
for value in result.values():
6768
self.assertGreaterEqual(value, 0.75)
@@ -85,8 +86,9 @@ def test_run_squad(self):
8586
"--per_gpu_eval_batch_size=1",
8687
"--overwrite_output_dir",
8788
"--seed=42"]
88-
model_name = "--model_name=bert-base-uncased"
89-
with patch.object(sys, 'argv', testargs + [model_name]):
89+
model_type, model_name = ("--model_type=bert",
90+
"--model_name_or_path=bert-base-uncased")
91+
with patch.object(sys, 'argv', testargs + [model_type, model_name]):
9092
result = run_squad.main()
9193
self.assertGreaterEqual(result['f1'], 30)
9294
self.assertGreaterEqual(result['exact'], 30)

examples/utils_squad.py

+34-9
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def __init__(self,
8787
segment_ids,
8888
cls_index,
8989
p_mask,
90+
paragraph_len,
9091
start_position=None,
9192
end_position=None,
9293
is_impossible=None):
@@ -101,6 +102,7 @@ def __init__(self,
101102
self.segment_ids = segment_ids
102103
self.cls_index = cls_index
103104
self.p_mask = p_mask
105+
self.paragraph_len = paragraph_len
104106
self.start_position = start_position
105107
self.end_position = end_position
106108
self.is_impossible = is_impossible
@@ -292,6 +294,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
292294
tokens.append(all_doc_tokens[split_token_index])
293295
segment_ids.append(sequence_b_segment_id)
294296
p_mask.append(0)
297+
paragraph_len = doc_span.length
295298

296299
# SEP token
297300
tokens.append(sep_token)
@@ -385,6 +388,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
385388
segment_ids=segment_ids,
386389
cls_index=cls_index,
387390
p_mask=p_mask,
391+
paragraph_len=paragraph_len,
388392
start_position=start_position,
389393
end_position=end_position,
390394
is_impossible=span_is_impossible))
@@ -673,8 +677,9 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
673677
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
674678
max_answer_length, output_prediction_file,
675679
output_nbest_file,
676-
output_null_log_odds_file, orig_data,
677-
start_n_top, end_n_top, version_2_with_negative):
680+
output_null_log_odds_file, orig_data_file,
681+
start_n_top, end_n_top, version_2_with_negative,
682+
tokenizer, verbose_logging):
678683
""" XLNet write prediction logic (more complex than Bert's).
679684
Write final predictions to the json file and log-odds of null if needed.
680685
@@ -764,13 +769,30 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
764769
break
765770
feature = features[pred.feature_index]
766771

767-
tok_start_to_orig_index = feature.tok_start_to_orig_index
768-
tok_end_to_orig_index = feature.tok_end_to_orig_index
769-
start_orig_pos = tok_start_to_orig_index[pred.start_index]
770-
end_orig_pos = tok_end_to_orig_index[pred.end_index]
771-
772-
paragraph_text = example.paragraph_text
773-
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
772+
# XLNet un-tokenizer
773+
# Let's keep it simple for now and see if we need all this later.
774+
#
775+
# tok_start_to_orig_index = feature.tok_start_to_orig_index
776+
# tok_end_to_orig_index = feature.tok_end_to_orig_index
777+
# start_orig_pos = tok_start_to_orig_index[pred.start_index]
778+
# end_orig_pos = tok_end_to_orig_index[pred.end_index]
779+
# paragraph_text = example.paragraph_text
780+
# final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
781+
782+
# Previously used Bert untokenizer
783+
tok_tokens = feature.tokens[pred.start_index:(pred.end_index + 1)]
784+
orig_doc_start = feature.token_to_orig_map[pred.start_index]
785+
orig_doc_end = feature.token_to_orig_map[pred.end_index]
786+
orig_tokens = example.doc_tokens[orig_doc_start:(orig_doc_end + 1)]
787+
tok_text = tokenizer.convert_tokens_to_string(tok_tokens)
788+
789+
# Clean whitespace
790+
tok_text = tok_text.strip()
791+
tok_text = " ".join(tok_text.split())
792+
orig_text = " ".join(orig_tokens)
793+
794+
final_text = get_final_text(tok_text, orig_text, tokenizer.do_lower_case,
795+
verbose_logging)
774796

775797
if final_text in seen_predictions:
776798
continue
@@ -829,6 +851,9 @@ def write_predictions_extended(all_examples, all_features, all_results, n_best_s
829851
with open(output_null_log_odds_file, "w") as writer:
830852
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
831853

854+
with open(orig_data_file, "r", encoding='utf-8') as reader:
855+
orig_data = json.load(reader)["data"]
856+
832857
qid_to_has_ans = make_qid_to_has_ans(orig_data)
833858
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
834859
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]

pytorch_transformers/modeling_utils.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,9 @@ def forward(self, hidden_states, start_states=None, start_positions=None, p_mask
528528
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
529529
1.0 means token should be masked.
530530
"""
531-
slen, hsz = hidden_states.shape[-2:]
532531
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
533532
if start_positions is not None:
533+
slen, hsz = hidden_states.shape[-2:]
534534
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
535535
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
536536
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
@@ -571,7 +571,7 @@ def forward(self, hidden_states, start_states=None, start_positions=None, cls_in
571571
no dependency on end_feature so that we can obtain one single `cls_logits`
572572
for each sample
573573
"""
574-
slen, hsz = hidden_states.shape[-2:]
574+
hsz = hidden_states.shape[-1]
575575
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
576576
if start_positions is not None:
577577
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
@@ -614,12 +614,21 @@ class SQuADHead(nn.Module):
614614
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
615615
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
616616
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
617-
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
618-
Sequence of hidden-states at the last layer of the model.
619-
**mems**:
620-
list of ``torch.FloatTensor`` (one for each layer):
621-
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
622-
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
617+
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
618+
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
619+
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
620+
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
621+
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
622+
Indices for the top config.start_n_top start token possibilities (beam-search).
623+
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
624+
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
625+
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
626+
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
627+
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
628+
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
629+
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
630+
``torch.FloatTensor`` of shape ``(batch_size,)``
631+
Log probabilities for the ``is_impossible`` label of the answers.
623632
"""
624633
def __init__(self, config):
625634
super(SQuADHead, self).__init__()
@@ -667,8 +676,8 @@ def forward(self, hidden_states, start_positions=None, end_positions=None,
667676
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
668677

669678
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
670-
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
671-
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
679+
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
680+
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
672681
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
673682

674683
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)

0 commit comments

Comments
 (0)