Skip to content

Commit 9ea7053

Browse files
committed
add warnings for extra columns in pandas dataframes
1 parent 7570630 commit 9ea7053

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

python/interpret-core/interpret/utils/_clean_x.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,7 +1102,7 @@ def unify_columns(
11021102
# X = np.asfortranarray(X)
11031103

11041104
n_cols = X.shape[1]
1105-
if n_cols == len(feature_names_in):
1105+
if len(feature_names_in) == n_cols:
11061106
if feature_types is None:
11071107
for feature_idx, categories in requests:
11081108
yield _process_numpy_column(
@@ -1125,7 +1125,7 @@ def unify_columns(
11251125
np.bool_,
11261126
count=len(feature_types),
11271127
)
1128-
if n_cols != keep_cols.sum():
1128+
if keep_cols.sum() != n_cols:
11291129
# called under: predict
11301130
msg = f"The model has {len(keep_cols)} features, but X has {n_cols} columns"
11311131
_log.error(msg)
@@ -1154,11 +1154,14 @@ def unify_columns(
11541154
mapping = dict(zip(map(str, cols), cols))
11551155
n_cols = len(cols)
11561156
if len(mapping) != n_cols:
1157-
# this can happen if for instance one column is "0" and annother is int(0)
1158-
# Pandas also allows duplicate labels by default:
1159-
# https://pandas.pydata.org/docs/user_guide/duplicates.html#duplicates-disallow
1160-
# we can tollerate duplicate labels here, provided none of them are being used by our model
1157+
warn(
1158+
"Columns with duplicate names detected. This can happen for example if there are columns '0' and 0."
1159+
)
1160+
1161+
# We can handle duplicate names if they are not being used by the model.
11611162
counts = Counter(map(str, cols))
1163+
1164+
# sum is used to iterate outside the interpreter. The result is not used.
11621165
sum(
11631166
map(
11641167
operator.truth,
@@ -1175,6 +1178,10 @@ def unify_columns(
11751178
if feature_types is None:
11761179
if all(map(operator.contains, repeat(mapping), feature_names_in)):
11771180
# we can index by name, which is a lot faster in pandas
1181+
1182+
if len(feature_names_in) != n_cols:
1183+
warn("Extra columns present in X that are not used by the model.")
1184+
11781185
for feature_idx, categories in requests:
11791186
yield _process_pandas_column(
11801187
X[mapping[feature_names_in[feature_idx]]],
@@ -1183,7 +1190,7 @@ def unify_columns(
11831190
min_unique_continuous,
11841191
)
11851192
else:
1186-
if n_cols != len(feature_names_in):
1193+
if len(feature_names_in) != n_cols:
11871194
msg = f"The model has {len(feature_names_in)} feature names, but X has {n_cols} columns."
11881195
_log.error(msg)
11891196
raise ValueError(msg)
@@ -1209,6 +1216,10 @@ def unify_columns(
12091216
)
12101217
):
12111218
# we can index by name, which is a lot faster in pandas
1219+
1220+
if len(feature_names_in) < n_cols:
1221+
warn("Extra columns present in X that are not used by the model.")
1222+
12121223
for feature_idx, categories in requests:
12131224
yield _process_pandas_column(
12141225
X[mapping[feature_names_in[feature_idx]]],
@@ -1218,7 +1229,7 @@ def unify_columns(
12181229
)
12191230
else:
12201231
X = X.iloc
1221-
if n_cols == len(feature_names_in):
1232+
if len(feature_names_in) == n_cols:
12221233
warn(
12231234
"Pandas dataframe X does not contain all feature names. Falling back to positional columns."
12241235
)
@@ -1235,9 +1246,9 @@ def unify_columns(
12351246
np.bool_,
12361247
count=len(feature_types),
12371248
)
1238-
if n_cols != keep_cols.sum():
1249+
if keep_cols.sum() != n_cols:
12391250
# called under: predict
1240-
msg = f"The model has {len(keep_cols)} features, but X has {n_cols} columns"
1251+
msg = f"The model has {len(keep_cols)} features, but X has {n_cols} columns."
12411252
_log.error(msg)
12421253
raise ValueError(msg)
12431254
col_map = np.empty(len(keep_cols), np.int64)
@@ -1266,7 +1277,7 @@ def unify_columns(
12661277

12671278
n_cols = X.shape[1]
12681279

1269-
if n_cols == len(feature_names_in):
1280+
if len(feature_names_in) == n_cols:
12701281
if feature_types is None:
12711282
for feature_idx, categories in requests:
12721283
yield _process_sparse_column(
@@ -1289,8 +1300,8 @@ def unify_columns(
12891300
np.bool_,
12901301
count=len(feature_types),
12911302
)
1292-
if n_cols != keep_cols.sum():
1293-
msg = f"The model has {len(feature_types)} features, but X has {n_cols} columns"
1303+
if keep_cols.sum() != n_cols:
1304+
msg = f"The model has {len(feature_types)} features, but X has {n_cols} columns."
12941305
_log.error(msg)
12951306
raise ValueError(msg)
12961307
col_map = np.empty(len(feature_types), np.int64)
@@ -1315,7 +1326,7 @@ def unify_columns(
13151326
elif safe_isinstance(X, "scipy.sparse.spmatrix"):
13161327
n_cols = X.shape[1]
13171328

1318-
if n_cols == len(feature_names_in):
1329+
if len(feature_names_in) == n_cols:
13191330
if feature_types is None:
13201331
for feature_idx, categories in requests:
13211332
yield _process_sparse_column(
@@ -1338,8 +1349,8 @@ def unify_columns(
13381349
np.bool_,
13391350
count=len(feature_types),
13401351
)
1341-
if n_cols != keep_cols.sum():
1342-
msg = f"The model has {len(feature_types)} features, but X has {n_cols} columns"
1352+
if keep_cols.sum() != n_cols:
1353+
msg = f"The model has {len(feature_types)} features, but X has {n_cols} columns."
13431354
_log.error(msg)
13441355
raise ValueError(msg)
13451356
col_map = np.empty(len(feature_types), np.int64)

0 commit comments

Comments
 (0)