Skip to content

Commit 319beb6

Browse files
authored
#12789 Replace assert statements with exceptions (#13909)
* #12789 Replace assert statements with exceptions * fix-copies: made copy changes to utils_qa.py in examples/pytorch/question-answering and examples/tensorflow/question-answering * minor refactor for clarity
1 parent 279ce5b commit 319beb6

File tree

3 files changed

+36
-24
lines changed

3 files changed

+36
-24
lines changed

examples/flax/question-answering/utils_qa.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ def postprocess_qa_predictions(
7373
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
7474
``logging`` log level (e.g., ``logging.WARNING``)
7575
"""
76-
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
76+
if len(predictions) != 2:
77+
raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
7778
all_start_logits, all_end_logits = predictions
7879

79-
assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
80+
if len(predictions[0]) != len(features):
81+
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
8082

8183
# Build a map example to its corresponding features.
8284
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@@ -212,7 +214,8 @@ def postprocess_qa_predictions(
212214

213215
# If we have an output_dir, let's save all those dicts.
214216
if output_dir is not None:
215-
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
217+
if not os.path.isdir(output_dir):
218+
raise EnvironmentError(f"{output_dir} is not a directory.")
216219

217220
prediction_file = os.path.join(
218221
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
@@ -283,12 +286,12 @@ def postprocess_qa_predictions_with_beam_search(
283286
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
284287
``logging`` log level (e.g., ``logging.WARNING``)
285288
"""
286-
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
289+
if len(predictions) != 5:
290+
raise ValueError("`predictions` should be a tuple with five elements.")
287291
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
288292

289-
assert len(predictions[0]) == len(
290-
features
291-
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
293+
if len(predictions[0]) != len(features):
294+
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
292295

293296
# Build a map example to its corresponding features.
294297
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@@ -400,7 +403,8 @@ def postprocess_qa_predictions_with_beam_search(
400403

401404
# If we have an output_dir, let's save all those dicts.
402405
if output_dir is not None:
403-
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
406+
if not os.path.isdir(output_dir):
407+
raise EnvironmentError(f"{output_dir} is not a directory.")
404408

405409
prediction_file = os.path.join(
406410
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"

examples/pytorch/question-answering/utils_qa.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ def postprocess_qa_predictions(
7373
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
7474
``logging`` log level (e.g., ``logging.WARNING``)
7575
"""
76-
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
76+
if len(predictions) != 2:
77+
raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
7778
all_start_logits, all_end_logits = predictions
7879

79-
assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
80+
if len(predictions[0]) != len(features):
81+
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
8082

8183
# Build a map example to its corresponding features.
8284
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@@ -212,7 +214,8 @@ def postprocess_qa_predictions(
212214

213215
# If we have an output_dir, let's save all those dicts.
214216
if output_dir is not None:
215-
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
217+
if not os.path.isdir(output_dir):
218+
raise EnvironmentError(f"{output_dir} is not a directory.")
216219

217220
prediction_file = os.path.join(
218221
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
@@ -283,12 +286,12 @@ def postprocess_qa_predictions_with_beam_search(
283286
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
284287
``logging`` log level (e.g., ``logging.WARNING``)
285288
"""
286-
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
289+
if len(predictions) != 5:
290+
raise ValueError("`predictions` should be a tuple with five elements.")
287291
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
288292

289-
assert len(predictions[0]) == len(
290-
features
291-
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
293+
if len(predictions[0]) != len(features):
294+
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
292295

293296
# Build a map example to its corresponding features.
294297
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@@ -400,7 +403,8 @@ def postprocess_qa_predictions_with_beam_search(
400403

401404
# If we have an output_dir, let's save all those dicts.
402405
if output_dir is not None:
403-
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
406+
if not os.path.isdir(output_dir):
407+
raise EnvironmentError(f"{output_dir} is not a directory.")
404408

405409
prediction_file = os.path.join(
406410
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"

examples/tensorflow/question-answering/utils_qa.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,12 @@ def postprocess_qa_predictions(
7373
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
7474
``logging`` log level (e.g., ``logging.WARNING``)
7575
"""
76-
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
76+
if len(predictions) != 2:
77+
raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
7778
all_start_logits, all_end_logits = predictions
7879

79-
assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
80+
if len(predictions[0]) != len(features):
81+
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
8082

8183
# Build a map example to its corresponding features.
8284
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@@ -212,7 +214,8 @@ def postprocess_qa_predictions(
212214

213215
# If we have an output_dir, let's save all those dicts.
214216
if output_dir is not None:
215-
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
217+
if not os.path.isdir(output_dir):
218+
raise EnvironmentError(f"{output_dir} is not a directory.")
216219

217220
prediction_file = os.path.join(
218221
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
@@ -283,12 +286,12 @@ def postprocess_qa_predictions_with_beam_search(
283286
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
284287
``logging`` log level (e.g., ``logging.WARNING``)
285288
"""
286-
assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
289+
if len(predictions) != 5:
290+
raise ValueError("`predictions` should be a tuple with five elements.")
287291
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
288292

289-
assert len(predictions[0]) == len(
290-
features
291-
), f"Got {len(predictions[0])} predicitions and {len(features)} features."
293+
if len(predictions[0]) != len(features):
294+
raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")
292295

293296
# Build a map example to its corresponding features.
294297
example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
@@ -400,7 +403,8 @@ def postprocess_qa_predictions_with_beam_search(
400403

401404
# If we have an output_dir, let's save all those dicts.
402405
if output_dir is not None:
403-
assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
406+
if not os.path.isdir(output_dir):
407+
raise EnvironmentError(f"{output_dir} is not a directory.")
404408

405409
prediction_file = os.path.join(
406410
output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"

0 commit comments

Comments
 (0)