Skip to content

Commit 1527b1f

Browse files
ArturoAmorQArturoAmorQogrisellucyleeow
authored
DOC Rework voting classifier example (scikit-learn#30985)
Co-authored-by: ArturoAmorQ <arturo.amor-quiroz@polytechnique.edu> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Lucy Liu <jliu176@gmail.com>
1 parent 4985e69 commit 1527b1f

File tree

4 files changed

+199
-171
lines changed

4 files changed

+199
-171
lines changed

doc/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,9 @@ def add_js_css_files(app, pagename, templatename, context, doctree):
491491
"auto_examples/ensemble/plot_forest_importances_faces": (
492492
"auto_examples/ensemble/plot_forest_importances"
493493
),
494+
"auto_examples/ensemble/plot_voting_probas": (
495+
"auto_examples/ensemble/plot_voting_decision_regions"
496+
),
494497
"auto_examples/datasets/plot_iris_dataset": (
495498
"auto_examples/decomposition/plot_pca_iris"
496499
),

doc/modules/ensemble.rst

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,40 +1410,17 @@ classifier 3 w3 * 0.3 w3 * 0.4 w3 * 0.3
14101410
weighted average 0.37 0.4 0.23
14111411
================ ========== ========== ==========
14121412

1413-
Here, the predicted class label is 2, since it has the highest average probability. See
1414-
this example on :ref:`Visualising class probabilities in a Voting Classifier
1415-
<sphx_glr_auto_examples_ensemble_plot_voting_probas.py>` for a detailed illustration of
1416-
class probabilities averaged by soft voting.
1413+
Here, the predicted class label is 2, since it has the highest average
1414+
predicted probability. See the example on
1415+
:ref:`sphx_glr_auto_examples_ensemble_plot_voting_decision_regions.py` for a
1416+
demonstration of how the predicted class label can be obtained from the weighted
1417+
average of predicted probabilities.
14171418

1418-
Also, the following example illustrates how the decision regions may change
1419-
when a soft :class:`VotingClassifier` is used based on a linear Support
1420-
Vector Machine, a Decision Tree, and a K-nearest neighbor classifier::
1419+
The following figure illustrates how the decision regions may change when
1420+
a soft :class:`VotingClassifier` is trained with weights on three linear
1421+
models:
14211422

1422-
>>> from sklearn import datasets
1423-
>>> from sklearn.tree import DecisionTreeClassifier
1424-
>>> from sklearn.neighbors import KNeighborsClassifier
1425-
>>> from sklearn.svm import SVC
1426-
>>> from itertools import product
1427-
>>> from sklearn.ensemble import VotingClassifier
1428-
1429-
>>> # Loading some example data
1430-
>>> iris = datasets.load_iris()
1431-
>>> X = iris.data[:, [0, 2]]
1432-
>>> y = iris.target
1433-
1434-
>>> # Training classifiers
1435-
>>> clf1 = DecisionTreeClassifier(max_depth=4)
1436-
>>> clf2 = KNeighborsClassifier(n_neighbors=7)
1437-
>>> clf3 = SVC(kernel='rbf', probability=True)
1438-
>>> eclf = VotingClassifier(estimators=[('dt', clf1), ('knn', clf2), ('svc', clf3)],
1439-
... voting='soft', weights=[2, 1, 2])
1440-
1441-
>>> clf1 = clf1.fit(X, y)
1442-
>>> clf2 = clf2.fit(X, y)
1443-
>>> clf3 = clf3.fit(X, y)
1444-
>>> eclf = eclf.fit(X, y)
1445-
1446-
.. figure:: ../auto_examples/ensemble/images/sphx_glr_plot_voting_decision_regions_001.png
1423+
.. figure:: ../auto_examples/ensemble/images/sphx_glr_plot_voting_decision_regions_002.png
14471424
:target: ../auto_examples/ensemble/plot_voting_decision_regions.html
14481425
:align: center
14491426
:scale: 75%
Lines changed: 187 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,218 @@
11
"""
2-
==================================================
3-
Plot the decision boundaries of a VotingClassifier
4-
==================================================
2+
===============================================================
3+
Visualizing the probabilistic predictions of a VotingClassifier
4+
===============================================================
55
66
.. currentmodule:: sklearn
77
8-
Plot the decision boundaries of a :class:`~ensemble.VotingClassifier` for two
9-
features of the Iris dataset.
8+
Plot the predicted class probabilities in a toy dataset predicted by three
9+
different classifiers and averaged by the :class:`~ensemble.VotingClassifier`.
1010
11-
Plot the class probabilities of the first sample in a toy dataset predicted by
12-
three different classifiers and averaged by the
13-
:class:`~ensemble.VotingClassifier`.
11+
First, three linear classifiers are initialized. Two are spline models with
12+
interaction terms, one using constant extrapolation and the other using periodic
13+
extrapolation. The third classifier is a :class:`~kernel_approximation.Nystroem`
14+
with the default "rbf" kernel.
1415
15-
First, three exemplary classifiers are initialized
16-
(:class:`~tree.DecisionTreeClassifier`,
17-
:class:`~neighbors.KNeighborsClassifier`, and :class:`~svm.SVC`) and used to
18-
initialize a soft-voting :class:`~ensemble.VotingClassifier` with weights `[2,
19-
1, 2]`, which means that the predicted probabilities of the
20-
:class:`~tree.DecisionTreeClassifier` and :class:`~svm.SVC` each count 2 times
21-
as much as the weights of the :class:`~neighbors.KNeighborsClassifier`
22-
classifier when the averaged probability is calculated.
16+
In the first part of this example, these three classifiers are used to
17+
demonstrate soft-voting using :class:`~ensemble.VotingClassifier` with weighted
18+
average. We set `weights=[2, 1, 3]`, meaning the constant extrapolation spline
19+
model's predictions are weighted twice as much as the periodic spline model's,
20+
and the Nystroem model's predictions are weighted three times as much as the
21+
periodic spline.
22+
23+
The second part demonstrates how soft predictions can be converted into hard
24+
predictions.
2325
2426
"""
2527

2628
# Authors: The scikit-learn developers
2729
# SPDX-License-Identifier: BSD-3-Clause
2830

29-
from itertools import product
31+
# %%
32+
# We first generate a noisy XOR dataset, which is a binary classification task.
3033

3134
import matplotlib.pyplot as plt
35+
import numpy as np
36+
import pandas as pd
37+
from matplotlib.colors import ListedColormap
38+
39+
n_samples = 500
40+
rng = np.random.default_rng(0)
41+
feature_names = ["Feature #0", "Feature #1"]
42+
common_scatter_plot_params = dict(
43+
cmap=ListedColormap(["tab:red", "tab:blue"]),
44+
edgecolor="white",
45+
linewidth=1,
46+
)
47+
48+
xor = pd.DataFrame(
49+
np.random.RandomState(0).uniform(low=-1, high=1, size=(n_samples, 2)),
50+
columns=feature_names,
51+
)
52+
noise = rng.normal(loc=0, scale=0.1, size=(n_samples, 2))
53+
target_xor = np.logical_xor(
54+
xor["Feature #0"] + noise[:, 0] > 0, xor["Feature #1"] + noise[:, 1] > 0
55+
)
56+
57+
X = xor[feature_names]
58+
y = target_xor.astype(np.int32)
59+
60+
fig, ax = plt.subplots()
61+
ax.scatter(X["Feature #0"], X["Feature #1"], c=y, **common_scatter_plot_params)
62+
ax.set_title("The XOR dataset")
63+
plt.show()
64+
65+
# %%
66+
# Due to the inherent non-linear separability of the XOR dataset, tree-based
67+
# models would often be preferred. However, appropriate feature engineering
68+
# combined with a linear model can yield effective results, with the added
69+
# benefit of producing better-calibrated probabilities for samples located in
70+
# the transition regions affected by noise.
71+
#
72+
# We define and fit the models on the whole dataset.
3273

33-
from sklearn import datasets
3474
from sklearn.ensemble import VotingClassifier
35-
from sklearn.inspection import DecisionBoundaryDisplay
36-
from sklearn.neighbors import KNeighborsClassifier
37-
from sklearn.svm import SVC
38-
from sklearn.tree import DecisionTreeClassifier
39-
40-
# Loading some example data
41-
iris = datasets.load_iris()
42-
X = iris.data[:, [0, 2]]
43-
y = iris.target
44-
45-
# Training classifiers
46-
clf1 = DecisionTreeClassifier(max_depth=4)
47-
clf2 = KNeighborsClassifier(n_neighbors=7)
48-
clf3 = SVC(gamma=0.1, kernel="rbf", probability=True)
75+
from sklearn.kernel_approximation import Nystroem
76+
from sklearn.linear_model import LogisticRegression
77+
from sklearn.pipeline import make_pipeline
78+
from sklearn.preprocessing import PolynomialFeatures, SplineTransformer, StandardScaler
79+
80+
clf1 = make_pipeline(
81+
SplineTransformer(degree=2, n_knots=2),
82+
PolynomialFeatures(interaction_only=True),
83+
LogisticRegression(C=10),
84+
)
85+
clf2 = make_pipeline(
86+
SplineTransformer(
87+
degree=2,
88+
n_knots=4,
89+
extrapolation="periodic",
90+
include_bias=True,
91+
),
92+
PolynomialFeatures(interaction_only=True),
93+
LogisticRegression(C=10),
94+
)
95+
clf3 = make_pipeline(
96+
StandardScaler(),
97+
Nystroem(gamma=2, random_state=0),
98+
LogisticRegression(C=10),
99+
)
100+
weights = [2, 1, 3]
49101
eclf = VotingClassifier(
50-
estimators=[("dt", clf1), ("knn", clf2), ("svc", clf3)],
102+
estimators=[
103+
("constant splines model", clf1),
104+
("periodic splines model", clf2),
105+
("nystroem model", clf3),
106+
],
51107
voting="soft",
52-
weights=[2, 1, 2],
108+
weights=weights,
53109
)
54110

55111
clf1.fit(X, y)
56112
clf2.fit(X, y)
57113
clf3.fit(X, y)
58114
eclf.fit(X, y)
59115

60-
# Plotting decision regions
61-
f, axarr = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(10, 8))
62-
for idx, clf, tt in zip(
116+
# %%
117+
# Finally we use :class:`~inspection.DecisionBoundaryDisplay` to plot the
118+
# predicted probabilities. By using a diverging colormap (such as `"RdBu"`), we
119+
# can ensure that darker colors correspond to `predict_proba` close to either 0
120+
# or 1, and white corresponds to `predict_proba` of 0.5.
121+
122+
from itertools import product
123+
124+
from sklearn.inspection import DecisionBoundaryDisplay
125+
126+
fig, axarr = plt.subplots(2, 2, sharex="col", sharey="row", figsize=(10, 8))
127+
for idx, clf, title in zip(
63128
product([0, 1], [0, 1]),
64129
[clf1, clf2, clf3, eclf],
65-
["Decision Tree (depth=4)", "KNN (k=7)", "Kernel SVM", "Soft Voting"],
130+
[
131+
"Splines with\nconstant extrapolation",
132+
"Splines with\nperiodic extrapolation",
133+
"RBF Nystroem",
134+
"Soft Voting",
135+
],
66136
):
67-
DecisionBoundaryDisplay.from_estimator(
68-
clf, X, alpha=0.4, ax=axarr[idx[0], idx[1]], response_method="predict"
137+
disp = DecisionBoundaryDisplay.from_estimator(
138+
clf,
139+
X,
140+
response_method="predict_proba",
141+
plot_method="pcolormesh",
142+
cmap="RdBu",
143+
alpha=0.8,
144+
ax=axarr[idx[0], idx[1]],
145+
)
146+
axarr[idx[0], idx[1]].scatter(
147+
X["Feature #0"],
148+
X["Feature #1"],
149+
c=y,
150+
**common_scatter_plot_params,
69151
)
70-
axarr[idx[0], idx[1]].scatter(X[:, 0], X[:, 1], c=y, s=20, edgecolor="k")
71-
axarr[idx[0], idx[1]].set_title(tt)
152+
axarr[idx[0], idx[1]].set_title(title)
153+
fig.colorbar(disp.surface_, ax=axarr[idx[0], idx[1]], label="Probability estimate")
72154

73155
plt.show()
156+
157+
# %%
158+
# As a sanity check, we can verify for a given sample that the probability
159+
# predicted by the :class:`~ensemble.VotingClassifier` is indeed the weighted
160+
# average of the individual classifiers' soft-predictions.
161+
#
162+
# In the case of binary classification such as in the present example, the
163+
# :term:`predict_proba` arrays contain the probability of belonging to class 0
164+
# (here in red) as the first entry, and the probability of belonging to class 1
165+
# (here in blue) as the second entry.
166+
167+
test_sample = pd.DataFrame({"Feature #0": [-0.5], "Feature #1": [1.5]})
168+
predict_probas = [est.predict_proba(test_sample).ravel() for est in eclf.estimators_]
169+
for (est_name, _), est_probas in zip(eclf.estimators, predict_probas):
170+
print(f"{est_name}'s predicted probabilities: {est_probas}")
171+
172+
# %%
173+
print(
174+
"Weighted average of soft-predictions: "
175+
f"{np.dot(weights, predict_probas) / np.sum(weights)}"
176+
)
177+
178+
# %%
179+
# We can see that manual calculation of predicted probabilities above is
180+
# equivalent to that produced by the `VotingClassifier`:
181+
182+
print(
183+
"Predicted probability of VotingClassifier: "
184+
f"{eclf.predict_proba(test_sample).ravel()}"
185+
)
186+
187+
# %%
188+
# To convert soft predictions into hard predictions when weights are provided,
189+
# the weighted average predicted probabilities are computed for each class.
190+
# Then, the final class label is then derived from the class label with the
191+
# highest average probability, which corresponds to the default threshold at
192+
# `predict_proba=0.5` in the case of binary classification.
193+
194+
print(
195+
"Class with the highest weighted average of soft-predictions: "
196+
f"{np.argmax(np.dot(weights, predict_probas) / np.sum(weights))}"
197+
)
198+
199+
# %%
200+
# This is equivalent to the output of `VotingClassifier`'s `predict` method:
201+
202+
print(f"Predicted class of VotingClassifier: {eclf.predict(test_sample).ravel()}")
203+
204+
# %%
205+
# Soft votes can be thresholded as for any other probabilistic classifier. This
206+
# allows you to set a threshold probability at which the positive class will be
207+
# predicted, instead of simply selecting the class with the highest predicted
208+
# probability.
209+
210+
from sklearn.model_selection import FixedThresholdClassifier
211+
212+
eclf_other_threshold = FixedThresholdClassifier(
213+
eclf, threshold=0.7, response_method="predict_proba"
214+
).fit(X, y)
215+
print(
216+
"Predicted class of thresholded VotingClassifier: "
217+
f"{eclf_other_threshold.predict(test_sample)}"
218+
)

0 commit comments

Comments
 (0)