Skip to content

Commit 8cac52f

Browse files
glemaitrejeremiedbbogrisel
authored
FEA add ValidationCurveDisplay in model_selection module (scikit-learn#25120)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 15f7cfb commit 8cac52f

File tree

11 files changed

+999
-234
lines changed

11 files changed

+999
-234
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,7 @@ Visualization
12471247
:template: display_only_from_estimator.rst
12481248

12491249
model_selection.LearningCurveDisplay
1250+
model_selection.ValidationCurveDisplay
12501251

12511252
.. _multiclass_ref:
12521253

doc/modules/learning_curve.rst

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ The function :func:`validation_curve` can help in this case::
7171
>>> import numpy as np
7272
>>> from sklearn.model_selection import validation_curve
7373
>>> from sklearn.datasets import load_iris
74-
>>> from sklearn.linear_model import Ridge
74+
>>> from sklearn.svm import SVC
7575

7676
>>> np.random.seed(0)
7777
>>> X, y = load_iris(return_X_y=True)
@@ -80,30 +80,50 @@ The function :func:`validation_curve` can help in this case::
8080
>>> X, y = X[indices], y[indices]
8181

8282
>>> train_scores, valid_scores = validation_curve(
83-
... Ridge(), X, y, param_name="alpha", param_range=np.logspace(-7, 3, 3),
84-
... cv=5)
83+
... SVC(kernel="linear"), X, y, param_name="C", param_range=np.logspace(-7, 3, 3),
84+
... )
8585
>>> train_scores
86-
array([[0.93..., 0.94..., 0.92..., 0.91..., 0.92...],
87-
[0.93..., 0.94..., 0.92..., 0.91..., 0.92...],
88-
[0.51..., 0.52..., 0.49..., 0.47..., 0.49...]])
86+
array([[0.90..., 0.94..., 0.91..., 0.89..., 0.92...],
87+
[0.9... , 0.92..., 0.93..., 0.92..., 0.93...],
88+
[0.97..., 1... , 0.98..., 0.97..., 0.99...]])
8989
>>> valid_scores
90-
array([[0.90..., 0.84..., 0.94..., 0.96..., 0.93...],
91-
[0.90..., 0.84..., 0.94..., 0.96..., 0.93...],
92-
[0.46..., 0.25..., 0.50..., 0.49..., 0.52...]])
90+
array([[0.9..., 0.9... , 0.9... , 0.96..., 0.9... ],
91+
[0.9..., 0.83..., 0.96..., 0.96..., 0.93...],
92+
[1.... , 0.93..., 1.... , 1.... , 0.9... ]])
93+
94+
If you intend to plot the validation curves only, the class
95+
:class:`~sklearn.model_selection.ValidationCurveDisplay` is more direct than
96+
using matplotlib manually on the results of a call to :func:`validation_curve`.
97+
You can use the method
98+
:meth:`~sklearn.model_selection.ValidationCurveDisplay.from_estimator` similarly
99+
to :func:`validation_curve` to generate and plot the validation curve:
100+
101+
.. plot::
102+
:context: close-figs
103+
:align: center
104+
105+
from sklearn.datasets import load_iris
106+
from sklearn.model_selection import ValidationCurveDisplay
107+
from sklearn.svm import SVC
108+
from sklearn.utils import shuffle
109+
X, y = load_iris(return_X_y=True)
110+
X, y = shuffle(X, y, random_state=0)
111+
ValidationCurveDisplay.from_estimator(
112+
SVC(kernel="linear"), X, y, param_name="C", param_range=np.logspace(-7, 3, 10)
113+
)
93114

94115
If the training score and the validation score are both low, the estimator will
95116
be underfitting. If the training score is high and the validation score is low,
96117
the estimator is overfitting and otherwise it is working very well. A low
97118
training score and a high validation score is usually not possible. Underfitting,
98119
overfitting, and a working model are shown in the in the plot below where we vary
99-
the parameter :math:`\gamma` of an SVM on the digits dataset.
120+
the parameter `gamma` of an SVM with an RBF kernel on the digits dataset.
100121

101122
.. figure:: ../auto_examples/model_selection/images/sphx_glr_plot_validation_curve_001.png
102123
:target: ../auto_examples/model_selection/plot_validation_curve.html
103124
:align: center
104125
:scale: 50%
105126

106-
107127
.. _learning_curve:
108128

109129
Learning curve

doc/visualizations.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,4 @@ Display Objects
8989
metrics.PredictionErrorDisplay
9090
metrics.RocCurveDisplay
9191
model_selection.LearningCurveDisplay
92+
model_selection.ValidationCurveDisplay

doc/whats_new/v1.3.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ random sampling procedures.
5151
used each time the kernel is called.
5252
:pr:`26337` by :user:`Yao Xiao <Charlie-XIAO>`.
5353

