|
3 | 3 | Plot classification probability
|
4 | 4 | ===============================
|
5 | 5 |
|
6 |
| -Plot the classification probability for different classifiers. We use a 3 class |
7 |
| -dataset, and we classify it with a Support Vector classifier, L1 and L2 |
8 |
| -penalized logistic regression (multinomial multiclass), a One-Vs-Rest version with |
9 |
| -logistic regression, and Gaussian process classification. |
| 6 | +This example illustrates the use of |
| 7 | +:class:`sklearn.inspection.DecisionBoundaryDisplay` to plot the predicted class |
| 8 | +probabilities of various classifiers in a 2D feature space, mostly for didactic |
| 9 | +purposes. |
10 | 10 |
|
11 |
| -Linear SVC is not a probabilistic classifier by default but it has a built-in |
12 |
| -calibration option enabled in this example (`probability=True`). |
13 |
| -
|
14 |
| -The logistic regression with One-Vs-Rest is not a multiclass classifier out of |
15 |
| -the box. As a result it has more trouble in separating class 2 and 3 than the |
16 |
| -other estimators. |
| 11 | +The first three columns shows the predicted probability for varying values of |
| 12 | +the two features. Round markers represent the test data that was predicted to |
| 13 | +belong to that class. |
17 | 14 |
|
| 15 | +In the last column, all three classes are represented on each plot; the class |
| 16 | +with the highest predicted probability at each point is plotted. The round |
| 17 | +markers show the test data and are colored by their true label. |
18 | 18 | """
|
19 | 19 |
|
| 20 | +# %% |
20 | 21 | # Authors: The scikit-learn developers
|
21 | 22 | # SPDX-License-Identifier: BSD-3-Clause
|
22 | 23 |
|
| 24 | +import matplotlib as mpl |
23 | 25 | import matplotlib.pyplot as plt
|
24 | 26 | import numpy as np
|
| 27 | +import pandas as pd |
25 | 28 | from matplotlib import cm
|
26 | 29 |
|
27 | 30 | from sklearn import datasets
|
| 31 | +from sklearn.ensemble import HistGradientBoostingClassifier |
28 | 32 | from sklearn.gaussian_process import GaussianProcessClassifier
|
29 | 33 | from sklearn.gaussian_process.kernels import RBF
|
30 | 34 | from sklearn.inspection import DecisionBoundaryDisplay
|
| 35 | +from sklearn.kernel_approximation import Nystroem |
31 | 36 | from sklearn.linear_model import LogisticRegression
|
32 |
| -from sklearn.metrics import accuracy_score |
33 |
| -from sklearn.multiclass import OneVsRestClassifier |
34 |
| -from sklearn.svm import SVC |
| 37 | +from sklearn.metrics import accuracy_score, log_loss, roc_auc_score |
| 38 | +from sklearn.model_selection import train_test_split |
| 39 | +from sklearn.pipeline import make_pipeline |
| 40 | +from sklearn.preprocessing import ( |
| 41 | + KBinsDiscretizer, |
| 42 | + PolynomialFeatures, |
| 43 | + SplineTransformer, |
| 44 | +) |
35 | 45 |
|
| 46 | +# %% |
| 47 | +# Data: 2D projection of the iris dataset |
| 48 | +# --------------------------------------- |
36 | 49 | iris = datasets.load_iris()
|
37 | 50 | X = iris.data[:, 0:2] # we only take the first two features for visualization
|
38 | 51 | y = iris.target
|
39 | 52 |
|
40 |
| -n_features = X.shape[1] |
| 53 | +X_train, X_test, y_train, y_test = train_test_split( |
| 54 | + X, y, test_size=0.5, random_state=42 |
| 55 | +) |
| 56 | + |
41 | 57 |
|
42 |
| -C = 10 |
43 |
| -kernel = 1.0 * RBF([1.0, 1.0]) # for GPC |
| 58 | +# %% |
| 59 | +# Probabilistic classifiers |
| 60 | +# ------------------------- |
| 61 | +# |
| 62 | +# We will plot the decision boundaries of several classifiers that have a |
| 63 | +# `predict_proba` method. This will allow us to visualize the uncertainty of |
| 64 | +# the classifier in regions where it is not certain of its prediction. |
44 | 65 |
|
45 |
| -# Create different classifiers. |
46 | 66 | classifiers = {
|
47 |
| - "L1 logistic": LogisticRegression(C=C, penalty="l1", solver="saga", max_iter=10000), |
48 |
| - "L2 logistic (Multinomial)": LogisticRegression( |
49 |
| - C=C, penalty="l2", solver="saga", max_iter=10000 |
| 67 | + "Logistic regression\n(C=0.01)": LogisticRegression(C=0.1), |
| 68 | + "Logistic regression\n(C=1)": LogisticRegression(C=100), |
| 69 | + "Gaussian Process": GaussianProcessClassifier(kernel=1.0 * RBF([1.0, 1.0])), |
| 70 | + "Logistic regression\n(RBF features)": make_pipeline( |
| 71 | + Nystroem(kernel="rbf", gamma=5e-1, n_components=50, random_state=1), |
| 72 | + LogisticRegression(C=10), |
50 | 73 | ),
|
51 |
| - "L2 logistic (OvR)": OneVsRestClassifier( |
52 |
| - LogisticRegression(C=C, penalty="l2", solver="saga", max_iter=10000) |
| 74 | + "Gradient Boosting": HistGradientBoostingClassifier(), |
| 75 | + "Logistic regression\n(binned features)": make_pipeline( |
| 76 | + KBinsDiscretizer(n_bins=5, quantile_method="averaged_inverted_cdf"), |
| 77 | + PolynomialFeatures(interaction_only=True), |
| 78 | + LogisticRegression(C=10), |
| 79 | + ), |
| 80 | + "Logistic regression\n(spline features)": make_pipeline( |
| 81 | + SplineTransformer(n_knots=5), |
| 82 | + PolynomialFeatures(interaction_only=True), |
| 83 | + LogisticRegression(C=10), |
53 | 84 | ),
|
54 |
| - "Linear SVC": SVC(kernel="linear", C=C, probability=True, random_state=0), |
55 |
| - "GPC": GaussianProcessClassifier(kernel), |
56 | 85 | }
|
57 | 86 |
|
| 87 | +# %% |
| 88 | +# Plotting the decision boundaries |
| 89 | +# -------------------------------- |
| 90 | +# |
| 91 | +# For each classifier, we plot the per-class probabilities on the first three |
| 92 | +# columns and the probabilities of the most likely class on the last column. |
| 93 | + |
58 | 94 | n_classifiers = len(classifiers)
|
| 95 | +scatter_kwargs = { |
| 96 | + "s": 25, |
| 97 | + "marker": "o", |
| 98 | + "linewidths": 0.8, |
| 99 | + "edgecolor": "k", |
| 100 | + "alpha": 0.7, |
| 101 | +} |
| 102 | +y_unique = np.unique(y) |
59 | 103 |
|
| 104 | +# Ensure legend not cut off |
| 105 | +mpl.rcParams["savefig.bbox"] = "tight" |
60 | 106 | fig, axes = plt.subplots(
|
61 | 107 | nrows=n_classifiers,
|
62 |
| - ncols=len(iris.target_names), |
63 |
| - figsize=(3 * 2, n_classifiers * 2), |
| 108 | + ncols=len(iris.target_names) + 1, |
| 109 | + figsize=(4 * 2.2, n_classifiers * 2.2), |
64 | 110 | )
|
| 111 | +evaluation_results = [] |
| 112 | +levels = 100 |
65 | 113 | for classifier_idx, (name, classifier) in enumerate(classifiers.items()):
|
66 |
| - y_pred = classifier.fit(X, y).predict(X) |
67 |
| - accuracy = accuracy_score(y, y_pred) |
68 |
| - print(f"Accuracy (train) for {name}: {accuracy:0.1%}") |
69 |
| - for label in np.unique(y): |
| 114 | + y_pred = classifier.fit(X_train, y_train).predict(X_test) |
| 115 | + y_pred_proba = classifier.predict_proba(X_test) |
| 116 | + accuracy_test = accuracy_score(y_test, y_pred) |
| 117 | + roc_auc_test = roc_auc_score(y_test, y_pred_proba, multi_class="ovr") |
| 118 | + log_loss_test = log_loss(y_test, y_pred_proba) |
| 119 | + evaluation_results.append( |
| 120 | + { |
| 121 | + "name": name.replace("\n", " "), |
| 122 | + "accuracy": accuracy_test, |
| 123 | + "roc_auc": roc_auc_test, |
| 124 | + "log_loss": log_loss_test, |
| 125 | + } |
| 126 | + ) |
| 127 | + for label in y_unique: |
70 | 128 | # plot the probability estimate provided by the classifier
|
71 | 129 | disp = DecisionBoundaryDisplay.from_estimator(
|
72 | 130 | classifier,
|
73 |
| - X, |
| 131 | + X_train, |
74 | 132 | response_method="predict_proba",
|
75 | 133 | class_of_interest=label,
|
76 | 134 | ax=axes[classifier_idx, label],
|
77 | 135 | vmin=0,
|
78 | 136 | vmax=1,
|
| 137 | + cmap="Blues", |
| 138 | + levels=levels, |
79 | 139 | )
|
80 | 140 | axes[classifier_idx, label].set_title(f"Class {label}")
|
81 | 141 | # plot data predicted to belong to given class
|
82 | 142 | mask_y_pred = y_pred == label
|
83 | 143 | axes[classifier_idx, label].scatter(
|
84 |
| - X[mask_y_pred, 0], X[mask_y_pred, 1], marker="o", c="w", edgecolor="k" |
| 144 | + X_test[mask_y_pred, 0], X_test[mask_y_pred, 1], c="w", **scatter_kwargs |
85 | 145 | )
|
| 146 | + |
86 | 147 | axes[classifier_idx, label].set(xticks=(), yticks=())
|
| 148 | + # add column that shows all classes by plotting class with max 'predict_proba' |
| 149 | + max_class_disp = DecisionBoundaryDisplay.from_estimator( |
| 150 | + classifier, |
| 151 | + X_train, |
| 152 | + response_method="predict_proba", |
| 153 | + class_of_interest=None, |
| 154 | + ax=axes[classifier_idx, len(y_unique)], |
| 155 | + vmin=0, |
| 156 | + vmax=1, |
| 157 | + levels=levels, |
| 158 | + ) |
| 159 | + for label in y_unique: |
| 160 | + mask_label = y_test == label |
| 161 | + axes[classifier_idx, 3].scatter( |
| 162 | + X_test[mask_label, 0], |
| 163 | + X_test[mask_label, 1], |
| 164 | + c=max_class_disp.multiclass_colors_[[label], :], |
| 165 | + **scatter_kwargs, |
| 166 | + ) |
| 167 | + |
| 168 | + axes[classifier_idx, 3].set(xticks=(), yticks=()) |
| 169 | + axes[classifier_idx, 3].set_title("Max class") |
87 | 170 | axes[classifier_idx, 0].set_ylabel(name)
|
88 | 171 |
|
89 |
| -ax = plt.axes([0.15, 0.04, 0.7, 0.02]) |
| 172 | +# colorbar for single class plots |
| 173 | +ax_single = fig.add_axes([0.15, 0.01, 0.5, 0.02]) |
90 | 174 | plt.title("Probability")
|
91 | 175 | _ = plt.colorbar(
|
92 |
| - cm.ScalarMappable(norm=None, cmap="viridis"), cax=ax, orientation="horizontal" |
| 176 | + cm.ScalarMappable(norm=None, cmap=disp.surface_.cmap), |
| 177 | + cax=ax_single, |
| 178 | + orientation="horizontal", |
93 | 179 | )
|
94 | 180 |
|
95 |
| -plt.show() |
| 181 | +# colorbars for max probability class column |
| 182 | +max_class_cmaps = [s.cmap for s in max_class_disp.surface_] |
| 183 | + |
| 184 | +for label in y_unique: |
| 185 | + ax_max = fig.add_axes([0.73, (0.06 - (label * 0.04)), 0.16, 0.015]) |
| 186 | + plt.title(f"Probability class {label}", fontsize=10) |
| 187 | + _ = plt.colorbar( |
| 188 | + cm.ScalarMappable(norm=None, cmap=max_class_cmaps[label]), |
| 189 | + cax=ax_max, |
| 190 | + orientation="horizontal", |
| 191 | + ) |
| 192 | + if label in (0, 1): |
| 193 | + ax_max.set(xticks=(), yticks=()) |
| 194 | + |
| 195 | + |
| 196 | +# %% |
| 197 | +# Quantitative evaluation |
| 198 | +# ----------------------- |
| 199 | +pd.DataFrame(evaluation_results).round(2) |
| 200 | + |
| 201 | + |
| 202 | +# %% |
| 203 | +# Analysis |
| 204 | +# -------- |
| 205 | +# |
| 206 | +# The two logistic regression models fitted on the original features display |
| 207 | +# linear decision boundaries as expected. For this particular problem, this |
| 208 | +# does not seem to be detrimental as both models are competitive with the |
| 209 | +# non-linear models when quantitatively evaluated on the test set. We can |
| 210 | +# observe that the amount of regularization influences the model confidence: |
| 211 | +# lighter colors for the strongly regularized model with a lower value of `C`. |
| 212 | +# Regularization also impacts the orientation of decision boundary leading to |
| 213 | +# slightly different ROC AUC. |
| 214 | +# |
| 215 | +# The log-loss on the other hand evaluates both sharpness and calibration and |
| 216 | +# as a result strongly favors the weakly regularized logistic-regression model, |
| 217 | +# probably because the strongly regularized model is under-confident. This |
| 218 | +# could be confirmed by looking at the calibration curve using |
| 219 | +# :class:`sklearn.calibration.CalibrationDisplay`. |
| 220 | +# |
| 221 | +# The logistic regression model with RBF features has a "blobby" decision |
| 222 | +# boundary that is non-linear in the original feature space and is quite |
| 223 | +# similar to the decision boundary of the Gaussian process classifier which is |
| 224 | +# configured to use an RBF kernel. |
| 225 | +# |
| 226 | +# The logistic regression model fitted on binned features with interactions has |
| 227 | +# a decision boundary that is non-linear in the original feature space and is |
| 228 | +# quite similar to the decision boundary of the gradient boosting classifier: |
| 229 | +# both models favor axis-aligned decisions when extrapolating to unseen region |
| 230 | +# of the feature space. |
| 231 | +# |
| 232 | +# The logistic regression model fitted on spline features with interactions |
| 233 | +# has a similar axis-aligned extrapolation behavior but a smoother decision |
| 234 | +# boundary in the dense region of the feature space than the two previous |
| 235 | +# models. |
| 236 | +# |
| 237 | +# To conclude, it is interesting to observe that feature engineering for |
| 238 | +# logistic regression models can be used to mimic some of the inductive bias of |
| 239 | +# various non-linear models. However, for this particular dataset, using the |
| 240 | +# raw features is enough to train a competitive model. This would not |
| 241 | +# necessarily the case for other datasets. |
0 commit comments