Skip to content

Commit 5af6c0b

Browse files
commented honest tree
1 parent 3b16b8f commit 5af6c0b

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

sklearn/tree/_honest_tree.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
1+
# Authors: Haoyin Xu <haoyinxu@gmail.com>
2+
# Samuel Carliles <scarlil1@jhu.edu>
3+
#
14
# Adopted from: https://github.com/neurodata/honest-forests
25

6+
# An honest classification tree implemented by inheriting BaseDecisionTree and
7+
# including the honesty module. The general idea is that:
8+
#
9+
# 1. The interface looks mostly like a regular DecisionTree, and we inherit as
10+
# much of the implementation as we can.
11+
# 2. Rather than actually being our own tree however, we have a target tree for
12+
# learning the structure which is just a regular DecisionTree trained on the
13+
# structure sample, and an honesty instance which grows the shadow tree described
14+
# in the honesty module.
15+
316
import numpy as np
417
from numpy import float32 as DTYPE
518

@@ -19,7 +32,7 @@
1932
import inspect
2033

2134

22-
# note to self: max_n_classes is the maximum number of classes observed
35+
# note: max_n_classes is the maximum number of classes observed
2336
# in any response variable dimension
2437
class HonestDecisionTree(BaseDecisionTree):
2538
_parameter_constraints: dict = {
@@ -55,6 +68,9 @@ def __init__(
5568
if target_tree_class is not None:
5669
HonestDecisionTree._target_tree_hack(self, target_tree_class, **target_tree_kwargs)
5770

71+
# In order to inherit behavior from BaseDecisionTree, we must satisfy a lot of
72+
# pythonic introspective attribute assumptions. This was the lowest effort way
73+
# that came to mind.
5874
@staticmethod
5975
def _target_tree_hack(honest_tree, target_tree_class, **kwargs):
6076
honest_tree.target_tree_class = target_tree_class
@@ -154,21 +170,6 @@ def fit(
154170
target_bta.sample_weight
155171
)
156172

157-
# # compute the honest sample indices
158-
# structure_mask = np.ones(len(target_bta.y), dtype=bool)
159-
# structure_mask[self.honest_indices_] = False
160-
161-
# if target_bta.sample_weight is None:
162-
# sample_weight_leaves = np.ones((len(target_bta.y),), dtype=np.float64)
163-
# else:
164-
# sample_weight_leaves = np.array(target_bta.sample_weight)
165-
# sample_weight_leaves[structure_mask] = 0
166-
167-
# # determine the honest indices using the sample weight
168-
# nonzero_indices = np.where(sample_weight_leaves > 0)[0]
169-
# # sample the structure indices
170-
# self.honest_indices_ = nonzero_indices
171-
172173
# create honesty, set up listeners in target tree
173174
self.honesty = Honesty(
174175
target_bta.X,
@@ -200,6 +201,7 @@ def fit(
200201
check_input=check_input
201202
)
202203

204+
# more pythonic introspection minutiae
203205
setattr(
204206
self,
205207
"classes_",
@@ -219,9 +221,9 @@ def fit(
219221

220222
weighted_n_samples += sample_weights_honest[i]
221223

224+
# more pythonic introspection minutiae
222225
# fingers crossed sklearn.utils.validation.check_is_fitted doesn't
223226
# change its behavior
224-
#print(f"n_classes = {target_bta.n_classes}")
225227
self.tree_ = HonestTree(
226228
self.target_tree.n_features_in_,
227229
target_bta.n_classes,
@@ -231,9 +233,7 @@ def fit(
231233
self.honesty.resize_tree(self.tree_, self.honesty.get_node_count())
232234
self.tree_.node_count = self.honesty.get_node_count()
233235

234-
#print(f"dishonest node count = {self.target_tree.tree_.node_count}")
235-
#print(f"honest node count = {self.tree_.node_count}")
236-
236+
# Criterion is very stateful, so do all the instantiation and initialization
237237
criterion = BaseDecisionTree._create_criterion(
238238
self.target_tree,
239239
n_outputs=target_bta.y.shape[1],
@@ -250,8 +250,6 @@ def fit(
250250

251251
for i in range(self.honesty.get_node_count()):
252252
start, end = self.honesty.get_node_range(i)
253-
#print(f"setting sample range for node {i}: ({start}, {end})")
254-
#print(f"node {i} is leaf: {self.honesty.is_leaf(i)}")
255253
self.honesty.set_sample_pointers(criterion, start, end)
256254

257255
if missing_values_in_feature_mask is not None:
@@ -262,6 +260,7 @@ def fit(
262260
if self.honesty.is_leaf(i):
263261
self.honesty.node_samples(self.tree_, criterion, i)
264262

263+
# more pythonic introspection minutiae
265264
setattr(
266265
self,
267266
"__sklearn_is_fitted__",

0 commit comments

Comments
 (0)