Skip to content

Commit f84199e

Browse files
authored
DOC remove example OLS with 3D plot (scikit-learn#29967)
1 parent 9671047 commit f84199e

File tree

3 files changed

+81
-129
lines changed

3 files changed

+81
-129
lines changed

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def add_js_css_files(app, pagename, templatename, context, doctree):
488488
"auto_examples/datasets/plot_iris_dataset": (
489489
"auto_examples/decomposition/plot_pca_iris"
490490
),
491+
"auto_examples/linear_model/plot_ols_3d": ("auto_examples/linear_model/plot_ols"),
491492
}
492493
html_context["redirects"] = redirects
493494
for old_link in redirects:

examples/linear_model/plot_ols.py

Lines changed: 80 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,97 @@
11
"""
2-
=========================================================
3-
Linear Regression Example
4-
=========================================================
5-
The example below uses only the first feature of the `diabetes` dataset,
6-
in order to illustrate the data points within the two-dimensional plot.
7-
The straight line can be seen in the plot, showing how linear regression
8-
attempts to draw a straight line that will best minimize the
9-
residual sum of squares between the observed responses in the dataset,
10-
and the responses predicted by the linear approximation.
11-
12-
The coefficients, residual sum of squares and the coefficient of
13-
determination are also calculated.
2+
==============================
3+
Ordinary Least Squares Example
4+
==============================
145
6+
This example shows how to use the ordinary least squares (OLS) model
7+
called :class:`~sklearn.linear_model.LinearRegression` in scikit-learn.
8+
9+
For this purpose, we use a single feature from the diabetes dataset and try to
10+
predict the diabetes progression using this linear model. We therefore load the
11+
diabetes dataset and split it into training and test sets.
12+
13+
Then, we fit the model on the training set and evaluate its performance on the test
14+
set and finally visualize the results on the test set.
1515
"""
1616

1717
# Authors: The scikit-learn developers
1818
# SPDX-License-Identifier: BSD-3-Clause
1919

20-
import matplotlib.pyplot as plt
21-
import numpy as np
22-
23-
from sklearn import datasets, linear_model
20+
# %%
21+
# Data Loading and Preparation
22+
# ----------------------------
23+
#
24+
# Load the diabetes dataset. For simplicity, we only keep a single feature in the data.
25+
# Then, we split the data and target into training and test sets.
26+
from sklearn.datasets import load_diabetes
27+
from sklearn.model_selection import train_test_split
28+
29+
X, y = load_diabetes(return_X_y=True)
30+
X = X[:, [2]] # Use only one feature
31+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=20, shuffle=False)
32+
33+
# %%
34+
# Linear regression model
35+
# -----------------------
36+
#
37+
# We create a linear regression model and fit it on the training data. Note that by
38+
# default, an intercept is added to the model. We can control this behavior by setting
39+
# the `fit_intercept` parameter.
40+
from sklearn.linear_model import LinearRegression
41+
42+
regressor = LinearRegression().fit(X_train, y_train)
43+
44+
# %%
45+
# Model evaluation
46+
# ----------------
47+
#
48+
# We evaluate the model's performance on the test set using the mean squared error
49+
# and the coefficient of determination.
2450
from sklearn.metrics import mean_squared_error, r2_score
2551

26-
# Load the diabetes dataset
27-
diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)
28-
29-
# Use only one feature
30-
diabetes_X = diabetes_X[:, np.newaxis, 2]
52+
y_pred = regressor.predict(X_test)
3153

32-
# Split the data into training/testing sets
33-
diabetes_X_train = diabetes_X[:-20]
34-
diabetes_X_test = diabetes_X[-20:]
54+
print(f"Mean squared error: {mean_squared_error(y_test, y_pred):.2f}")
55+
print(f"Coefficient of determination: {r2_score(y_test, y_pred):.2f}")
3556

36-
# Split the targets into training/testing sets
37-
diabetes_y_train = diabetes_y[:-20]
38-
diabetes_y_test = diabetes_y[-20:]
39-
40-
# Create linear regression object
41-
regr = linear_model.LinearRegression()
42-
43-
# Train the model using the training sets
44-
regr.fit(diabetes_X_train, diabetes_y_train)
57+
# %%
58+
# Plotting the results
59+
# --------------------
60+
#
61+
# Finally, we visualize the results on the train and test data.
62+
import matplotlib.pyplot as plt
4563

46-
# Make predictions using the testing set
47-
diabetes_y_pred = regr.predict(diabetes_X_test)
64+
fig, ax = plt.subplots(ncols=2, figsize=(10, 5), sharex=True, sharey=True)
4865

49-
# The coefficients
50-
print("Coefficients: \n", regr.coef_)
51-
# The mean squared error
52-
print("Mean squared error: %.2f" % mean_squared_error(diabetes_y_test, diabetes_y_pred))
53-
# The coefficient of determination: 1 is perfect prediction
54-
print("Coefficient of determination: %.2f" % r2_score(diabetes_y_test, diabetes_y_pred))
66+
ax[0].scatter(X_train, y_train, label="Train data points")
67+
ax[0].plot(
68+
X_train,
69+
regressor.predict(X_train),
70+
linewidth=3,
71+
color="tab:orange",
72+
label="Model predictions",
73+
)
74+
ax[0].set(xlabel="Feature", ylabel="Target", title="Train set")
75+
ax[0].legend()
5576

56-
# Plot outputs
57-
plt.scatter(diabetes_X_test, diabetes_y_test, color="black")
58-
plt.plot(diabetes_X_test, diabetes_y_pred, color="blue", linewidth=3)
77+
ax[1].scatter(X_test, y_test, label="Test data points")
78+
ax[1].plot(X_test, y_pred, linewidth=3, color="tab:orange", label="Model predictions")
79+
ax[1].set(xlabel="Feature", ylabel="Target", title="Test set")
80+
ax[1].legend()
5981

60-
plt.xticks(())
61-
plt.yticks(())
82+
fig.suptitle("Linear Regression")
6283

6384
plt.show()
85+
86+
# %%
87+
# Conclusion
88+
# ----------
89+
#
90+
# The trained model corresponds to the estimator that minimizes the mean squared error
91+
# between the predicted and the true target values on the training data. We therefore
92+
# obtain an estimator of the conditional mean of the target given the data.
93+
#
94+
# Note that in higher dimensions, minimizing only the squared error might lead to
95+
# overfitting. Therefore, regularization techniques are commonly used to prevent this
96+
# issue, such as those implemented in :class:`~sklearn.linear_model.Ridge` or
97+
# :class:`~sklearn.linear_model.Lasso`.

examples/linear_model/plot_ols_3d.py

Lines changed: 0 additions & 83 deletions
This file was deleted.

0 commit comments

Comments
 (0)