Skip to content

Commit 52fa4f5

Browse files
committed
simplify parameters to the boost and interaction detection calls
1 parent ee2ab39 commit 52fa4f5

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

python/interpret-core/interpret/glassbox/_ebm/_boost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def boost(
9797
learning_rate=intercept_learning_rate,
9898
min_samples_leaf=0,
9999
min_hessian=0.0,
100+
# TODO: should reg_alpha and reg_lambda be 0 for intercept?
100101
reg_alpha=reg_alpha,
101102
reg_lambda=reg_lambda,
102103
max_delta_step=0.0,

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,9 +1177,9 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
11771177

11781178
results = parallel(
11791179
delayed(booster)(
1180-
shm_name,
1181-
idx,
1182-
callback,
1180+
shm_name=shm_name,
1181+
bag_idx=idx,
1182+
callback=callback,
11831183
dataset=(
11841184
shared.name if shared.name is not None else shared.dataset
11851185
),
@@ -1188,25 +1188,29 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
11881188
"intercept_learning_rate"
11891189
),
11901190
intercept=bagged_intercept[idx],
1191-
bag=(bag := internal_bags[idx]),
1191+
bag=internal_bags[idx],
11921192
# TODO: instead of making these copies we should
11931193
# put init_score into the native shared dataframe
11941194
init_scores=(
11951195
init_score
11961196
if (
11971197
init_score is None
1198-
or bag is None
1199-
or np.count_nonzero(bag) == len(bag)
1198+
or internal_bags[idx] is None
1199+
or np.count_nonzero(internal_bags[idx])
1200+
== len(internal_bags[idx])
12001201
)
1201-
else init_score[bag != 0]
1202+
else init_score[internal_bags[idx] != 0]
12021203
),
12031204
term_features=term_features,
12041205
smoothing_rounds=smoothing_rounds,
12051206
# if there are no validation samples, turn off early stopping
12061207
# because the validation metric cannot improve each round
12071208
early_stopping_rounds=(
12081209
early_stopping_rounds
1209-
if (bag is not None and (bag < 0).any())
1210+
if (
1211+
internal_bags[idx] is not None
1212+
and (internal_bags[idx] < 0).any()
1213+
)
12101214
else 0
12111215
),
12121216
rng=rngs[idx],
@@ -1287,8 +1291,8 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
12871291
bagged_ranked_interaction = parallel(
12881292
# TODO: the combinations below should be selected from the non-excluded features
12891293
delayed(rank_interactions)(
1290-
shm_name,
1291-
idx,
1294+
shm_name=shm_name,
1295+
bag_idx=idx,
12921296
dataset=(
12931297
shared.name
12941298
if shared.name is not None
@@ -1412,9 +1416,9 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
14121416

14131417
results = parallel(
14141418
delayed(booster)(
1415-
shm_name,
1416-
idx,
1417-
callback,
1419+
shm_name=shm_name,
1420+
bag_idx=idx,
1421+
callback=callback,
14181422
dataset=(
14191423
shared.name
14201424
if shared.name is not None
@@ -1528,9 +1532,9 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
15281532
bagged_intercept += correction
15291533
else:
15301534
exception, intercept_change, _, _, rng = booster(
1531-
None,
1532-
0,
1533-
None,
1535+
shm_name=None,
1536+
bag_idx=0,
1537+
callback=None,
15341538
dataset=shared.dataset,
15351539
intercept_rounds=develop.get_option("n_intercept_rounds_final"),
15361540
intercept_learning_rate=develop.get_option(
@@ -1541,16 +1545,9 @@ def fit(self, X, y, sample_weight=None, bags=None, init_score=None):
15411545
init_scores=scores,
15421546
term_features=[],
15431547
n_inner_bags=0, # overwrite
1544-
min_samples_leaf=0, # overwrite
1545-
min_hessian=0.0, # overwrite
15461548
reg_alpha=0.0, # overwrite
15471549
reg_lambda=0.0, # overwrite
1548-
max_delta_step=0.0, # overwrite
1549-
gain_scale=1.0, # overwrite
1550-
max_leaves=1, # overwrite
1551-
monotone_constraints=None, # overwrite
15521550
smoothing_rounds=0,
1553-
max_rounds=0, # overwrite
15541551
early_stopping_rounds=0,
15551552
rng=rng,
15561553
acceleration=Native.AccelerationFlags_NONE, # overwrite

0 commit comments

Comments
 (0)