Skip to content

Commit 9082e98

Browse files
ahoslerliudmylaru
andauthored
ODSC-39815: roc curve bug (#189)
Co-authored-by: Liuda Rudenka <liuda.rudenka@oracle.com>
1 parent dd8f2a9 commit 9082e98

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

ads/evaluations/statistical_metrics.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,20 +198,27 @@ def _get_binary_metrics(self):
198198
self.y_true, self.y_pred, pos_label=self.positive_class, average="binary"
199199
)
200200

201-
(
202-
self.metrics["false_positive_rate"],
203-
self.metrics["true_positive_rate"],
204-
_,
205-
) = metrics.roc_curve(self.y_true, self.y_pred, pos_label=self.positive_class)
206-
self.metrics["auc"] = metrics.auc(
207-
self.metrics["false_positive_rate"], self.metrics["true_positive_rate"]
208-
)
209-
210201
if self.y_score is not None:
211202
if not all(0 >= x >= 1 for x in self.y_score):
212203
self.y_score = np.asarray(
213204
[0 if x < 0 else 1 if x > 1 else x for x in self.y_score]
214205
)
206+
if len(np.asarray(self.y_score).shape) > 1:
207+
# If the SKLearn classifier doesn't correctly identify the problem as
208+
# binary classification, y_score may be of shape (n_rows, 2)
209+
# instead of (n_rows,)
210+
pos_class_idx = self.classes.index(self.positive_class)
211+
positive_class_scores = self.y_score[:, pos_class_idx]
212+
else:
213+
positive_class_scores = self.y_score
214+
(
215+
self.metrics["false_positive_rate"],
216+
self.metrics["true_positive_rate"],
217+
_,
218+
) = metrics.roc_curve(y_true=self.y_true, y_score=positive_class_scores, pos_label=self.positive_class)
219+
self.metrics["auc"] = metrics.auc(
220+
self.metrics["false_positive_rate"], self.metrics["true_positive_rate"]
221+
)
215222
self.y_score = list(self.y_score)
216223
self.metrics["youden_j"] = (
217224
self.metrics["true_positive_rate"] - self.metrics["false_positive_rate"]

tests/unitary/with_extras/evaluator/test_evaluations_evaluator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_auc_against_sklearn(self):
159159
rf_clf = RandomForestClassifier(n_estimators=10, random_state=0).fit(
160160
X_train, y_train
161161
)
162-
svc_clf = SVC(kernel="linear", C=1.0, random_state=0).fit(X_train, y_train)
162+
svc_clf = SVC(kernel="linear", C=1.0, random_state=0, probability=True).fit(X_train, y_train)
163163

164164
bin_lr_model = ADSModel.from_estimator(lr_clf, classes=[0, 1])
165165
bin_rf_model = ADSModel.from_estimator(rf_clf, classes=[0, 1])
@@ -178,7 +178,8 @@ def test_auc_against_sklearn(self):
178178
("rf", bin_rf_model),
179179
("svc", svc_model),
180180
]:
181-
fpr, tpr, _ = roc_curve(test.y, model.est.predict(test.X), pos_label=1)
181+
pos_label_idx = model.classes_.index(1)
182+
fpr, tpr, _ = roc_curve(test.y, model.est.predict_proba(test.X)[:,pos_label_idx], pos_label=1)
182183
sklearn_metrics[model_type] = round(auc(fpr, tpr), 4)
183184

184185
assert (

0 commit comments

Comments
 (0)