Skip to content

Commit 8277afe

Browse files
authored
Raise user error when any of the query instances has missing values (#403)
Looks like missing values in query instances cause weird failures when attempting to do predict()/predict_proba(). Hence, asking the user to impute the values for missing values.
2 parents e5d2e27 + 623ac9c commit 8277afe

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

dice_ml/explainer_interfaces/explainer_base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pickle
66
from abc import ABC, abstractmethod
77
from collections.abc import Iterable
8+
from typing import Any, List
89

910
import numpy as np
1011
import pandas as pd
@@ -47,13 +48,37 @@ def __init__(self, data_interface, model_interface=None):
4748
# self.cont_precisions = \
4849
# [self.data_interface.get_decimal_precisions()[ix] for ix in self.encoded_continuous_feature_indexes]
4950

51+
def _find_features_having_missing_values(
52+
self, data: Any) -> List[str]:
53+
"""Return list of features which have missing values.
54+
55+
:param data: The dataset to check.
56+
:type data: Any
57+
:return: List of feature names which have missing values.
58+
:rtype: List[str]
59+
"""
60+
if not isinstance(data, pd.DataFrame):
61+
return []
62+
63+
list_of_feature_having_missing_values = []
64+
for feature in data.columns.tolist():
65+
if np.any(data[feature].isnull()):
66+
list_of_feature_having_missing_values.append(feature)
67+
return list_of_feature_having_missing_values
68+
5069
def _validate_counterfactual_configuration(
5170
self, query_instances, total_CFs,
5271
desired_class="opposite", desired_range=None,
5372
permitted_range=None, features_to_vary="all",
5473
stopping_threshold=0.5, posthoc_sparsity_param=0.1,
5574
posthoc_sparsity_algorithm="linear", verbose=False, **kwargs):
5675

76+
if len(self._find_features_having_missing_values(query_instances)) > 0:
77+
raise UserConfigValidationException(
78+
"The query instance(s) should not have any missing values. "
79+
"Please impute the missing values and try again."
80+
)
81+
5782
if total_CFs <= 0:
5883
raise UserConfigValidationException(
5984
"The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.")

tests/test_dice_interface/test_explainer_base.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import re
2+
3+
import numpy as np
14
import pandas as pd
25
import pytest
36
from rai_test_utils.datasets.tabular import create_housing_data
@@ -501,6 +504,27 @@ def test_generate_counterfactuals_user_config_validations(
501504
method=method)
502505

503506
explainer_function = getattr(exp, explainer_function)
507+
508+
regex_pattern = re.escape(
509+
'The query instance(s) should not have any missing values. '
510+
'Please impute the missing values and try again.')
511+
512+
query_instances_missing_values_numerical = pd.DataFrame({'Categorical': ['a'], 'Numerical': [np.nan]})
513+
with pytest.raises(
514+
UserConfigValidationException,
515+
match=regex_pattern):
516+
explainer_function(
517+
query_instances=query_instances_missing_values_numerical, desired_class='opposite',
518+
total_CFs=10)
519+
520+
query_instances_missing_values_categorical = pd.DataFrame({'Categorical': [np.nan], 'Numerical': [1]})
521+
with pytest.raises(
522+
UserConfigValidationException,
523+
match=regex_pattern):
524+
explainer_function(
525+
query_instances=query_instances_missing_values_categorical, desired_class='opposite',
526+
total_CFs=10)
527+
504528
with pytest.raises(
505529
UserConfigValidationException,
506530
match=r"The number of counterfactuals generated per query instance \(total_CFs\) "

0 commit comments

Comments
 (0)