41
41
42
42
from pytorch_transformers import AdamW , WarmupLinearSchedule
43
43
44
- from utils_squad import read_squad_examples , convert_examples_to_features , RawResult , write_predictions
44
+ from utils_squad import (read_squad_examples , convert_examples_to_features ,
45
+ RawResult , write_predictions ,
46
+ RawResultExtended , write_predictions_extended )
45
47
46
48
# The follwing import is the official SQuAD evaluation script (2.0).
47
49
# You can remove it from the dependencies if you are using this script outside of the library
@@ -66,6 +68,8 @@ def set_seed(args):
66
68
if args .n_gpu > 0 :
67
69
torch .cuda .manual_seed_all (args .seed )
68
70
71
+ def to_list (tensor ):
72
+ return tensor .detach ().cpu ().tolist ()
69
73
70
74
def train (args , train_dataset , model , tokenizer ):
71
75
""" Train the model """
@@ -118,10 +122,13 @@ def train(args, train_dataset, model, tokenizer):
118
122
model .train ()
119
123
batch = tuple (t .to (args .device ) for t in batch )
120
124
inputs = {'input_ids' : batch [0 ],
121
- 'token_type_ids' : batch [ 1 ] if args .model_type in [ 'bert' , 'xlnet' ] else None , # XLM don't use segment_ids
125
+ 'token_type_ids' : None if args .model_type == 'xlm' else batch [ 1 ] , # XLM don't use segment_ids
122
126
'attention_mask' : batch [2 ],
123
127
'start_positions' : batch [3 ],
124
128
'end_positions' : batch [4 ]}
129
+ if args .model_type in ['xlnet' , 'xlm' ]:
130
+ inputs .update ({'cls_index' : batch [5 ],
131
+ 'p_mask' : batch [6 ]})
125
132
ouputs = model (** inputs )
126
133
loss = ouputs [0 ] # model outputs are always tuple in pytorch-transformers (see doc)
127
134
@@ -197,31 +204,50 @@ def evaluate(args, model, tokenizer, prefix=""):
197
204
for batch in tqdm (eval_dataloader , desc = "Evaluating" ):
198
205
model .eval ()
199
206
batch = tuple (t .to (args .device ) for t in batch )
200
- example_indices = batch [3 ]
201
207
with torch .no_grad ():
202
208
inputs = {'input_ids' : batch [0 ],
203
- 'token_type_ids' : batch [1 ] if args .model_type in ['bert' , 'xlnet' ] else None , # XLM don't use segment_ids
204
- 'attention_mask' : batch [2 ]}
209
+ 'token_type_ids' : None if args .model_type == 'xlm' else batch [1 ], # XLM don't use segment_ids
210
+ 'attention_mask' : batch [2 ]}
211
+ example_indices = batch [3 ]
212
+ if args .model_type in ['xlnet' , 'xlm' ]:
213
+ inputs .update ({'cls_index' : batch [4 ],
214
+ 'p_mask' : batch [5 ]})
205
215
outputs = model (** inputs )
206
216
batch_start_logits , batch_end_logits = outputs [:2 ]
207
217
208
218
for i , example_index in enumerate (example_indices ):
209
- start_logits = batch_start_logits [i ].detach ().cpu ().tolist ()
210
- end_logits = batch_end_logits [i ].detach ().cpu ().tolist ()
211
219
eval_feature = features [example_index .item ()]
212
220
unique_id = int (eval_feature .unique_id )
213
- all_results .append (RawResult (unique_id = unique_id ,
214
- start_logits = start_logits ,
215
- end_logits = end_logits ))
221
+ if args .model_type in ['xlnet' , 'xlm' ]:
222
+ # XLNet uses a more complex post-processing procedure
223
+ result = RawResultExtended (unique_id = unique_id ,
224
+ start_top_log_probs = to_list (outputs [0 ][i ]),
225
+ start_top_index = to_list (outputs [1 ][i ]),
226
+ end_top_log_probs = to_list (outputs [2 ][i ]),
227
+ end_top_index = to_list (outputs [3 ][i ]),
228
+ cls_logits = to_list (outputs [4 ][i ]))
229
+ else :
230
+ result = RawResult (unique_id = unique_id ,
231
+ start_logits = to_list (outputs [0 ][i ]),
232
+ end_logits = to_list (outputs [1 ][i ]))
233
+ all_results .append (result )
216
234
217
235
# Compute predictions
218
236
output_prediction_file = os .path .join (args .output_dir , "predictions_{}.json" .format (prefix ))
219
237
output_nbest_file = os .path .join (args .output_dir , "nbest_predictions_{}.json" .format (prefix ))
220
238
output_null_log_odds_file = os .path .join (args .output_dir , "null_odds_{}.json" .format (prefix ))
221
- write_predictions (examples , features , all_results , args .n_best_size , args .max_answer_length ,
222
- args .do_lower_case , output_prediction_file , output_nbest_file ,
223
- output_null_log_odds_file , args .verbose_logging ,
224
- args .version_2_with_negative , args .null_score_diff_threshold )
239
+
240
+ if args .model_type in ['xlnet' , 'xlm' ]:
241
+ # XLNet uses a more complex post-processing procedure
242
+ write_predictions_extended (examples , features , all_results , args .n_best_size ,
243
+ args .max_answer_length , output_prediction_file ,
244
+ output_nbest_file , output_null_log_odds_file , args .predict_file ,
245
+ args .start_n_top , args .end_n_top , args .version_2_with_negative )
246
+ else :
247
+ write_predictions (examples , features , all_results , args .n_best_size ,
248
+ args .max_answer_length , args .do_lower_case , output_prediction_file ,
249
+ output_nbest_file , output_null_log_odds_file , args .verbose_logging ,
250
+ args .version_2_with_negative , args .null_score_diff_threshold )
225
251
226
252
# Evaluate with the official SQuAD script
227
253
evaluate_options = EVAL_OPTS (data_file = args .predict_file ,
@@ -244,8 +270,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
244
270
else :
245
271
logger .info ("Creating features from dataset file at %s" , input_file )
246
272
examples = read_squad_examples (input_file = input_file ,
247
- is_training = not evaluate ,
248
- version_2_with_negative = args .version_2_with_negative )
273
+ is_training = not evaluate ,
274
+ version_2_with_negative = args .version_2_with_negative )
249
275
features = convert_examples_to_features (examples = examples ,
250
276
tokenizer = tokenizer ,
251
277
max_seq_length = args .max_seq_length ,
@@ -260,13 +286,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
260
286
all_input_ids = torch .tensor ([f .input_ids for f in features ], dtype = torch .long )
261
287
all_input_mask = torch .tensor ([f .input_mask for f in features ], dtype = torch .long )
262
288
all_segment_ids = torch .tensor ([f .segment_ids for f in features ], dtype = torch .long )
289
+ all_cls_index = torch .tensor ([f .cls_index for f in features ], dtype = torch .long )
290
+ all_p_mask = torch .tensor ([f .p_mask for f in features ], dtype = torch .float )
263
291
if evaluate :
264
292
all_example_index = torch .arange (all_input_ids .size (0 ), dtype = torch .long )
265
- dataset = TensorDataset (all_input_ids , all_input_mask , all_segment_ids , all_example_index )
293
+ dataset = TensorDataset (all_input_ids , all_input_mask , all_segment_ids ,
294
+ all_example_index , all_cls_index , all_p_mask )
266
295
else :
267
296
all_start_positions = torch .tensor ([f .start_position for f in features ], dtype = torch .long )
268
297
all_end_positions = torch .tensor ([f .end_position for f in features ], dtype = torch .long )
269
- dataset = TensorDataset (all_input_ids , all_input_mask , all_segment_ids , all_start_positions , all_end_positions )
298
+ dataset = TensorDataset (all_input_ids , all_input_mask , all_segment_ids ,
299
+ all_start_positions , all_end_positions ,
300
+ all_cls_index , all_p_mask )
270
301
271
302
if output_examples :
272
303
return dataset , examples , features
0 commit comments