54+
Changed displays
55+
----------------
56+
57+
- |Enhancement| :class:`model_selection.LearningCurveDisplay` displays both the
58+
train and test curves by default. You can set `score_type="test"` to keep the
59+
past behaviour.
60+
:pr:`25120` by :user:`Guillaume Lemaitre <glemaitre>`.
61+
5462
Changes impacting all modules
5563
-----------------------------
5664

@@ -548,6 +556,18 @@ Changelog
548556
:mod:`sklearn.model_selection`
549557
..............................
550558

559+
- |MajorFeature| Added the class :class:`model_selection.ValidationCurveDisplay`
560+
that allows easy plotting of validation curves obtained by the function
561+
:func:`model_selection.validation_curve`.
562+
:pr:`25120` by :user:`Guillaume Lemaitre <glemaitre>`.
563+
564+
- |API| The parameter `log_scale` in the class
565+
:class:`model_selection.LearningCurveDisplay` has been deprecated in 1.3 and
566+
will be removed in 1.5. The default scale can be overriden by setting it
567+
directly on the `ax` object and will be set automatically from the spacing
568+
of the data points otherwise.
569+
:pr:`25120` by :user:`Guillaume Lemaitre <glemaitre>`.
570+
551571
- |Enhancement| :func:`model_selection.cross_validate` accepts a new parameter
552572
`return_indices` to return the train-test indices of each cv split.
553573
:pr:`25659` by :user:`Guillaume Lemaitre <glemaitre>`.

examples/miscellaneous/plot_kernel_ridge_regression.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@
203203
"scoring": "neg_mean_squared_error",
204204
"negate_score": True,
205205
"score_name": "Mean Squared Error",
206+
"score_type": "test",
206207
"std_display_style": None,
207208
"ax": ax,
208209
}

examples/model_selection/plot_validation_curve.py

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,53 +18,23 @@
1818

1919
from sklearn.datasets import load_digits
2020
from sklearn.svm import SVC
21-
from sklearn.model_selection import validation_curve
21+
from sklearn.model_selection import ValidationCurveDisplay
2222

2323
X, y = load_digits(return_X_y=True)
2424
subset_mask = np.isin(y, [1, 2]) # binary classification: 1 vs 2
2525
X, y = X[subset_mask], y[subset_mask]
2626

27-
param_range = np.logspace(-6, -1, 5)
28-
train_scores, test_scores = validation_curve(
27+
disp = ValidationCurveDisplay.from_estimator(
2928
SVC(),
3029
X,
3130
y,
3231
param_name="gamma",
33-
param_range=param_range,
34-
scoring="accuracy",
32+
param_range=np.logspace(-6, -1, 5),
33+
score_type="both",
3534
n_jobs=2,
35+
score_name="Accuracy",
3636
)
37-
train_scores_mean = np.mean(train_scores, axis=1)
38-
train_scores_std = np.std(train_scores, axis=1)
39-
test_scores_mean = np.mean(test_scores, axis=1)
40-
test_scores_std = np.std(test_scores, axis=1)
41-
42-
plt.title("Validation Curve with SVM")
43-
plt.xlabel(r"$\gamma$")
44-
plt.ylabel("Score")
45-
plt.ylim(0.0, 1.1)
46-
lw = 2
47-
plt.semilogx(
48-
param_range, train_scores_mean, label="Training score", color="darkorange", lw=lw
49-
)
50-
plt.fill_between(
51-
param_range,
52-
train_scores_mean - train_scores_std,
53-
train_scores_mean + train_scores_std,
54-
alpha=0.2,
55-
color="darkorange",
56-
lw=lw,
57-
)
58-
plt.semilogx(
59-
param_range, test_scores_mean, label="Cross-validation score", color="navy", lw=lw
60-
)
61-
plt.fill_between(
62-
param_range,
63-
test_scores_mean - test_scores_std,
64-
test_scores_mean + test_scores_std,
65-
alpha=0.2,
66-
color="navy",
67-
lw=lw,
68-
)
69-
plt.legend(loc="best")
37+
disp.ax_.set_title("Validation Curve for SVM with an RBF kernel")
38+
disp.ax_.set_xlabel(r"gamma (inverse radius of the RBF kernel)")
39+
disp.ax_.set_ylim(0.0, 1.1)
7040
plt.show()

sklearn/model_selection/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ._search import ParameterSampler
3434

3535
from ._plot import LearningCurveDisplay
36+
from ._plot import ValidationCurveDisplay
3637

3738
if typing.TYPE_CHECKING:
3839
# Avoid errors in type checkers (e.g. mypy) for experimental estimators.
@@ -74,6 +75,7 @@
7475
"permutation_test_score",
7576
"train_test_split",
7677
"validation_curve",
78+
"ValidationCurveDisplay",
7779
]
7880

7981

0 commit comments

Comments
 (0)