|
62 | 62 | # Classification and ROC analysis
|
63 | 63 | # -------------------------------
|
64 | 64 | #
|
65 |
| -# Here we run a :class:`~sklearn.svm.SVC` classifier with cross-validation and |
66 |
| -# plot the ROC curves fold-wise. Notice that the baseline to define the chance |
| 65 | +# Here we run :func:`~sklearn.model_selection.cross_validate` on a |
| 66 | +# :class:`~sklearn.svm.SVC` classifier, then use the computed cross-validation results |
| 67 | +# to plot the ROC curves fold-wise. Notice that the baseline to define the chance |
67 | 68 | # level (dashed ROC curve) is a classifier that would always predict the most
|
68 | 69 | # frequent class.
|
69 | 70 |
|
70 | 71 | import matplotlib.pyplot as plt
|
71 | 72 |
|
72 | 73 | from sklearn import svm
|
73 | 74 | from sklearn.metrics import RocCurveDisplay, auc
|
74 |
| -from sklearn.model_selection import StratifiedKFold |
| 75 | +from sklearn.model_selection import StratifiedKFold, cross_validate |
75 | 76 |
|
76 | 77 | n_splits = 6
|
77 | 78 | cv = StratifiedKFold(n_splits=n_splits)
|
78 | 79 | classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)
|
| 80 | +cv_results = cross_validate( |
| 81 | + classifier, X, y, cv=cv, return_estimator=True, return_indices=True |
| 82 | +) |
| 83 | + |
| 84 | +prop_cycle = plt.rcParams["axes.prop_cycle"] |
| 85 | +colors = prop_cycle.by_key()["color"] |
| 86 | +curve_kwargs_list = [ |
| 87 | + dict(alpha=0.3, lw=1, color=colors[fold % len(colors)]) for fold in range(n_splits) |
| 88 | +] |
| 89 | +names = [f"ROC fold {idx}" for idx in range(n_splits)] |
79 | 90 |
|
80 |
| -tprs = [] |
81 |
| -aucs = [] |
82 | 91 | mean_fpr = np.linspace(0, 1, 100)
|
| 92 | +interp_tprs = [] |
| 93 | + |
| 94 | +_, ax = plt.subplots(figsize=(6, 6)) |
| 95 | +viz = RocCurveDisplay.from_cv_results( |
| 96 | + cv_results, |
| 97 | + X, |
| 98 | + y, |
| 99 | + ax=ax, |
| 100 | + name=names, |
| 101 | + curve_kwargs=curve_kwargs_list, |
| 102 | + plot_chance_level=True, |
| 103 | +) |
83 | 104 |
|
84 |
| -fig, ax = plt.subplots(figsize=(6, 6)) |
85 |
| -for fold, (train, test) in enumerate(cv.split(X, y)): |
86 |
| - classifier.fit(X[train], y[train]) |
87 |
| - viz = RocCurveDisplay.from_estimator( |
88 |
| - classifier, |
89 |
| - X[test], |
90 |
| - y[test], |
91 |
| - name=f"ROC fold {fold}", |
92 |
| - curve_kwargs=dict(alpha=0.3, lw=1), |
93 |
| - ax=ax, |
94 |
| - plot_chance_level=(fold == n_splits - 1), |
95 |
| - ) |
96 |
| - interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr) |
| 105 | +for idx in range(n_splits): |
| 106 | + interp_tpr = np.interp(mean_fpr, viz.fpr[idx], viz.tpr[idx]) |
97 | 107 | interp_tpr[0] = 0.0
|
98 |
| - tprs.append(interp_tpr) |
99 |
| - aucs.append(viz.roc_auc) |
| 108 | + interp_tprs.append(interp_tpr) |
100 | 109 |
|
101 |
| -mean_tpr = np.mean(tprs, axis=0) |
| 110 | +mean_tpr = np.mean(interp_tprs, axis=0) |
102 | 111 | mean_tpr[-1] = 1.0
|
103 | 112 | mean_auc = auc(mean_fpr, mean_tpr)
|
104 |
| -std_auc = np.std(aucs) |
| 113 | +std_auc = np.std(viz.roc_auc) |
| 114 | + |
105 | 115 | ax.plot(
|
106 | 116 | mean_fpr,
|
107 | 117 | mean_tpr,
|
|
111 | 121 | alpha=0.8,
|
112 | 122 | )
|
113 | 123 |
|
114 |
| -std_tpr = np.std(tprs, axis=0) |
| 124 | +std_tpr = np.std(interp_tprs, axis=0) |
115 | 125 | tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
|
116 | 126 | tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
|
117 | 127 | ax.fill_between(
|
|
0 commit comments