Skip to content

Commit 610d4f7

Browse files
adam2392ArturoAmorQlesteve
authored
DOC Fix tree explanation of tree_.value in example (scikit-learn#29331)
Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent faf197a commit 610d4f7

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

doc/whats_new/v1.4.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ Version 1.4.1
2929

3030
**February 2024**
3131

32+
Changed models
33+
--------------
34+
35+
- |API| The `tree_.value` attribute in :class:`tree.DecisionTreeClassifier`,
36+
:class:`tree.DecisionTreeRegressor`, :class:`tree.ExtraTreeClassifier` and
37+
:class:`tree.ExtraTreeRegressor` changed from an weighted absolute count
38+
of number of samples to a weighted fraction of the total number of samples.
39+
:pr:`27639` by :user:`Samuel Ronsin <samronsin>`.
40+
3241
Metadata Routing
3342
----------------
3443

examples/tree/plot_unveil_tree_structure.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@
6868
# - ``weighted_n_node_samples[i]``: the weighted number of training samples
6969
# reaching node ``i``
7070
# - ``value[i, j, k]``: the summary of the training samples that reached node i for
71-
# output j and class k (for regression tree, class is set to 1).
71+
# output j and class k (for regression tree, class is set to 1). See below
72+
# for more information about ``value``.
7273
#
7374
# Using the arrays, we can traverse the tree structure to compute various
7475
# properties. Below, we will compute the depth of each node and whether or not
@@ -108,7 +109,7 @@
108109
if is_leaves[i]:
109110
print(
110111
"{space}node={node} is a leaf node with value={value}.".format(
111-
space=node_depth[i] * "\t", node=i, value=values[i]
112+
space=node_depth[i] * "\t", node=i, value=np.around(values[i], 3)
112113
)
113114
)
114115
else:
@@ -122,24 +123,36 @@
122123
feature=feature[i],
123124
threshold=threshold[i],
124125
right=children_right[i],
125-
value=values[i],
126+
value=np.around(values[i], 3),
126127
)
127128
)
128129

129130
# %%
130131
# What is the values array used here?
131132
# -----------------------------------
132133
# The `tree_.value` array is a 3D array of shape
133-
# [``n_nodes``, ``n_classes``, ``n_outputs``] which provides the count of samples
134-
# reaching a node for each class and for each output. Each node has a ``value``
135-
# array which is the number of weighted samples reaching this
136-
# node for each output and class.
134+
# [``n_nodes``, ``n_classes``, ``n_outputs``] which provides the proportion of samples
135+
# reaching a node for each class and for each output.
136+
# Each node has a ``value`` array which is the proportion of weighted samples reaching
137+
# this node for each output and class with respect to the parent node.
138+
#
139+
# One could convert this to the absolute weighted number of samples reaching a node,
140+
# by multiplying this number by `tree_.weighted_n_node_samples[node_idx]` for the
141+
# given node. Note sample weights are not used in this example, so the weighted
142+
# number of samples is the number of samples reaching the node because each sample
143+
# has a weight of 1 by default.
137144
#
138145
# For example, in the above tree built on the iris dataset, the root node has
139-
# ``value = [37, 34, 41]``, indicating there are 37 samples
146+
# ``value = [0.33, 0.304, 0.366]`` indicating there are 33% of class 0 samples,
147+
# 30.4% of class 1 samples, and 36.6% of class 2 samples at the root node. One can
148+
# convert this to the absolute number of samples by multiplying by the number of
149+
# samples reaching the root node, which is `tree_.weighted_n_node_samples[0]`.
150+
# Then the root node has ``value = [37, 34, 41]``, indicating there are 37 samples
140151
# of class 0, 34 samples of class 1, and 41 samples of class 2 at the root node.
152+
#
141153
# Traversing the tree, the samples are split and as a result, the ``value`` array
142-
# reaching each node changes. The left child of the root node has ``value = [37, 0, 0]``
154+
# reaching each node changes. The left child of the root node has ``value = [1., 0, 0]``
155+
# (or ``value = [37, 0, 0]`` when converted to the absolute number of samples)
143156
# because all 37 samples in the left child node are from class 0.
144157
#
145158
# Note: In this example, `n_outputs=1`, but the tree classifier can also handle
@@ -148,8 +161,10 @@
148161

149162
##############################################################################
150163
# We can compare the above output to the plot of the decision tree.
164+
# Here, we show the proportions of samples of each class that reach each
165+
# node corresponding to the actual elements of `tree_.value` array.
151166

152-
tree.plot_tree(clf)
167+
tree.plot_tree(clf, proportion=True)
153168
plt.show()
154169

155170
##############################################################################

0 commit comments

Comments
 (0)