Skip to content

Commit 78673b6

Browse files
committed
change deduplication to avoid consolidating identical objects and instead remove higher order bins that are duplicates of lower order bins
1 parent fc6f7ab commit 78673b6

File tree

5 files changed

+49
-49
lines changed

5 files changed

+49
-49
lines changed

python/interpret-core/interpret/glassbox/_ebm/_bin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
def eval_terms(X, n_samples, feature_names_in, feature_types_in, bins, term_features):
3434
# called under: predict
3535

36-
# prior to calling this function, call deduplicate_bins which will eliminate extra work in this function
36+
# prior to calling this function, call remove_extra_bins which will eliminate extra work in this function
3737

3838
# this generator function returns data in whatever order it thinks is most efficient. Normally for
3939
# mains it returns them in order, but pairs will be returned as their data completes and they can

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,12 @@
6363
from ._json import UNTESTED_from_jsonable, to_jsonable
6464
from ._tensor import remove_last, trim_tensor
6565
from ._utils import (
66-
deduplicate_bins,
6766
generate_term_names,
6867
generate_term_types,
6968
make_bag,
7069
order_terms,
7170
process_terms,
72-
remove_unused_higher_bins,
71+
remove_extra_bins,
7372
)
7473

7574
_log = logging.getLogger(__name__)
@@ -1429,8 +1428,7 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
14291428

14301429
best_iteration = np.array(best_iteration, np.int64)
14311430

1432-
remove_unused_higher_bins(term_features, bins)
1433-
deduplicate_bins(bins)
1431+
remove_extra_bins(term_features, bins)
14341432

14351433
bagged_scores = (
14361434
np.array([model[idx] for model in models], np.float64)
@@ -2445,8 +2443,7 @@ def sweep(self, terms=True, bins=True, features=False):
24452443
raise ValueError(msg)
24462444

24472445
if bins is True:
2448-
remove_unused_higher_bins(self.term_features_, self.bins_)
2449-
deduplicate_bins(self.bins_)
2446+
remove_extra_bins(self.term_features_, self.bins_)
24502447
elif bins is not False:
24512448
msg = "bins must be True or False"
24522449
_log.error(msg)

python/interpret-core/interpret/glassbox/_ebm/_merge_ebms.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
from ...utils._native import Native
1212
from ._utils import (
1313
convert_categorical_to_continuous,
14-
deduplicate_bins,
1514
generate_term_names,
1615
order_terms,
1716
process_terms,
18-
remove_unused_higher_bins,
17+
remove_extra_bins,
1918
)
2019

2120
_log = logging.getLogger(__name__)
@@ -512,7 +511,6 @@ def merge_ebms(models):
512511
new_leveled_bins.append(merged_bins)
513512
new_bins.append(new_leveled_bins)
514513
ebm.feature_types_in_ = new_feature_types
515-
deduplicate_bins(new_bins)
516514
ebm.bins_ = new_bins
517515

518516
feature_names_merged = [None] * n_features
@@ -768,9 +766,7 @@ def merge_ebms(models):
768766
]
769767

770768
# TODO: we might be able to do these operations earlier
771-
remove_unused_higher_bins(ebm.term_features_, ebm.bins_)
772-
# removing the higher order terms might allow us to eliminate some extra bins now that couldn't before
773-
deduplicate_bins(ebm.bins_)
769+
remove_extra_bins(ebm.term_features_, ebm.bins_)
774770

775771
# dependent attributes (can be re-derrived after serialization)
776772
ebm.n_features_in_ = len(ebm.bins_) # scikit-learn specified name

python/interpret-core/interpret/glassbox/_ebm/_utils.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def order_terms(term_features, *args):
256256
return ret if len(ret) >= 2 else ret[0]
257257

258258

259-
def remove_unused_higher_bins(term_features, bins):
259+
def remove_extra_bins(term_features, bins):
260260
# many features are not used in pairs, so we can simplify the model
261261
# by removing the extra higher interaction level bins
262262

@@ -267,33 +267,33 @@ def remove_unused_higher_bins(term_features, bins):
267267
highest_levels[feature_idx], len(feature_idxs)
268268
)
269269

