Skip to content

Commit 88283ee

Browse files
lucyleeowogrisel
andauthored
ENH Allows plotting max class for multiclass in DecisionBoundaryDisplay (scikit-learn#29797)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 2c0cdd4 commit 88283ee

File tree

4 files changed

+479
-109
lines changed

4 files changed

+479
-109
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :class:`inspection.DecisionBoundaryDisplay` now supports
2+
plotting all classes for multi-class problems when `response_method` is
3+
'decision_function', 'predict_proba' or 'auto'.
4+
By :user:`Lucy Liu <lucyleeow>`

examples/classification/plot_classification_probability.py

Lines changed: 181 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,93 +3,239 @@
33
Plot classification probability
44
===============================
55
6-
Plot the classification probability for different classifiers. We use a 3 class
7-
dataset, and we classify it with a Support Vector classifier, L1 and L2
8-
penalized logistic regression (multinomial multiclass), a One-Vs-Rest version with
9-
logistic regression, and Gaussian process classification.
6+
This example illustrates the use of
7+
:class:`sklearn.inspection.DecisionBoundaryDisplay` to plot the predicted class
8+
probabilities of various classifiers in a 2D feature space, mostly for didactic
9+
purposes.
1010
11-
Linear SVC is not a probabilistic classifier by default but it has a built-in
12-
calibration option enabled in this example (`probability=True`).
13-
14-
The logistic regression with One-Vs-Rest is not a multiclass classifier out of
15-
the box. As a result it has more trouble in separating class 2 and 3 than the
16-
other estimators.
11+
The first three columns shows the predicted probability for varying values of
12+
the two features. Round markers represent the test data that was predicted to
13+
belong to that class.
1714
15+
In the last column, all three classes are represented on each plot; the class
16+
with the highest predicted probability at each point is plotted. The round
17+
markers show the test data and are colored by their true label.
1818
"""
1919

20+
# %%
2021
# Authors: The scikit-learn developers
2122
# SPDX-License-Identifier: BSD-3-Clause
2223

24+
import matplotlib as mpl
2325
import matplotlib.pyplot as plt
2426
import numpy as np
27+
import pandas as pd
2528
from matplotlib import cm
2629

2730
from sklearn import datasets
31+
from sklearn.ensemble import HistGradientBoostingClassifier
2832
from sklearn.gaussian_process import GaussianProcessClassifier
2933
from sklearn.gaussian_process.kernels import RBF
3034
from sklearn.inspection import DecisionBoundaryDisplay
35+
from sklearn.kernel_approximation import Nystroem
3136
from sklearn.linear_model import LogisticRegression
32-
from sklearn.metrics import accuracy_score
33-
from sklearn.multiclass import OneVsRestClassifier
34-
from sklearn.svm import SVC
37+
from sklearn.metrics import accuracy_score, log_loss, roc_auc_score
38+
from sklearn.model_selection import train_test_split
39+
from sklearn.pipeline import make_pipeline
40+
from sklearn.preprocessing import (
41+
KBinsDiscretizer,
42+
PolynomialFeatures,
43+
SplineTransformer,
44+
)
3545

46+
# %%
47+
# Data: 2D projection of the iris dataset
48+
# ---------------------------------------
3649
iris = datasets.load_iris()
3750
X = iris.data[:, 0:2] # we only take the first two features for visualization
3851
y = iris.target
3952

40-
n_features = X.shape[1]
53+
X_train, X_test, y_train, y_test = train_test_split(
54+
X, y, test_size=0.5, random_state=42
55+
)
56+
4157

42-
C = 10
43-
kernel = 1.0 * RBF([1.0, 1.0]) # for GPC
58+
# %%
59+
# Probabilistic classifiers
60+
# -------------------------
61+
#
62+
# We will plot the decision boundaries of several classifiers that have a
63+
# `predict_proba` method. This will allow us to visualize the uncertainty of
64+
# the classifier in regions where it is not certain of its prediction.
4465

45-
# Create different classifiers.
4666
classifiers = {
47-
"L1 logistic": LogisticRegression(C=C, penalty="l1", solver="saga", max_iter=10000),
48-
"L2 logistic (Multinomial)": LogisticRegression(
49-
C=C, penalty="l2", solver="saga", max_iter=10000
67+
"Logistic regression\n(C=0.01)": LogisticRegression(C=0.1),
68+
"Logistic regression\n(C=1)": LogisticRegression(C=100),
69+
"Gaussian Process": GaussianProcessClassifier(kernel=1.0 * RBF([1.0, 1.0])),
70+
"Logistic regression\n(RBF features)": make_pipeline(
71+
Nystroem(kernel="rbf", gamma=5e-1, n_components=50, random_state=1),
72+
LogisticRegression(C=10),
5073
),
51-
"L2 logistic (OvR)": OneVsRestClassifier(
52-
LogisticRegression(C=C, penalty="l2", solver="saga", max_iter=10000)
74+
"Gradient Boosting": HistGradientBoostingClassifier(),
75+
"Logistic regression\n(binned features)": make_pipeline(
76+
KBinsDiscretizer(n_bins=5, quantile_method="averaged_inverted_cdf"),
77+
PolynomialFeatures(interaction_only=True),
78+
LogisticRegression(C=10),
79+
),
80+
"Logistic regression\n(spline features)": make_pipeline(
81+
SplineTransformer(n_knots=5),
82+
PolynomialFeatures(interaction_only=True),
83+
LogisticRegression(C=10),
5384
),
54-
"Linear SVC": SVC(kernel="linear", C=C, probability=True, random_state=0),
55-
"GPC": GaussianProcessClassifier(kernel),
5685
}
5786

87+
# %%
88+
# Plotting the decision boundaries
89+
# --------------------------------
90+
#
91+
# For each classifier, we plot the per-class probabilities on the first three
92+
# columns and the probabilities of the most likely class on the last column.
93+
5894
n_classifiers = len(classifiers)
95+
scatter_kwargs = {
96+
"s": 25,
97+
"marker": "o",
98+
"linewidths": 0.8,
99+
"edgecolor": "k",
100+
"alpha": 0.7,
101+
}
102+
y_unique = np.unique(y)
59103

104+
# Ensure legend not cut off
105+
mpl.rcParams["savefig.bbox"] = "tight"
60106
fig, axes = plt.subplots(
61107
nrows=n_classifiers,
62-
ncols=len(iris.target_names),
63-
figsize=(3 * 2, n_classifiers * 2),
108+
ncols=len(iris.target_names) + 1,
109+
figsize=(4 * 2.2, n_classifiers * 2.2),
64110
)
111+
evaluation_results = []
112+
levels = 100
65113
for classifier_idx, (name, classifier) in enumerate(classifiers.items()):
66-
y_pred = classifier.fit(X, y).predict(X)
67-
accuracy = accuracy_score(y, y_pred)
68-
print(f"Accuracy (train) for {name}: {accuracy:0.1%}")
69-
for label in np.unique(y):
114+
y_pred = classifier.fit(X_train, y_train).predict(X_test)
115+
y_pred_proba = classifier.predict_proba(X_test)
116+
accuracy_test = accuracy_score(y_test, y_pred)
117+
roc_auc_test = roc_auc_score(y_test, y_pred_proba, multi_class="ovr")
118+
log_loss_test = log_loss(y_test, y_pred_proba)
119+
evaluation_results.append(
120+
{
121+
"name": name.replace("\n", " "),
122+
"accuracy": accuracy_test,
123+
"roc_auc": roc_auc_test,
124+
"log_loss": log_loss_test,
125+
}
126+
)
127+
for label in y_unique:
70128
# plot the probability estimate provided by the classifier
71129
disp = DecisionBoundaryDisplay.from_estimator(
72130
classifier,
73-
X,
131+
X_train,
74132
response_method="predict_proba",
75133
class_of_interest=label,
76134
ax=axes[classifier_idx, label],
77135
vmin=0,
78136
vmax=1,
137+
cmap="Blues",
138+
levels=levels,
79139
)
80140
axes[classifier_idx, label].set_title(f"Class {label}")
81141
# plot data predicted to belong to given class
82142
mask_y_pred = y_pred == label
83143
axes[classifier_idx, label].scatter(
84-
X[mask_y_pred, 0], X[mask_y_pred, 1], marker="o", c="w", edgecolor="k"
144+
X_test[mask_y_pred, 0], X_test[mask_y_pred, 1], c="w", **scatter_kwargs
85145
)
146+
86147
axes[classifier_idx, label].set(xticks=(), yticks=())
148+
# add column that shows all classes by plotting class with max 'predict_proba'
149+
max_class_disp = DecisionBoundaryDisplay.from_estimator(
150+
classifier,
151+
X_train,
152+
response_method="predict_proba",
153+
class_of_interest=None,
154+
ax=axes[classifier_idx, len(y_unique)],
155+
vmin=0,
156+
vmax=1,
157+
levels=levels,
158+
)
159+
for label in y_unique:
160+
mask_label = y_test == label
161+
axes[classifier_idx, 3].scatter(
162+
X_test[mask_label, 0],
163+
X_test[mask_label, 1],
164+
c=max_class_disp.multiclass_colors_[[label], :],
165+
**scatter_kwargs,
166+
)
167+
168+
axes[classifier_idx, 3].set(xticks=(), yticks=())
169+
axes[classifier_idx, 3].set_title("Max class")
87170
axes[classifier_idx, 0].set_ylabel(name)
88171

89-
ax = plt.axes([0.15, 0.04, 0.7, 0.02])
172+
# colorbar for single class plots
173+
ax_single = fig.add_axes([0.15, 0.01, 0.5, 0.02])
90174
plt.title("Probability")
91175
_ = plt.colorbar(
92-
cm.ScalarMappable(norm=None, cmap="viridis"), cax=ax, orientation="horizontal"
176+
cm.ScalarMappable(norm=None, cmap=disp.surface_.cmap),
177+
cax=ax_single,
178+
orientation="horizontal",
93179
)
94180

95-
plt.show()
181+
# colorbars for max probability class column
182+
max_class_cmaps = [s.cmap for s in max_class_disp.surface_]
183+
184+
for label in y_unique:
185+
ax_max = fig.add_axes([0.73, (0.06 - (label * 0.04)), 0.16, 0.015])
186+
plt.title(f"Probability class {label}", fontsize=10)
187+
_ = plt.colorbar(
188+
cm.ScalarMappable(norm=None, cmap=max_class_cmaps[label]),
189+
cax=ax_max,
190+
orientation="horizontal",
191+
)
192+
if label in (0, 1):
193+
ax_max.set(xticks=(), yticks=())
194+
195+
196+
# %%
197+
# Quantitative evaluation
198+
# -----------------------
199+
pd.DataFrame(evaluation_results).round(2)
200+
201+
202+
# %%
203+
# Analysis
204+
# --------
205+
#
206+
# The two logistic regression models fitted on the original features display
207+
# linear decision boundaries as expected. For this particular problem, this
208+
# does not seem to be detrimental as both models are competitive with the
209+
# non-linear models when quantitatively evaluated on the test set. We can
210+
# observe that the amount of regularization influences the model confidence:
211+
# lighter colors for the strongly regularized model with a lower value of `C`.
212+
# Regularization also impacts the orientation of decision boundary leading to
213+
# slightly different ROC AUC.
214+
#
215+
# The log-loss on the other hand evaluates both sharpness and calibration and
216+
# as a result strongly favors the weakly regularized logistic-regression model,
217+
# probably because the strongly regularized model is under-confident. This
218+
# could be confirmed by looking at the calibration curve using
219+
# :class:`sklearn.calibration.CalibrationDisplay`.
220+
#
221+
# The logistic regression model with RBF features has a "blobby" decision
222+
# boundary that is non-linear in the original feature space and is quite
223+
# similar to the decision boundary of the Gaussian process classifier which is
224+
# configured to use an RBF kernel.
225+
#
226+
# The logistic regression model fitted on binned features with interactions has
227+
# a decision boundary that is non-linear in the original feature space and is
228+
# quite similar to the decision boundary of the gradient boosting classifier:
229+
# both models favor axis-aligned decisions when extrapolating to unseen region
230+
# of the feature space.
231+
#
232+
# The logistic regression model fitted on spline features with interactions
233+
# has a similar axis-aligned extrapolation behavior but a smoother decision
234+
# boundary in the dense region of the feature space than the two previous
235+
# models.
236+
#
237+
# To conclude, it is interesting to observe that feature engineering for
238+
# logistic regression models can be used to mimic some of the inductive bias of
239+
# various non-linear models. However, for this particular dataset, using the
240+
# raw features is enough to train a competitive model. This would not
241+
# necessarily the case for other datasets.

0 commit comments

Comments
 (0)