|
1 | 1 | """
|
2 |
| -================================================== |
3 |
| -Plot the decision boundaries of a VotingClassifier |
4 |
| -================================================== |
| 2 | +=============================================================== |
| 3 | +Visualizing the probabilistic predictions of a VotingClassifier |
| 4 | +=============================================================== |
5 | 5 |
|
6 | 6 | .. currentmodule:: sklearn
|
7 | 7 |
|
8 |
| -Plot the decision boundaries of a :class:`~ensemble.VotingClassifier` for two |
9 |
| -features of the Iris dataset. |
| 8 | +Plot the predicted class probabilities in a toy dataset predicted by three |
| 9 | +different classifiers and averaged by the :class:`~ensemble.VotingClassifier`. |
10 | 10 |
|
11 |
| -Plot the class probabilities of the first sample in a toy dataset predicted by |
12 |
| -three different classifiers and averaged by the |
13 |
| -:class:`~ensemble.VotingClassifier`. |
| 11 | +First, three linear classifiers are initialized. Two are spline models with |
| 12 | +interaction terms, one using constant extrapolation and the other using periodic |
| 13 | +extrapolation. The third classifier is a :class:`~kernel_approximation.Nystroem` |
| 14 | +with the default "rbf" kernel. |
14 | 15 |
|
15 |
| -First, three exemplary classifiers are initialized |
16 |
| -(:class:`~tree.DecisionTreeClassifier`, |
17 |
| -:class:`~neighbors.KNeighborsClassifier`, and :class:`~svm.SVC`) and used to |
18 |
| -initialize a soft-voting :class:`~ensemble.VotingClassifier` with weights `[2, |
19 |
| -1, 2]`, which means that the predicted probabilities of the |
20 |
| -:class:`~tree.DecisionTreeClassifier` and :class:`~svm.SVC` each count 2 times |
21 |
| -as much as the weights of the :class:`~neighbors.KNeighborsClassifier` |
22 |
| -classifier when the averaged probability is calculated. |
| 16 | +In the first part of this example, these three classifiers are used to |
| 17 | +demonstrate soft-voting using :class:`~ensemble.VotingClassifier` with weighted |
| 18 | +average. We set `weights=[2, 1, 3]`, meaning the constant extrapolation spline |
| 19 | +model's predictions are weighted twice as much as the periodic spline model's, |
| 20 | +and the Nystroem model's predictions are weighted three times as much as the |
| 21 | +periodic spline. |
| 22 | +
|
| 23 | +The second part demonstrates how soft predictions can be converted into hard |
| 24 | +predictions. |
23 | 25 |
|
24 | 26 | """
|
25 | 27 |
|
26 | 28 | # Authors: The scikit-learn developers
|
27 | 29 | # SPDX-License-Identifier: BSD-3-Clause
|
28 | 30 |
|
29 |
| -from itertools import product |
| 31 | +# %% |
| 32 | +# We first generate a noisy XOR dataset, which is a binary classification task. |
30 | 33 |
|
31 | 34 | import matplotlib.pyplot as plt
|
| 35 | +import numpy as np |
| 36 | +import pandas as pd |
| 37 | +from matplotlib.colors import ListedColormap |
| 38 | + |
| 39 | +n_samples = 500 |
| 40 | +rng = np.random.default_rng(0) |
| 41 | +feature_names = ["Feature #0", "Feature #1"] |
| 42 | +common_scatter_plot_params = dict( |
| 43 | + cmap=ListedColormap(["tab:red", "tab:blue"]), |
| 44 | + edgecolor="white", |
| 45 | + linewidth=1, |
| 46 | +) |
| 47 | + |
| 48 | +xor = pd.DataFrame( |
| 49 | + np.random.RandomState(0).uniform(low=-1, high=1, size=(n_samples, 2)), |
| 50 | + columns=feature_names, |
| 51 | +) |
| 52 | +noise = rng.normal(loc=0, scale=0.1, size=(n_samples, 2)) |
| 53 | +target_xor = np.logical_xor( |
| 54 | + xor["Feature #0"] + noise[:, 0] > 0, xor["Feature #1"] + noise[:, 1] > 0 |
| 55 | +) |
| 56 | + |
| 57 | +X = xor[feature_names] |
| 58 | +y = target_xor.astype(np.int32) |
| 59 | + |
| 60 | +fig, ax = plt.subplots() |
| 61 | +ax.scatter(X["Feature #0"], X["Feature #1"], c=y, **common_scatter_plot_params) |
| 62 | +ax.set_title("The XOR dataset") |
| 63 | +plt.show() |
| 64 | + |
| 65 | +# %% |
| 66 | +# Due to the inherent non-linear separability of the XOR dataset, tree-based |
| 67 | +# models would often be preferred. However, appropriate feature engineering |
| 68 | +# combined with a linear model can yield effective results, with the added |
| 69 | +# benefit of producing better-calibrated probabilities for samples located in |
| 70 | +# the transition regions affected by noise. |
| 71 | +# |
| 72 | +# We define and fit the models on the whole dataset. |
32 | 73 |
|
33 |
| -from sklearn import datasets |
34 | 74 | from sklearn.ensemble import VotingClassifier
|
35 |
| -from sklearn.inspection import DecisionBoundaryDisplay |
36 |
| -from sklearn.neighbors import KNeighborsClassifier |
37 |
| -from sklearn.svm import SVC |
38 |
| -from sklearn.tree import DecisionTreeClassifier |
39 |
| - |
40 |
| -# Loading some example data |
41 |
| -iris = datasets.load_iris() |
42 |
| -X = iris.data[:, [0, 2]] |
43 |
| -y = iris.target |
44 |
| - |
45 |
| -# Training classifiers |
46 |
| -clf1 = DecisionTreeClassifier(max_depth=4) |
47 |
| -clf2 = KNeighborsClassifier(n_neighbors=7) |
48 |
| -clf3 = SVC(gamma=0.1, kernel="rbf", probability=True) |
| 75 | +from sklearn.kernel_approximation import Nystroem |
| 76 | +from sklearn.linear_model import LogisticRegression |
| 77 | +from sklearn.pipeline import make_pipeline |
| 78 | +from sklearn.preprocessing import PolynomialFeatures, SplineTransformer, StandardScaler |
| 79 | + |
| 80 | +clf1 = make_pipeline( |
| 81 | + SplineTransformer(degree=2, n_knots=2), |
| 82 | + PolynomialFeatures(interaction_only=True), |
| 83 | + LogisticRegression(C=10), |
| 84 | +) |
| 85 | +clf2 = make_pipeline( |
| 86 | + SplineTransformer( |
| 87 | + degree=2, |
| 88 | + n_knots=4, |
| 89 | + extrapolation="periodic", |
| 90 | + include_bias=True, |
| 91 | + ), |
| 92 | + PolynomialFeatures(interaction_only=True), |
| 93 | + LogisticRegression(C=10), |
| 94 | +) |
| 95 | +clf3 = make_pipeline( |
| 96 | + StandardScaler(), |
| 97 | + Nystroem(gamma=2, random_state=0), |
| 98 | + LogisticRegression(C=10), |
| 99 | +) |
| 100 | +weights = [2, 1, 3] |
49 | 101 | eclf = VotingClassifier(
|
50 |
| - estimators=[("dt", clf1), ("knn", clf2), ("svc", clf3)], |
| 102 | + estimators=[ |
| 103 | + ("constant splines model", clf1), |
| 104 | + ("periodic splines model", clf2), |
| 105 | + ("nystroem model", clf3), |
| 106 | + ], |
51 | 107 | voting="soft",
|
52 |
| - weights=[2, 1, 2], |
| 108 | + weights=weights, |
53 | 109 | )
|
54 | 110 |
|
55 | 111 | clf1.fit(X, y)
|
56 | 112 | clf2.fit(X, y)
|
57 | 113 | clf3.fit(X, y)
|
58 | 114 | eclf.fit(X, y)
|
59 | 115 |
|
60 |
| -# Plotting decision regions |
61 |
| -f, axarr = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(10, 8)) |
62 |
| -for idx, clf, tt in zip( |
| 116 | +# %% |
| 117 | +# Finally we use :class:`~inspection.DecisionBoundaryDisplay` to plot the |
| 118 | +# predicted probabilities. By using a diverging colormap (such as `"RdBu"`), we |
| 119 | +# can ensure that darker colors correspond to `predict_proba` close to either 0 |
| 120 | +# or 1, and white corresponds to `predict_proba` of 0.5. |
| 121 | + |
| 122 | +from itertools import product |
| 123 | + |
| 124 | +from sklearn.inspection import DecisionBoundaryDisplay |
| 125 | + |
| 126 | +fig, axarr = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(10, 8)) |
| 127 | +for idx, clf, title in zip( |
63 | 128 | product([0, 1], [0, 1]),
|
64 | 129 | [clf1, clf2, clf3, eclf],
|
65 |
| - ["Decision Tree (depth=4)", "KNN (k=7)", "Kernel SVM", "Soft Voting"], |
| 130 | + [ |
| 131 | + "Splines with\nconstant extrapolation", |
| 132 | + "Splines with\nperiodic extrapolation", |
| 133 | + "RBF Nystroem", |
| 134 | + "Soft Voting", |
| 135 | + ], |
66 | 136 | ):
|
67 |
| - DecisionBoundaryDisplay.from_estimator( |
68 |
| - clf, X, alpha=0.4, ax=axarr[idx[0], idx[1]], response_method="predict" |
| 137 | + disp = DecisionBoundaryDisplay.from_estimator( |
| 138 | + clf, |
| 139 | + X, |
| 140 | + response_method="predict_proba", |
| 141 | + plot_method="pcolormesh", |
| 142 | + cmap="RdBu", |
| 143 | + alpha=0.8, |
| 144 | + ax=axarr[idx[0], idx[1]], |
| 145 | + ) |
| 146 | + axarr[idx[0], idx[1]].scatter( |
| 147 | + X["Feature #0"], |
| 148 | + X["Feature #1"], |
| 149 | + c=y, |
| 150 | + **common_scatter_plot_params, |
69 | 151 | )
|
70 |
| - axarr[idx[0], idx[1]].scatter(X[:, 0], X[:, 1], c=y, s=20, edgecolor="k") |
71 |
| - axarr[idx[0], idx[1]].set_title(tt) |
| 152 | + axarr[idx[0], idx[1]].set_title(title) |
| 153 | + fig.colorbar(disp.surface_, ax=axarr[idx[0], idx[1]], label="Probability estimate") |
72 | 154 |
|
73 | 155 | plt.show()
|
| 156 | + |
| 157 | +# %% |
| 158 | +# As a sanity check, we can verify for a given sample that the probability |
| 159 | +# predicted by the :class:`~ensemble.VotingClassifier` is indeed the weighted |
| 160 | +# average of the individual classifiers' soft-predictions. |
| 161 | +# |
| 162 | +# In the case of binary classification such as in the present example, the |
| 163 | +# :term:`predict_proba` arrays contain the probability of belonging to class 0 |
| 164 | +# (here in red) as the first entry, and the probability of belonging to class 1 |
| 165 | +# (here in blue) as the second entry. |
| 166 | + |
| 167 | +test_sample = pd.DataFrame({"Feature #0": [-0.5], "Feature #1": [1.5]}) |
| 168 | +predict_probas = [est.predict_proba(test_sample).ravel() for est in eclf.estimators_] |
| 169 | +for (est_name, _), est_probas in zip(eclf.estimators, predict_probas): |
| 170 | + print(f"{est_name}'s predicted probabilities: {est_probas}") |
| 171 | + |
| 172 | +# %% |
| 173 | +print( |
| 174 | + "Weighted average of soft-predictions: " |
| 175 | + f"{np.dot(weights, predict_probas) / np.sum(weights)}" |
| 176 | +) |
| 177 | + |
| 178 | +# %% |
| 179 | +# We can see that manual calculation of predicted probabilities above is |
| 180 | +# equivalent to that produced by the `VotingClassifier`: |
| 181 | + |
| 182 | +print( |
| 183 | + "Predicted probability of VotingClassifier: " |
| 184 | + f"{eclf.predict_proba(test_sample).ravel()}" |
| 185 | +) |
| 186 | + |
| 187 | +# %% |
| 188 | +# To convert soft predictions into hard predictions when weights are provided, |
| 189 | +# the weighted average predicted probabilities are computed for each class. |
| 190 | +# Then, the final class label is then derived from the class label with the |
| 191 | +# highest average probability, which corresponds to the default threshold at |
| 192 | +# `predict_proba=0.5` in the case of binary classification. |
| 193 | + |
| 194 | +print( |
| 195 | + "Class with the highest weighted average of soft-predictions: " |
| 196 | + f"{np.argmax(np.dot(weights, predict_probas) / np.sum(weights))}" |
| 197 | +) |
| 198 | + |
| 199 | +# %% |
| 200 | +# This is equivalent to the output of `VotingClassifier`'s `predict` method: |
| 201 | + |
| 202 | +print(f"Predicted class of VotingClassifier: {eclf.predict(test_sample).ravel()}") |
| 203 | + |
| 204 | +# %% |
| 205 | +# Soft votes can be thresholded as for any other probabilistic classifier. This |
| 206 | +# allows you to set a threshold probability at which the positive class will be |
| 207 | +# predicted, instead of simply selecting the class with the highest predicted |
| 208 | +# probability. |
| 209 | + |
| 210 | +from sklearn.model_selection import FixedThresholdClassifier |
| 211 | + |
| 212 | +eclf_other_threshold = FixedThresholdClassifier( |
| 213 | + eclf, threshold=0.7, response_method="predict_proba" |
| 214 | +).fit(X, y) |
| 215 | +print( |
| 216 | + "Predicted class of thresholded VotingClassifier: " |
| 217 | + f"{eclf_other_threshold.predict(test_sample)}" |
| 218 | +) |
0 commit comments