Skip to content

Commit 325930e

Browse files
Tialoadrinjalali
andauthored
DOC Add link to plot_tree_regression.py example (scikit-learn#26962)
Co-authored-by: adrinjalali <adrin.jalali@gmail.com>
1 parent 3cda5b2 commit 325930e

File tree

4 files changed

+115
-89
lines changed

4 files changed

+115
-89
lines changed

doc/modules/tree.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,11 @@ of shape ``(n_samples, n_outputs)`` then the resulting estimator will:
284284
``predict_proba``.
285285

286286
The use of multi-output trees for regression is demonstrated in
287-
:ref:`sphx_glr_auto_examples_tree_plot_tree_regression_multioutput.py`. In this example, the input
287+
:ref:`sphx_glr_auto_examples_tree_plot_tree_regression.py`. In this example, the input
288288
X is a single real value and the outputs Y are the sine and cosine of X.
289289

290-
.. figure:: ../auto_examples/tree/images/sphx_glr_plot_tree_regression_multioutput_001.png
291-
:target: ../auto_examples/tree/plot_tree_regression_multioutput.html
290+
.. figure:: ../auto_examples/tree/images/sphx_glr_plot_tree_regression_002.png
291+
:target: ../auto_examples/tree/plot_tree_regression.html
292292
:scale: 75
293293
:align: center
294294

@@ -304,7 +304,6 @@ the lower half of those faces.
304304

305305
.. rubric:: Examples
306306

307-
* :ref:`sphx_glr_auto_examples_tree_plot_tree_regression_multioutput.py`
308307
* :ref:`sphx_glr_auto_examples_miscellaneous_plot_multioutput_face_completion.py`
309308

310309
.. rubric:: References

examples/tree/plot_tree_regression.py

Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,62 @@
11
"""
2-
===================================================================
2+
========================
33
Decision Tree Regression
4-
===================================================================
5-
6-
A 1D regression with decision tree.
7-
8-
The :ref:`decision trees <tree>` is
9-
used to fit a sine curve with addition noisy observation. As a result, it
10-
learns local linear regressions approximating the sine curve.
11-
12-
We can see that if the maximum depth of the tree (controlled by the
13-
`max_depth` parameter) is set too high, the decision trees learn too fine
14-
details of the training data and learn from the noise, i.e. they overfit.
4+
========================
5+
In this example, we demonstrate the effect of changing the maximum depth of a
6+
decision tree on how it fits to the data. We perform this once on a 1D regression
7+
task and once on a multi-output regression task.
158
"""
169

1710
# Authors: The scikit-learn developers
1811
# SPDX-License-Identifier: BSD-3-Clause
1912

20-
# Import the necessary modules and libraries
21-
import matplotlib.pyplot as plt
13+
# %%
14+
# Decision Tree on a 1D Regression Task
15+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
16+
#
17+
# Here we fit a tree on a 1D regression task.
18+
#
19+
# The :ref:`decision trees <tree>` is
20+
# used to fit a sine curve with addition noisy observation. As a result, it
21+
# learns local linear regressions approximating the sine curve.
22+
#
23+
# We can see that if the maximum depth of the tree (controlled by the
24+
# `max_depth` parameter) is set too high, the decision trees learn too fine
25+
# details of the training data and learn from the noise, i.e. they overfit.
26+
#
27+
# Create a random 1D dataset
28+
# --------------------------
2229
import numpy as np
2330

24-
from sklearn.tree import DecisionTreeRegressor
25-
26-
# Create a random dataset
2731
rng = np.random.RandomState(1)
2832
X = np.sort(5 * rng.rand(80, 1), axis=0)
2933
y = np.sin(X).ravel()
3034
y[::5] += 3 * (0.5 - rng.rand(16))
3135

36+
# %%
3237
# Fit regression model
38+
# --------------------
39+
# Here we fit two models with different maximum depths
40+
from sklearn.tree import DecisionTreeRegressor
41+
3342
regr_1 = DecisionTreeRegressor(max_depth=2)
3443
regr_2 = DecisionTreeRegressor(max_depth=5)
3544
regr_1.fit(X, y)
3645
regr_2.fit(X, y)
3746

47+
# %%
3848
# Predict
49+
# -------
50+
# Get predictions on the test set
3951
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
4052
y_1 = regr_1.predict(X_test)
4153
y_2 = regr_2.predict(X_test)
4254

55+
# %%
4356
# Plot the results
57+
# ----------------
58+
import matplotlib.pyplot as plt
59+
4460
plt.figure()
4561
plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
4662
plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
@@ -50,3 +66,79 @@
5066
plt.title("Decision Tree Regression")
5167
plt.legend()
5268
plt.show()
69+
70+
# %%
71+
# As you can see, the model with a depth of 5 (yellow) learns the details of the
72+
# training data to the point that it overfits to the noise. On the other hand,
73+
# the model with a depth of 2 (blue) learns the major tendencies in the data well
74+
# and does not overfit. In real use cases, you need to make sure that the tree
75+
# is not overfitting the training data, which can be done using cross-validation.
76+
77+
# %%
78+
# Decision Tree Regression with Multi-Output Targets
79+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
80+
#
81+
# Here the :ref:`decision trees <tree>`
82+
# is used to predict simultaneously the noisy `x` and `y` observations of a circle
83+
# given a single underlying feature. As a result, it learns local linear
84+
# regressions approximating the circle.
85+
#
86+
# We can see that if the maximum depth of the tree (controlled by the
87+
# `max_depth` parameter) is set too high, the decision trees learn too fine
88+
# details of the training data and learn from the noise, i.e. they overfit.
89+
90+
# %%
91+
# Create a random dataset
92+
# -----------------------
93+
rng = np.random.RandomState(1)
94+
X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
95+
y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
96+
y[::5, :] += 0.5 - rng.rand(20, 2)
97+
98+
# %%
99+
# Fit regression model
100+
# --------------------
101+
regr_1 = DecisionTreeRegressor(max_depth=2)
102+
regr_2 = DecisionTreeRegressor(max_depth=5)
103+
regr_3 = DecisionTreeRegressor(max_depth=8)
104+
regr_1.fit(X, y)
105+
regr_2.fit(X, y)
106+
regr_3.fit(X, y)
107+
108+
# %%
109+
# Predict
110+
# -------
111+
# Get predictions on the test set
112+
X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]
113+
y_1 = regr_1.predict(X_test)
114+
y_2 = regr_2.predict(X_test)
115+
y_3 = regr_3.predict(X_test)
116+
117+
# %%
118+
# Plot the results
119+
# ----------------
120+
plt.figure()
121+
s = 25
122+
plt.scatter(y[:, 0], y[:, 1], c="yellow", s=s, edgecolor="black", label="data")
123+
plt.scatter(
124+
y_1[:, 0],
125+
y_1[:, 1],
126+
c="cornflowerblue",
127+
s=s,
128+
edgecolor="black",
129+
label="max_depth=2",
130+
)
131+
plt.scatter(y_2[:, 0], y_2[:, 1], c="red", s=s, edgecolor="black", label="max_depth=5")
132+
plt.scatter(y_3[:, 0], y_3[:, 1], c="blue", s=s, edgecolor="black", label="max_depth=8")
133+
plt.xlim([-6, 6])
134+
plt.ylim([-6, 6])
135+
plt.xlabel("target 1")
136+
plt.ylabel("target 2")
137+
plt.title("Multi-output Decision Tree Regression")
138+
plt.legend(loc="best")
139+
plt.show()
140+
141+
# %%
142+
# As you can see, the higher the value of `max_depth`, the more details of the data
143+
# are caught by the model. However, the model also overfits to the data and is
144+
# influenced by the noise.

examples/tree/plot_tree_regression_multioutput.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

sklearn/tree/_classes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,9 @@ class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree):
11261126
all leaves are pure or until all leaves contain less than
11271127
min_samples_split samples.
11281128
1129+
For an example of how ``max_depth`` influences the model, see
1130+
:ref:`sphx_glr_auto_examples_tree_plot_tree_regression.py`.
1131+
11291132
min_samples_split : int or float, default=2
11301133
The minimum number of samples required to split an internal node:
11311134

0 commit comments

Comments
 (0)