|
68 | 68 | # - ``weighted_n_node_samples[i]``: the weighted number of training samples
|
69 | 69 | # reaching node ``i``
|
70 | 70 | # - ``value[i, j, k]``: the summary of the training samples that reached node i for
|
71 |
| -# output j and class k (for regression tree, class is set to 1). |
| 71 | +# output j and class k (for regression tree, class is set to 1). See below |
| 72 | +# for more information about ``value``. |
72 | 73 | #
|
73 | 74 | # Using the arrays, we can traverse the tree structure to compute various
|
74 | 75 | # properties. Below, we will compute the depth of each node and whether or not
|
|
108 | 109 | if is_leaves[i]:
|
109 | 110 | print(
|
110 | 111 | "{space}node={node} is a leaf node with value={value}.".format(
|
111 |
| - space=node_depth[i] * "\t", node=i, value=values[i] |
| 112 | + space=node_depth[i] * "\t", node=i, value=np.around(values[i], 3) |
112 | 113 | )
|
113 | 114 | )
|
114 | 115 | else:
|
|
122 | 123 | feature=feature[i],
|
123 | 124 | threshold=threshold[i],
|
124 | 125 | right=children_right[i],
|
125 |
| - value=values[i], |
| 126 | + value=np.around(values[i], 3), |
126 | 127 | )
|
127 | 128 | )
|
128 | 129 |
|
129 | 130 | # %%
|
130 | 131 | # What is the values array used here?
|
131 | 132 | # -----------------------------------
|
132 | 133 | # The `tree_.value` array is a 3D array of shape
|
133 |
| -# [``n_nodes``, ``n_classes``, ``n_outputs``] which provides the count of samples |
134 |
| -# reaching a node for each class and for each output. Each node has a ``value`` |
135 |
| -# array which is the number of weighted samples reaching this |
136 |
| -# node for each output and class. |
| 134 | +# [``n_nodes``, ``n_classes``, ``n_outputs``] which provides the proportion of samples |
| 135 | +# reaching a node for each class and for each output. |
| 136 | +# Each node has a ``value`` array which is the proportion of weighted samples reaching |
| 137 | +# this node for each output and class with respect to the parent node. |
| 138 | +# |
| 139 | +# One could convert this to the absolute weighted number of samples reaching a node, |
| 140 | +# by multiplying this number by `tree_.weighted_n_node_samples[node_idx]` for the |
| 141 | +# given node. Note sample weights are not used in this example, so the weighted |
| 142 | +# number of samples is the number of samples reaching the node because each sample |
| 143 | +# has a weight of 1 by default. |
137 | 144 | #
|
138 | 145 | # For example, in the above tree built on the iris dataset, the root node has
|
139 |
| -# ``value = [37, 34, 41]``, indicating there are 37 samples |
| 146 | +# ``value = [0.33, 0.304, 0.366]`` indicating there are 33% of class 0 samples, |
| 147 | +# 30.4% of class 1 samples, and 36.6% of class 2 samples at the root node. One can |
| 148 | +# convert this to the absolute number of samples by multiplying by the number of |
| 149 | +# samples reaching the root node, which is `tree_.weighted_n_node_samples[0]`. |
| 150 | +# Then the root node has ``value = [37, 34, 41]``, indicating there are 37 samples |
140 | 151 | # of class 0, 34 samples of class 1, and 41 samples of class 2 at the root node.
|
| 152 | +# |
141 | 153 | # Traversing the tree, the samples are split and as a result, the ``value`` array
|
142 |
| -# reaching each node changes. The left child of the root node has ``value = [37, 0, 0]`` |
| 154 | +# reaching each node changes. The left child of the root node has ``value = [1., 0, 0]`` |
| 155 | +# (or ``value = [37, 0, 0]`` when converted to the absolute number of samples) |
143 | 156 | # because all 37 samples in the left child node are from class 0.
|
144 | 157 | #
|
145 | 158 | # Note: In this example, `n_outputs=1`, but the tree classifier can also handle
|
|
148 | 161 |
|
149 | 162 | ##############################################################################
|
150 | 163 | # We can compare the above output to the plot of the decision tree.
|
| 164 | +# Here, we show the proportions of samples of each class that reach each |
| 165 | +# node corresponding to the actual elements of `tree_.value` array. |
151 | 166 |
|
152 |
| -tree.plot_tree(clf) |
| 167 | +tree.plot_tree(clf, proportion=True) |
153 | 168 | plt.show()
|
154 | 169 |
|
155 | 170 | ##############################################################################
|
|
0 commit comments