Skip to content

Commit 6ea50cc

Browse files
honest forest fixes, honest tree tests
1 parent 71cacf3 commit 6ea50cc

File tree

3 files changed

+189
-55
lines changed

3 files changed

+189
-55
lines changed

sklearn/ensemble/_forest.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2491,6 +2491,10 @@ class labels (multi-output problem).
24912491
Interval(Integral, 1, None, closed="left"),
24922492
]
24932493

2494+
@staticmethod
2495+
def _generate_sample_indices(tree, random_state, n_samples):
2496+
return _generate_sample_indices(tree, random_state, n_samples)
2497+
24942498
def __init__(
24952499
self,
24962500
n_estimators=100,
@@ -2540,15 +2544,17 @@ def __init__(
25402544
target_tree_kwargs=self.target_tree_kwargs,
25412545
stratify=stratify,
25422546
honest_prior=honest_prior,
2543-
honest_fraction=honest_fraction
2547+
honest_fraction=honest_fraction,
2548+
random_state=random_state
25442549
),
25452550
n_estimators=n_estimators,
25462551
estimator_params=(
25472552
"target_tree_class",
25482553
"target_tree_kwargs",
25492554
"stratify",
25502555
"honest_prior",
2551-
"honest_fraction"
2556+
"honest_fraction",
2557+
"random_state"
25522558
),
25532559
# estimator_params=(
25542560
# "criterion",

sklearn/tree/_honest_tree.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,6 @@ def _init_output_shape(self, X, y, classes=None):
342342

343343

344344
def _partition_honest_indices(self, y, sample_weight):
345-
rng = np.random.default_rng(self.target_tree.random_state)
346-
347345
# Account for bootstrapping too
348346
if sample_weight is None:
349347
structure_weight = np.ones((len(y),), dtype=np.float64)
@@ -353,6 +351,7 @@ def _partition_honest_indices(self, y, sample_weight):
353351
honest_weight = np.array(sample_weight)
354352

355353
nonzero_indices = np.where(structure_weight > 0)[0]
354+
356355
# sample the structure indices
357356
if self.stratify:
358357
ss = StratifiedShuffleSplit(
@@ -362,7 +361,9 @@ def _partition_honest_indices(self, y, sample_weight):
362361
np.zeros((len(nonzero_indices), 1)), y[nonzero_indices]
363362
):
364363
self.structure_indices_ = nonzero_indices[structure_idx]
364+
365365
else:
366+
rng = np.random.default_rng(self.random_state)
366367
self.structure_indices_ = rng.choice(
367368
nonzero_indices,
368369
int((1 - self.honest_fraction) * len(nonzero_indices)),

sklearn/tree/tests/test_tree.py

Lines changed: 178 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,115 @@
198198
}
199199

200200

201+
def make_trunk_classification(
202+
n_samples,
203+
n_dim,
204+
n_informative=1,
205+
simulation: str = "trunk",
206+
mu_0: float = 0,
207+
mu_1: float = 1,
208+
rho: int = 0,
209+
band_type: str = "ma",
210+
return_params: bool = False,
211+
mix: float = 0.5,
212+
seed=None,
213+
):
214+
if n_dim < n_informative:
215+
raise ValueError(
216+
f"Number of informative dimensions {n_informative} must be less than number "
217+
f"of dimensions, {n_dim}"
218+
)
219+
rng = np.random.default_rng(seed=seed)
220+
rng1 = np.random.default_rng(seed=seed)
221+
mu_0 = np.array([mu_0 / np.sqrt(i) for i in range(1, n_informative + 1)])
222+
mu_1 = np.array([mu_1 / np.sqrt(i) for i in range(1, n_informative + 1)])
223+
if rho != 0:
224+
if band_type == "ma":
225+
cov = _moving_avg_cov(n_informative, rho)
226+
elif band_type == "ar":
227+
cov = _autoregressive_cov(n_informative, rho)
228+
else:
229+
raise ValueError(f'Band type {band_type} must be one of "ma", or "ar".')
230+
else:
231+
cov = np.identity(n_informative)
232+
if mix < 0 or mix > 1:
233+
raise ValueError("Mix must be between 0 and 1.")
234+
# speed up computations for large multivariate normal matrix with SVD approximation
235+
if n_informative > 1000:
236+
method = "cholesky"
237+
else:
238+
method = "svd"
239+
if simulation == "trunk":
240+
X = np.vstack(
241+
(
242+
rng.multivariate_normal(mu_0, cov, n_samples // 2, method=method),
243+
rng1.multivariate_normal(mu_1, cov, n_samples // 2, method=method),
244+
)
245+
)
246+
elif simulation == "trunk_overlap":
247+
mixture_idx = rng.choice(
248+
2, n_samples // 2, replace=True, shuffle=True, p=[mix, 1 - mix]
249+
)
250+
norm_params = [[mu_0, cov], [mu_1, cov]]
251+
X_mixture = np.fromiter(
252+
(
253+
rng.multivariate_normal(*(norm_params[i]), size=1, method=method)
254+
for i in mixture_idx
255+
),
256+
dtype=np.dtype((float, n_informative)),
257+
)
258+
X_mixture_2 = np.fromiter(
259+
(
260+
rng1.multivariate_normal(*(norm_params[i]), size=1, method=method)
261+
for i in mixture_idx
262+
),
263+
dtype=np.dtype((float, n_informative)),
264+
)
265+
X = np.vstack(
266+
(
267+
X_mixture.reshape(n_samples // 2, n_informative),
268+
X_mixture_2.reshape(n_samples // 2, n_informative),
269+
)
270+
)
271+
elif simulation == "trunk_mix":
272+
mixture_idx = rng.choice(
273+
2, n_samples // 2, replace=True, shuffle=True, p=[mix, 1 - mix]
274+
)
275+
norm_params = [[mu_0, cov], [mu_1, cov]]
276+
X_mixture = np.fromiter(
277+
(
278+
rng1.multivariate_normal(*(norm_params[i]), size=1, method=method)
279+
for i in mixture_idx
280+
),
281+
dtype=np.dtype((float, n_informative)),
282+
)
283+
X = np.vstack(
284+
(
285+
rng.multivariate_normal(
286+
np.zeros(n_informative), cov, n_samples // 2, method=method
287+
),
288+
X_mixture.reshape(n_samples // 2, n_informative),
289+
)
290+
)
291+
else:
292+
raise ValueError(f"Simulation must be: trunk, trunk_overlap, trunk_mix")
293+
if n_dim > n_informative:
294+
X = np.hstack(
295+
(X, rng.normal(loc=0, scale=1, size=(X.shape[0], n_dim - n_informative)))
296+
)
297+
y = np.concatenate((np.zeros(n_samples // 2), np.ones(n_samples // 2)))
298+
if return_params:
299+
returns = [X, y]
300+
if simulation == "trunk":
301+
returns += [[mu_0, mu_1], [cov, cov]]
302+
elif simulation == "trunk-overlap":
303+
returns += [[np.zeros(n_informative), np.zeros(n_informative)], [cov, cov]]
304+
elif simulation == "trunk-mix":
305+
returns += [*list(zip(*norm_params)), X_mixture]
306+
return returns
307+
return X, y
308+
309+
201310
def assert_tree_equal(d, s, message):
202311
assert (
203312
s.node_count == d.node_count
@@ -373,24 +482,17 @@ def test_honest_iris():
373482
honest_hist, _ = np.histogram(honest, bins=len(uniques))
374483
if np.array_equal(dishonest_hist, honest_hist):
375484
leaf_eq.append(i)
376-
print(f"node {i}: ")
377-
print(f"dishonest: {dishonest.T}")
378-
print(f" honest: {honest.T}")
379-
print(f"dishonest_hist: {dishonest_hist}")
380-
print(f" honest_hist: {honest_hist}")
381485

382486
assert len(leaf_eq) != leaf_ct, (
383487
"Failed with all leaves equal: {0}".format(leaf_eq)
384488
)
385489

386490
# check accuracy
387491
score = accuracy_score(hf.target_tree.predict(iris.data), iris.target)
388-
print(f"dishonest score: {score}")
389492
assert score > 0.9, "Failed with {0}, criterion = {1} and dishonest score = {2}".format(
390493
"DecisionTreeClassifier", criterion, score
391494
)
392495
score = accuracy_score(hf.predict(iris.data), iris.target)
393-
print(f"honest score: {score}")
394496
assert score > 0.9, "Failed with {0}, criterion = {1} and honest score = {2}".format(
395497
"DecisionTreeClassifier", criterion, score
396498
)
@@ -416,22 +518,75 @@ def test_honest_iris():
416518
invalid_nodes_json = json.dumps(invalid_nodes_dict, indent=4)
417519
assert len(invalid_nodes) == 0, "Failed with invalid nodes: {0}".format(invalid_nodes_json)
418520

419-
#clf = Tree(criterion=criterion, max_features=2, random_state=0)
420-
#hf = HonestDecisionTree(clf)
421-
#hf.fit(iris.data, iris.target)
422-
#score = accuracy_score(clf.predict(iris.data), iris.target)
423-
#assert score > 0.5, "Failed with {0}, criterion = {1} and dishonest score = {2}".format(
424-
# name, criterion, score
425-
#)
426-
#score = accuracy_score(hf.predict(iris.data), iris.target)
427-
#assert score > 0.5, "Failed with {0}, criterion = {1} and honest score = {2}".format(
428-
# name, criterion, score
429-
#)
430-
#ht = HonestyTester(hf)
431-
#invalid_nodes = ht.get_invalid_nodes()
432-
#invalid_nodes_dict = [node.to_dict() if hasattr(node, 'to_dict') else node for node in invalid_nodes]
433-
#invalid_nodes_json = json.dumps(invalid_nodes_dict, indent=4)
434-
#assert len(invalid_nodes) == 0, "Failed with invalid nodes: {0}".format(invalid_nodes_json)
521+
522+
def test_honest_separation():
523+
# verify that splits are made independently of the honest data set.
524+
# we do this by eliminating randomness from the training process,
525+
# running repeated trials with honest Y labels shuffled, and verifying
526+
# that the splits do not change.
527+
N_ITER = 100
528+
SAMPLE_SIZE = 1024
529+
RANDOM_STATE = 1
530+
HONEST_PRIOR = "ignore"
531+
HONEST_FRACTION = 0.9
532+
533+
X, y = make_trunk_classification(
534+
n_samples=SAMPLE_SIZE,
535+
n_dim=1,
536+
n_informative=1,
537+
seed=0,
538+
)
539+
X_t = np.concatenate((
540+
X[: SAMPLE_SIZE // 2],
541+
X[SAMPLE_SIZE // 2 :]
542+
))
543+
y_t = np.concatenate((np.zeros(SAMPLE_SIZE // 2), np.ones(SAMPLE_SIZE // 2)))
544+
545+
546+
tree=HonestDecisionTree(
547+
target_tree_class=DecisionTreeClassifier,
548+
target_tree_kwargs={
549+
"criterion": "gini",
550+
"random_state": RANDOM_STATE
551+
},
552+
honest_prior=HONEST_PRIOR,
553+
honest_fraction=HONEST_FRACTION
554+
)
555+
tree.fit(X_t, y_t.ravel())
556+
honest_tree = tree.tree_
557+
structure_tree = honest_tree.target_tree
558+
old_threshold = structure_tree.threshold.copy()
559+
old_y = y_t.copy()
560+
561+
honest_indices = tree.honest_indices_
562+
563+
for _ in range(N_ITER):
564+
y_perm = y_t.copy()
565+
honest_shuffled = honest_indices.copy()
566+
np.random.shuffle(honest_shuffled)
567+
for i in range(len(honest_indices)):
568+
y_perm[honest_indices[i]] = y_t[honest_shuffled[i]]
569+
570+
assert(not np.array_equal(y_t, y_perm))
571+
assert(not np.array_equal(old_y, y_perm))
572+
573+
tree=HonestDecisionTree(
574+
target_tree_class=DecisionTreeClassifier,
575+
target_tree_kwargs={
576+
"criterion": "gini",
577+
"random_state": RANDOM_STATE
578+
},
579+
honest_prior=HONEST_PRIOR,
580+
honest_fraction=HONEST_FRACTION
581+
)
582+
tree.fit(X_t, y_perm.ravel())
583+
honest_tree = tree.tree_
584+
structure_tree = honest_tree.target_tree
585+
586+
assert(np.array_equal(old_threshold, structure_tree.threshold))
587+
old_threshold = structure_tree.threshold.copy()
588+
old_y = y_perm.copy()
589+
435590

436591
@pytest.mark.parametrize("name, Tree", REG_TREES.items())
437592
@pytest.mark.parametrize("criterion", REG_CRITERIONS)
@@ -467,34 +622,6 @@ def test_diabetes_underfit(name, Tree, criterion, max_depth, metric, max_loss):
467622
assert 0 < loss < max_loss
468623

469624

470-
# @skip_if_32bit
471-
# @pytest.mark.parametrize("name, Tree", {"DecisionTreeRegressor": DecisionTreeRegressor}.items())
472-
# @pytest.mark.parametrize(
473-
# "criterion, max_depth, metric, max_loss",
474-
# [
475-
# ("squared_error", 15, mean_squared_error, 60),
476-
# ("absolute_error", 20, mean_squared_error, 60),
477-
# ("friedman_mse", 15, mean_squared_error, 60),
478-
# ("poisson", 15, mean_poisson_deviance, 30),
479-
# ],
480-
# )
481-
# def test_diabetes_honest_underfit(name, Tree, criterion, max_depth, metric, max_loss):
482-
# # check consistency of trees when the depth and the number of features are
483-
# # limited
484-
485-
# reg = Tree(criterion=criterion, max_depth=max_depth, max_features=6, random_state=0)
486-
# hon = HonestDecisionTree(reg)
487-
# hon.fit(diabetes.data, diabetes.target)
488-
489-
# loss = metric(diabetes.target, reg.predict(diabetes.data))
490-
# print(f"dishonest loss: {loss}")
491-
# assert 0 < loss < max_loss
492-
493-
# hon_loss = metric(diabetes.target, hon.predict(diabetes.data))
494-
# print(f"honest loss: {hon_loss}")
495-
# assert 0 < hon_loss < max_loss
496-
497-
498625
def test_probability():
499626
# Predict probabilities using DecisionTreeClassifier.
500627

0 commit comments

Comments
 (0)