270-
for bin_levels, max_level in zip(bins, highest_levels):
271-
del bin_levels[max_level:]
270+
for bin_levels, i in zip(bins, highest_levels):
271+
if i != 0:
272+
if len(bin_levels) == 0:
273+
raise Exception("Empty bin cannot be used in a term.")
272274

275+
i = min(i, len(bin_levels)) - 1
276+
types = set(map(type, bin_levels))
273277

274-
def deduplicate_bins(bins):
275-
# calling this function before calling score_terms allows score_terms to operate more efficiently since it'll
276-
# be able to avoid re-binning data for pairs that have already been processed in mains or other pairs since we
277-
# use the id of the bins to identify feature data that was previously binned
278+
if len(types) != 1:
279+
raise Exception("Inconsistent bin types.")
278280

279-
uniques = {}
280-
for bin_levels in bins:
281-
highest_key = None
282-
highest_idx = -1
283-
for level_idx, feature_bins in enumerate(bin_levels):
284-
if isinstance(feature_bins, dict):
285-
key = frozenset(feature_bins.items())
281+
if next(iter(types)) == dict:
282+
key = frozenset(bin_levels[i].items())
283+
i -= 1
284+
while 0 <= i:
285+
if key != frozenset(bin_levels[i].items()):
286+
break
287+
i -= 1
286288
else:
287-
key = tuple(feature_bins)
288-
if key in uniques:
289-
bin_levels[level_idx] = uniques[key]
290-
else:
291-
uniques[key] = feature_bins
292-
293-
if highest_key != key:
294-
highest_key = key
295-
highest_idx = level_idx
296-
del bin_levels[highest_idx + 1 :]
289+
key = tuple(bin_levels[i])
290+
i -= 1
291+
while 0 <= i:
292+
if key != tuple(bin_levels[i]):
293+
break
294+
i -= 1
295+
i += 2
296+
del bin_levels[i:]
297297

298298

299299
def convert_to_intervals(cuts): # pragma: no cover

python/interpret-core/tests/glassbox/ebm/test_ebm_utils.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,39 @@
55
convert_categorical_to_continuous,
66
convert_to_cuts,
77
convert_to_intervals,
8-
deduplicate_bins,
8+
remove_extra_bins,
99
make_bag,
1010
)
1111

1212

13-
def test_deduplicate_bins():
13+
def test_remove_extra_bins():
1414
bins = [
1515
[{"a": 1, "b": 2}, {"a": 2, "b": 1}, {"b": 2, "a": 1}, {"b": 2, "a": 1}],
1616
[
17-
np.array([1, 2, 3], dtype=np.float64),
1817
np.array([1, 3, 2], dtype=np.float64),
1918
np.array([1, 2, 3], dtype=np.float64),
19+
np.array([1, 2, 3], dtype=np.float64),
20+
],
21+
[
22+
np.array([9, 8, 7], dtype=np.float64),
2023
],
24+
[{"m": 1, "q": 2}],
25+
[{"r": 7, "t": 8}, {"r": 7, "t": 8}],
26+
[{"one": 1, "two": 2}],
27+
[{"never_used": 1, "never_ever": 2}],
28+
[],
2129
]
2230

23-
deduplicate_bins(bins)
31+
remove_extra_bins([(0, 1, 2, 3, 4), (5,)], bins)
2432

2533
assert len(bins[0]) == 3
26-
assert id(bins[0][0]) != id(bins[0][1])
27-
assert id(bins[0][0]) == id(bins[0][2])
28-
assert id(bins[0][1]) != id(bins[0][2])
29-
30-
assert len(bins[1]) == 3
31-
assert id(bins[1][0]) != id(bins[1][1])
32-
assert id(bins[1][0]) == id(bins[1][2])
33-
assert id(bins[1][1]) != id(bins[1][2])
34+
assert len(bins[1]) == 2
35+
assert len(bins[2]) == 1
36+
assert len(bins[3]) == 1
37+
assert len(bins[4]) == 1
38+
assert len(bins[5]) == 1
39+
assert len(bins[6]) == 0
40+
assert len(bins[7]) == 0
3441

3542

3643
def test_conversion_cut_intervals():

0 commit comments

Comments
 (0)