@@ -198,20 +198,27 @@ def _get_binary_metrics(self):
198
198
self .y_true , self .y_pred , pos_label = self .positive_class , average = "binary"
199
199
)
200
200
201
- (
202
- self .metrics ["false_positive_rate" ],
203
- self .metrics ["true_positive_rate" ],
204
- _ ,
205
- ) = metrics .roc_curve (self .y_true , self .y_pred , pos_label = self .positive_class )
206
- self .metrics ["auc" ] = metrics .auc (
207
- self .metrics ["false_positive_rate" ], self .metrics ["true_positive_rate" ]
208
- )
209
-
210
201
if self .y_score is not None :
211
202
if not all (0 >= x >= 1 for x in self .y_score ):
212
203
self .y_score = np .asarray (
213
204
[0 if x < 0 else 1 if x > 1 else x for x in self .y_score ]
214
205
)
206
+ if len (np .asarray (self .y_score ).shape ) > 1 :
207
+ # If the SKLearn classifier doesn't correctly identify the problem as
208
+ # binary classification, y_score may be of shape (n_rows, 2)
209
+ # instead of (n_rows,)
210
+ pos_class_idx = self .classes .index (self .positive_class )
211
+ positive_class_scores = self .y_score [:, pos_class_idx ]
212
+ else :
213
+ positive_class_scores = self .y_score
214
+ (
215
+ self .metrics ["false_positive_rate" ],
216
+ self .metrics ["true_positive_rate" ],
217
+ _ ,
218
+ ) = metrics .roc_curve (y_true = self .y_true , y_score = positive_class_scores , pos_label = self .positive_class )
219
+ self .metrics ["auc" ] = metrics .auc (
220
+ self .metrics ["false_positive_rate" ], self .metrics ["true_positive_rate" ]
221
+ )
215
222
self .y_score = list (self .y_score )
216
223
self .metrics ["youden_j" ] = (
217
224
self .metrics ["true_positive_rate" ] - self .metrics ["false_positive_rate" ]
0 commit comments