Skip to content

Commit 8331fca

Browse files
authored
DOC fix deprecation warning in plot_sgdocsvm_vs_ocsvm (scikit-learn#27449)
1 parent 1550432 commit 8331fca

File tree

1 file changed

+94
-52
lines changed

1 file changed

+94
-52
lines changed

examples/linear_model/plot_sgdocsvm_vs_ocsvm.py

Lines changed: 94 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
2020
""" # noqa: E501
2121

22+
# %%
2223
import matplotlib
24+
import matplotlib.lines as mlines
2325
import matplotlib.pyplot as plt
2426
import numpy as np
2527

@@ -44,8 +46,6 @@
4446
# Generate some abnormal novel observations
4547
X_outliers = rng.uniform(low=-4, high=4, size=(20, 2))
4648

47-
xx, yy = np.meshgrid(np.linspace(-4.5, 4.5, 50), np.linspace(-4.5, 4.5, 50))
48-
4949
# OCSVM hyperparameters
5050
nu = 0.05
5151
gamma = 2.0
@@ -60,10 +60,6 @@
6060
n_error_test = y_pred_test[y_pred_test == -1].size
6161
n_error_outliers = y_pred_outliers[y_pred_outliers == 1].size
6262

63-
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
64-
Z = Z.reshape(xx.shape)
65-
66-
6763
# Fit the One-Class SVM using a kernel approximation and SGD
6864
transform = Nystroem(gamma=gamma, random_state=random_state)
6965
clf_sgd = SGDOneClassSVM(
@@ -78,25 +74,59 @@
7874
n_error_test_sgd = y_pred_test_sgd[y_pred_test_sgd == -1].size
7975
n_error_outliers_sgd = y_pred_outliers_sgd[y_pred_outliers_sgd == 1].size
8076

81-
Z_sgd = pipe_sgd.decision_function(np.c_[xx.ravel(), yy.ravel()])
82-
Z_sgd = Z_sgd.reshape(xx.shape)
8377

84-
# plot the level sets of the decision function
85-
plt.figure(figsize=(9, 6))
86-
plt.title("One Class SVM")
87-
plt.contourf(xx, yy, Z, levels=np.linspace(Z.min(), 0, 7), cmap=plt.cm.PuBu)
88-
a = plt.contour(xx, yy, Z, levels=[0], linewidths=2, colors="darkred")
89-
plt.contourf(xx, yy, Z, levels=[0, Z.max()], colors="palevioletred")
78+
# %%
79+
from sklearn.inspection import DecisionBoundaryDisplay
80+
81+
_, ax = plt.subplots(figsize=(9, 6))
82+
83+
xx, yy = np.meshgrid(np.linspace(-4.5, 4.5, 50), np.linspace(-4.5, 4.5, 50))
84+
X = np.concatenate([xx.ravel().reshape(-1, 1), yy.ravel().reshape(-1, 1)], axis=1)
85+
DecisionBoundaryDisplay.from_estimator(
86+
clf,
87+
X,
88+
response_method="decision_function",
89+
plot_method="contourf",
90+
ax=ax,
91+
cmap="PuBu",
92+
)
93+
DecisionBoundaryDisplay.from_estimator(
94+
clf,
95+
X,
96+
response_method="decision_function",
97+
plot_method="contour",
98+
ax=ax,
99+
linewidths=2,
100+
colors="darkred",
101+
levels=[0],
102+
)
103+
DecisionBoundaryDisplay.from_estimator(
104+
clf,
105+
X,
106+
response_method="decision_function",
107+
plot_method="contourf",
108+
ax=ax,
109+
colors="palevioletred",
110+
levels=[0, clf.decision_function(X).max()],
111+
)
90112

91113
s = 20
92114
b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c="white", s=s, edgecolors="k")
93115
b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c="blueviolet", s=s, edgecolors="k")
94116
c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c="gold", s=s, edgecolors="k")
95-
plt.axis("tight")
96-
plt.xlim((-4.5, 4.5))
97-
plt.ylim((-4.5, 4.5))
98-
plt.legend(
99-
[a.collections[0], b1, b2, c],
117+
118+
ax.set(
119+
title="One-Class SVM",
120+
xlim=(-4.5, 4.5),
121+
ylim=(-4.5, 4.5),
122+
xlabel=(
123+
f"error train: {n_error_train}/{X_train.shape[0]}; "
124+
f"errors novel regular: {n_error_test}/{X_test.shape[0]}; "
125+
f"errors novel abnormal: {n_error_outliers}/{X_outliers.shape[0]}"
126+
),
127+
)
128+
_ = ax.legend(
129+
[mlines.Line2D([], [], color="darkred", label="learned frontier"), b1, b2, c],
100130
[
101131
"learned frontier",
102132
"training observations",
@@ -105,34 +135,57 @@
105135
],
106136
loc="upper left",
107137
)
108-
plt.xlabel(
109-
"error train: %d/%d; errors novel regular: %d/%d; errors novel abnormal: %d/%d"
110-
% (
111-
n_error_train,
112-
X_train.shape[0],
113-
n_error_test,
114-
X_test.shape[0],
115-
n_error_outliers,
116-
X_outliers.shape[0],
117-
)
118-
)
119-
plt.show()
120138

121-
plt.figure(figsize=(9, 6))
122-
plt.title("Online One-Class SVM")
123-
plt.contourf(xx, yy, Z_sgd, levels=np.linspace(Z_sgd.min(), 0, 7), cmap=plt.cm.PuBu)
124-
a = plt.contour(xx, yy, Z_sgd, levels=[0], linewidths=2, colors="darkred")
125-
plt.contourf(xx, yy, Z_sgd, levels=[0, Z_sgd.max()], colors="palevioletred")
139+
# %%
140+
_, ax = plt.subplots(figsize=(9, 6))
141+
142+
xx, yy = np.meshgrid(np.linspace(-4.5, 4.5, 50), np.linspace(-4.5, 4.5, 50))
143+
X = np.concatenate([xx.ravel().reshape(-1, 1), yy.ravel().reshape(-1, 1)], axis=1)
144+
DecisionBoundaryDisplay.from_estimator(
145+
pipe_sgd,
146+
X,
147+
response_method="decision_function",
148+
plot_method="contourf",
149+
ax=ax,
150+
cmap="PuBu",
151+
)
152+
DecisionBoundaryDisplay.from_estimator(
153+
pipe_sgd,
154+
X,
155+
response_method="decision_function",
156+
plot_method="contour",
157+
ax=ax,
158+
linewidths=2,
159+
colors="darkred",
160+
levels=[0],
161+
)
162+
DecisionBoundaryDisplay.from_estimator(
163+
pipe_sgd,
164+
X,
165+
response_method="decision_function",
166+
plot_method="contourf",
167+
ax=ax,
168+
colors="palevioletred",
169+
levels=[0, pipe_sgd.decision_function(X).max()],
170+
)
126171

127172
s = 20
128173
b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c="white", s=s, edgecolors="k")
129174
b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c="blueviolet", s=s, edgecolors="k")
130175
c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c="gold", s=s, edgecolors="k")
131-
plt.axis("tight")
132-
plt.xlim((-4.5, 4.5))
133-
plt.ylim((-4.5, 4.5))
134-
plt.legend(
135-
[a.collections[0], b1, b2, c],
176+
177+
ax.set(
178+
title="Online One-Class SVM",
179+
xlim=(-4.5, 4.5),
180+
ylim=(-4.5, 4.5),
181+
xlabel=(
182+
f"error train: {n_error_train_sgd}/{X_train.shape[0]}; "
183+
f"errors novel regular: {n_error_test_sgd}/{X_test.shape[0]}; "
184+
f"errors novel abnormal: {n_error_outliers_sgd}/{X_outliers.shape[0]}"
185+
),
186+
)
187+
ax.legend(
188+
[mlines.Line2D([], [], color="darkred", label="learned frontier"), b1, b2, c],
136189
[
137190
"learned frontier",
138191
"training observations",
@@ -141,15 +194,4 @@
141194
],
142195
loc="upper left",
143196
)
144-
plt.xlabel(
145-
"error train: %d/%d; errors novel regular: %d/%d; errors novel abnormal: %d/%d"
146-
% (
147-
n_error_train_sgd,
148-
X_train.shape[0],
149-
n_error_test_sgd,
150-
X_test.shape[0],
151-
n_error_outliers_sgd,
152-
X_outliers.shape[0],
153-
)
154-
)
155197
plt.show()

0 commit comments

Comments
 (0)