Skip to content

Commit 3b469cb

Browse files
committed
updating squad for compatibility with XLNet
1 parent 8ca767f commit 3b469cb

File tree

5 files changed

+403
-63
lines changed

5 files changed

+403
-63
lines changed

examples/run_squad.py

+49-18
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141

4242
from pytorch_transformers import AdamW, WarmupLinearSchedule
4343

44-
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
44+
from utils_squad import (read_squad_examples, convert_examples_to_features,
45+
RawResult, write_predictions,
46+
RawResultExtended, write_predictions_extended)
4547

4648
# The follwing import is the official SQuAD evaluation script (2.0).
4749
# You can remove it from the dependencies if you are using this script outside of the library
@@ -66,6 +68,8 @@ def set_seed(args):
6668
if args.n_gpu > 0:
6769
torch.cuda.manual_seed_all(args.seed)
6870

71+
def to_list(tensor):
72+
return tensor.detach().cpu().tolist()
6973

7074
def train(args, train_dataset, model, tokenizer):
7175
""" Train the model """
@@ -118,10 +122,13 @@ def train(args, train_dataset, model, tokenizer):
118122
model.train()
119123
batch = tuple(t.to(args.device) for t in batch)
120124
inputs = {'input_ids': batch[0],
121-
'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
125+
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
122126
'attention_mask': batch[2],
123127
'start_positions': batch[3],
124128
'end_positions': batch[4]}
129+
if args.model_type in ['xlnet', 'xlm']:
130+
inputs.update({'cls_index': batch[5],
131+
'p_mask': batch[6]})
125132
ouputs = model(**inputs)
126133
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
127134

@@ -197,31 +204,50 @@ def evaluate(args, model, tokenizer, prefix=""):
197204
for batch in tqdm(eval_dataloader, desc="Evaluating"):
198205
model.eval()
199206
batch = tuple(t.to(args.device) for t in batch)
200-
example_indices = batch[3]
201207
with torch.no_grad():
202208
inputs = {'input_ids': batch[0],
203-
'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
204-
'attention_mask': batch[2]}
209+
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
210+
'attention_mask': batch[2]}
211+
example_indices = batch[3]
212+
if args.model_type in ['xlnet', 'xlm']:
213+
inputs.update({'cls_index': batch[4],
214+
'p_mask': batch[5]})
205215
outputs = model(**inputs)
206216
batch_start_logits, batch_end_logits = outputs[:2]
207217

208218
for i, example_index in enumerate(example_indices):
209-
start_logits = batch_start_logits[i].detach().cpu().tolist()
210-
end_logits = batch_end_logits[i].detach().cpu().tolist()
211219
eval_feature = features[example_index.item()]
212220
unique_id = int(eval_feature.unique_id)
213-
all_results.append(RawResult(unique_id=unique_id,
214-
start_logits=start_logits,
215-
end_logits=end_logits))
221+
if args.model_type in ['xlnet', 'xlm']:
222+
# XLNet uses a more complex post-processing procedure
223+
result = RawResultExtended(unique_id = unique_id,
224+
start_top_log_probs = to_list(outputs[0][i]),
225+
start_top_index = to_list(outputs[1][i]),
226+
end_top_log_probs = to_list(outputs[2][i]),
227+
end_top_index = to_list(outputs[3][i]),
228+
cls_logits = to_list(outputs[4][i]))
229+
else:
230+
result = RawResult(unique_id = unique_id,
231+
start_logits = to_list(outputs[0][i]),
232+
end_logits = to_list(outputs[1][i]))
233+
all_results.append(result)
216234

217235
# Compute predictions
218236
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
219237
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
220238
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
221-
write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length,
222-
args.do_lower_case, output_prediction_file, output_nbest_file,
223-
output_null_log_odds_file, args.verbose_logging,
224-
args.version_2_with_negative, args.null_score_diff_threshold)
239+
240+
if args.model_type in ['xlnet', 'xlm']:
241+
# XLNet uses a more complex post-processing procedure
242+
write_predictions_extended(examples, features, all_results, args.n_best_size,
243+
args.max_answer_length, output_prediction_file,
244+
output_nbest_file, output_null_log_odds_file, args.predict_file,
245+
args.start_n_top, args.end_n_top, args.version_2_with_negative)
246+
else:
247+
write_predictions(examples, features, all_results, args.n_best_size,
248+
args.max_answer_length, args.do_lower_case, output_prediction_file,
249+
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
250+
args.version_2_with_negative, args.null_score_diff_threshold)
225251

226252
# Evaluate with the official SQuAD script
227253
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
@@ -244,8 +270,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
244270
else:
245271
logger.info("Creating features from dataset file at %s", input_file)
246272
examples = read_squad_examples(input_file=input_file,
247-
is_training=not evaluate,
248-
version_2_with_negative=args.version_2_with_negative)
273+
is_training=not evaluate,
274+
version_2_with_negative=args.version_2_with_negative)
249275
features = convert_examples_to_features(examples=examples,
250276
tokenizer=tokenizer,
251277
max_seq_length=args.max_seq_length,
@@ -260,13 +286,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
260286
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
261287
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
262288
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
289+
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
290+
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
263291
if evaluate:
264292
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
265-
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
293+
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
294+
all_example_index, all_cls_index, all_p_mask)
266295
else:
267296
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
268297
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
269-
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
298+
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
299+
all_start_positions, all_end_positions,
300+
all_cls_index, all_p_mask)
270301

271302
if output_examples:
272303
return dataset, examples, features

0 commit comments

Comments
 (0)