|
1 | 1 | """
|
2 |
| -=================================================================== |
| 2 | +======================== |
3 | 3 | Decision Tree Regression
|
4 |
| -=================================================================== |
5 |
| -
|
6 |
| -A 1D regression with decision tree. |
7 |
| -
|
8 |
| -The :ref:`decision trees <tree>` is |
9 |
| -used to fit a sine curve with addition noisy observation. As a result, it |
10 |
| -learns local linear regressions approximating the sine curve. |
11 |
| -
|
12 |
| -We can see that if the maximum depth of the tree (controlled by the |
13 |
| -`max_depth` parameter) is set too high, the decision trees learn too fine |
14 |
| -details of the training data and learn from the noise, i.e. they overfit. |
| 4 | +======================== |
| 5 | +In this example, we demonstrate the effect of changing the maximum depth of a |
| 6 | +decision tree on how it fits to the data. We perform this once on a 1D regression |
| 7 | +task and once on a multi-output regression task. |
15 | 8 | """
|
16 | 9 |
|
17 | 10 | # Authors: The scikit-learn developers
|
18 | 11 | # SPDX-License-Identifier: BSD-3-Clause
|
19 | 12 |
|
20 |
| -# Import the necessary modules and libraries |
21 |
| -import matplotlib.pyplot as plt |
| 13 | +# %% |
| 14 | +# Decision Tree on a 1D Regression Task |
| 15 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 16 | +# |
| 17 | +# Here we fit a tree on a 1D regression task. |
| 18 | +# |
| 19 | +# The :ref:`decision trees <tree>` is |
| 20 | +# used to fit a sine curve with addition noisy observation. As a result, it |
| 21 | +# learns local linear regressions approximating the sine curve. |
| 22 | +# |
| 23 | +# We can see that if the maximum depth of the tree (controlled by the |
| 24 | +# `max_depth` parameter) is set too high, the decision trees learn too fine |
| 25 | +# details of the training data and learn from the noise, i.e. they overfit. |
| 26 | +# |
| 27 | +# Create a random 1D dataset |
| 28 | +# -------------------------- |
22 | 29 | import numpy as np
|
23 | 30 |
|
24 |
| -from sklearn.tree import DecisionTreeRegressor |
25 |
| - |
26 |
| -# Create a random dataset |
27 | 31 | rng = np.random.RandomState(1)
|
28 | 32 | X = np.sort(5 * rng.rand(80, 1), axis=0)
|
29 | 33 | y = np.sin(X).ravel()
|
30 | 34 | y[::5] += 3 * (0.5 - rng.rand(16))
|
31 | 35 |
|
| 36 | +# %% |
32 | 37 | # Fit regression model
|
| 38 | +# -------------------- |
| 39 | +# Here we fit two models with different maximum depths |
| 40 | +from sklearn.tree import DecisionTreeRegressor |
| 41 | + |
33 | 42 | regr_1 = DecisionTreeRegressor(max_depth=2)
|
34 | 43 | regr_2 = DecisionTreeRegressor(max_depth=5)
|
35 | 44 | regr_1.fit(X, y)
|
36 | 45 | regr_2.fit(X, y)
|
37 | 46 |
|
| 47 | +# %% |
38 | 48 | # Predict
|
| 49 | +# ------- |
| 50 | +# Get predictions on the test set |
39 | 51 | X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
|
40 | 52 | y_1 = regr_1.predict(X_test)
|
41 | 53 | y_2 = regr_2.predict(X_test)
|
42 | 54 |
|
| 55 | +# %% |
43 | 56 | # Plot the results
|
| 57 | +# ---------------- |
| 58 | +import matplotlib.pyplot as plt |
| 59 | + |
44 | 60 | plt.figure()
|
45 | 61 | plt.scatter(X, y, s=20, edgecolor="black", c="darkorange", label="data")
|
46 | 62 | plt.plot(X_test, y_1, color="cornflowerblue", label="max_depth=2", linewidth=2)
|
|
50 | 66 | plt.title("Decision Tree Regression")
|
51 | 67 | plt.legend()
|
52 | 68 | plt.show()
|
| 69 | + |
| 70 | +# %% |
| 71 | +# As you can see, the model with a depth of 5 (yellow) learns the details of the |
| 72 | +# training data to the point that it overfits to the noise. On the other hand, |
| 73 | +# the model with a depth of 2 (blue) learns the major tendencies in the data well |
| 74 | +# and does not overfit. In real use cases, you need to make sure that the tree |
| 75 | +# is not overfitting the training data, which can be done using cross-validation. |
| 76 | + |
| 77 | +# %% |
| 78 | +# Decision Tree Regression with Multi-Output Targets |
| 79 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 80 | +# |
| 81 | +# Here the :ref:`decision trees <tree>` |
| 82 | +# is used to predict simultaneously the noisy `x` and `y` observations of a circle |
| 83 | +# given a single underlying feature. As a result, it learns local linear |
| 84 | +# regressions approximating the circle. |
| 85 | +# |
| 86 | +# We can see that if the maximum depth of the tree (controlled by the |
| 87 | +# `max_depth` parameter) is set too high, the decision trees learn too fine |
| 88 | +# details of the training data and learn from the noise, i.e. they overfit. |
| 89 | + |
| 90 | +# %% |
| 91 | +# Create a random dataset |
| 92 | +# ----------------------- |
| 93 | +rng = np.random.RandomState(1) |
| 94 | +X = np.sort(200 * rng.rand(100, 1) - 100, axis=0) |
| 95 | +y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T |
| 96 | +y[::5, :] += 0.5 - rng.rand(20, 2) |
| 97 | + |
| 98 | +# %% |
| 99 | +# Fit regression model |
| 100 | +# -------------------- |
| 101 | +regr_1 = DecisionTreeRegressor(max_depth=2) |
| 102 | +regr_2 = DecisionTreeRegressor(max_depth=5) |
| 103 | +regr_3 = DecisionTreeRegressor(max_depth=8) |
| 104 | +regr_1.fit(X, y) |
| 105 | +regr_2.fit(X, y) |
| 106 | +regr_3.fit(X, y) |
| 107 | + |
| 108 | +# %% |
| 109 | +# Predict |
| 110 | +# ------- |
| 111 | +# Get predictions on the test set |
| 112 | +X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis] |
| 113 | +y_1 = regr_1.predict(X_test) |
| 114 | +y_2 = regr_2.predict(X_test) |
| 115 | +y_3 = regr_3.predict(X_test) |
| 116 | + |
| 117 | +# %% |
| 118 | +# Plot the results |
| 119 | +# ---------------- |
| 120 | +plt.figure() |
| 121 | +s = 25 |
| 122 | +plt.scatter(y[:, 0], y[:, 1], c="yellow", s=s, edgecolor="black", label="data") |
| 123 | +plt.scatter( |
| 124 | + y_1[:, 0], |
| 125 | + y_1[:, 1], |
| 126 | + c="cornflowerblue", |
| 127 | + s=s, |
| 128 | + edgecolor="black", |
| 129 | + label="max_depth=2", |
| 130 | +) |
| 131 | +plt.scatter(y_2[:, 0], y_2[:, 1], c="red", s=s, edgecolor="black", label="max_depth=5") |
| 132 | +plt.scatter(y_3[:, 0], y_3[:, 1], c="blue", s=s, edgecolor="black", label="max_depth=8") |
| 133 | +plt.xlim([-6, 6]) |
| 134 | +plt.ylim([-6, 6]) |
| 135 | +plt.xlabel("target 1") |
| 136 | +plt.ylabel("target 2") |
| 137 | +plt.title("Multi-output Decision Tree Regression") |
| 138 | +plt.legend(loc="best") |
| 139 | +plt.show() |
| 140 | + |
| 141 | +# %% |
| 142 | +# As you can see, the higher the value of `max_depth`, the more details of the data |
| 143 | +# are caught by the model. However, the model also overfits to the data and is |
| 144 | +# influenced by the noise. |
0 commit comments