@@ -73,10 +73,12 @@ def postprocess_qa_predictions(
73
73
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
74
74
``logging`` log level (e.g., ``logging.WARNING``)
75
75
"""
76
- assert len (predictions ) == 2 , "`predictions` should be a tuple with two elements (start_logits, end_logits)."
76
+ if len (predictions ) != 2 :
77
+ raise ValueError ("`predictions` should be a tuple with two elements (start_logits, end_logits)." )
77
78
all_start_logits , all_end_logits = predictions
78
79
79
- assert len (predictions [0 ]) == len (features ), f"Got { len (predictions [0 ])} predictions and { len (features )} features."
80
+ if len (predictions [0 ]) != len (features ):
81
+ raise ValueError (f"Got { len (predictions [0 ])} predictions and { len (features )} features." )
80
82
81
83
# Build a map example to its corresponding features.
82
84
example_id_to_index = {k : i for i , k in enumerate (examples ["id" ])}
@@ -212,7 +214,8 @@ def postprocess_qa_predictions(
212
214
213
215
# If we have an output_dir, let's save all those dicts.
214
216
if output_dir is not None :
215
- assert os .path .isdir (output_dir ), f"{ output_dir } is not a directory."
217
+ if not os .path .isdir (output_dir ):
218
+ raise EnvironmentError (f"{ output_dir } is not a directory." )
216
219
217
220
prediction_file = os .path .join (
218
221
output_dir , "predictions.json" if prefix is None else f"{ prefix } _predictions.json"
@@ -283,12 +286,12 @@ def postprocess_qa_predictions_with_beam_search(
283
286
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
284
287
``logging`` log level (e.g., ``logging.WARNING``)
285
288
"""
286
- assert len (predictions ) == 5 , "`predictions` should be a tuple with five elements."
289
+ if len (predictions ) != 5 :
290
+ raise ValueError ("`predictions` should be a tuple with five elements." )
287
291
start_top_log_probs , start_top_index , end_top_log_probs , end_top_index , cls_logits = predictions
288
292
289
- assert len (predictions [0 ]) == len (
290
- features
291
- ), f"Got { len (predictions [0 ])} predicitions and { len (features )} features."
293
+ if len (predictions [0 ]) != len (features ):
294
+ raise ValueError (f"Got { len (predictions [0 ])} predictions and { len (features )} features." )
292
295
293
296
# Build a map example to its corresponding features.
294
297
example_id_to_index = {k : i for i , k in enumerate (examples ["id" ])}
@@ -400,7 +403,8 @@ def postprocess_qa_predictions_with_beam_search(
400
403
401
404
# If we have an output_dir, let's save all those dicts.
402
405
if output_dir is not None :
403
- assert os .path .isdir (output_dir ), f"{ output_dir } is not a directory."
406
+ if not os .path .isdir (output_dir ):
407
+ raise EnvironmentError (f"{ output_dir } is not a directory." )
404
408
405
409
prediction_file = os .path .join (
406
410
output_dir , "predictions.json" if prefix is None else f"{ prefix } _predictions.json"
0 commit comments