Skip to content

Commit 492ddad

Browse files
honest forest test added
1 parent 6ea50cc commit 492ddad

File tree

1 file changed

+116
-2
lines changed

1 file changed

+116
-2
lines changed

sklearn/ensemble/tests/test_forest.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split
4747
from sklearn.svm import LinearSVC
48+
from sklearn.tree.tests.test_tree import make_trunk_classification
4849
from sklearn.tree._classes import SPARSE_SPLITTERS
4950
from sklearn.utils._testing import (
5051
_convert_container,
@@ -274,7 +275,6 @@ def test_iris_criterion(name, criterion):
274275
@pytest.mark.parametrize("criterion", ("gini", "log_loss"))
275276
def test_honest_forest_iris_criterion(criterion):
276277
# Check consistency on dataset iris.
277-
print("yo")
278278
clf = HonestRandomForestClassifier(
279279
n_estimators=10, criterion=criterion, random_state=1
280280
)
@@ -288,7 +288,121 @@ def test_honest_forest_iris_criterion(criterion):
288288
clf.fit(iris.data, iris.target)
289289
score = clf.score(iris.data, iris.target)
290290
assert score > 0.5, "Failed with criterion %s and score = %f" % (criterion, score)
291-
print("sup")
291+
292+
293+
def test_honest_forest_separation():
294+
# verify that splits by trees in an honest forest are made independent of honest
295+
# Y labels. this can't be done using the shuffle test method used in the tree
296+
# tests because in a forest using stratified sampling, the honest Y labels are
297+
# used to determine the stratification, making it impossible to both shuffle the
298+
# Y labels and keep the honest index selection fixed between trials. thus we must
299+
# use a different method to test forests, which is simply to run two trials,
300+
# shifting the honest X values in the second trial such that any split which
301+
# considered the honest Y labels must move. we also do a third trial moving some
302+
# of the structure X values to verify that moving X's under consideration would
303+
# in fact alter splits, obvious as it may seem.
304+
#
305+
# in order for this test to work, one must ensure that the honest split rejection
306+
# criteria never veto a desired split by the shadow structure tree.
307+
# the lazy way to do this is to make sure there are enough honest observations
308+
# so that there will be enough on either side of any potential structure split.
309+
# thus more dims => more samples
310+
N_TREES = 1
311+
N_DIM = 10
312+
SAMPLE_SIZE = 2098
313+
RANDOM_STATE = 1
314+
HONEST_FRACTION = 0.95
315+
STRATIFY = True
316+
317+
X, y = make_trunk_classification(
318+
n_samples=SAMPLE_SIZE,
319+
n_dim=N_DIM,
320+
n_informative=1,
321+
seed=0,
322+
mu_0=-5,
323+
mu_1=5
324+
)
325+
X_t = np.concatenate((
326+
X[: SAMPLE_SIZE // 2],
327+
X[SAMPLE_SIZE // 2 :]
328+
))
329+
y_t = np.concatenate((
330+
y[: SAMPLE_SIZE // 2],
331+
y[SAMPLE_SIZE // 2 :]
332+
))
333+
334+
335+
def perturb(X, y, indices):
336+
for d in range(N_DIM):
337+
for i in indices:
338+
if y[i] == 0 and np.random.randint(0, 2, 1) > 0:
339+
X[i, d] -= 5
340+
elif np.random.randint(0, 2, 1) > 0:
341+
X[i, d] -= 2
342+
343+
return X, y
344+
345+
346+
class Trial:
347+
def __init__(self, X, y):
348+
self.est = HonestRandomForestClassifier(
349+
n_estimators=N_TREES,
350+
max_samples=1.0,
351+
max_features=0.3,
352+
bootstrap=True,
353+
stratify=STRATIFY,
354+
n_jobs=-2,
355+
random_state=RANDOM_STATE,
356+
honest_prior="ignore",
357+
honest_fraction=HONEST_FRACTION,
358+
)
359+
self.est.fit(X, y)
360+
361+
self.tree = self.est.estimators_[0]
362+
self.honest_tree = self.tree.tree_
363+
self.structure_tree = self.honest_tree.target_tree
364+
self.honest_indices = np.sort(self.tree.honest_indices_)
365+
self.structure_indices = np.sort(self.tree.structure_indices_)
366+
self.threshold = self.honest_tree.target_tree.threshold.copy()
367+
368+
369+
trial_results = []
370+
trial_results.append(Trial(X_t, y_t))
371+
372+
# perturb honest X values; threshold should not change
373+
X_t, y_t = perturb(X_t, y_t, trial_results[0].honest_indices)
374+
375+
trial_results.append(Trial(X_t, y_t))
376+
assert np.array_equal(
377+
trial_results[0].honest_indices,
378+
trial_results[1].honest_indices
379+
)
380+
assert np.array_equal(
381+
trial_results[0].structure_indices,
382+
trial_results[1].structure_indices
383+
)
384+
assert np.array_equal(
385+
trial_results[0].threshold,
386+
trial_results[1].threshold
387+
), f"threshold1 = {trial_results[0].threshold}\nthreshold2 = {trial_results[1].threshold}"
388+
389+
390+
# perturb structure X's; threshold should change
391+
X_t, y_t = perturb(X_t, y_t, trial_results[0].structure_indices)
392+
trial_results.append(Trial(X_t, y_t))
393+
assert np.array_equal(
394+
trial_results[0].honest_indices,
395+
trial_results[2].honest_indices
396+
)
397+
assert np.array_equal(
398+
trial_results[0].structure_indices,
399+
trial_results[2].structure_indices
400+
)
401+
assert not np.array_equal(
402+
trial_results[0].threshold,
403+
trial_results[2].threshold
404+
)
405+
292406

293407
@pytest.mark.parametrize("name", FOREST_REGRESSORS)
294408
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)