File tree Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Expand file tree Collapse file tree 1 file changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -313,14 +313,22 @@ class SklearnModelInfo(ModelInfo):
313
313
}
314
314
315
315
def __init__ (self , model , X , y ):
316
- # Ensure input/output is a DataFrame for consistency
317
- X_df = pd .DataFrame (X )
318
- y_df = pd .DataFrame (y )
316
+
319
317
320
318
is_classifier = hasattr (model , "classes_" )
321
319
is_binary_classifier = is_classifier and len (model .classes_ ) == 2
322
320
is_clusterer = hasattr (model , "cluster_centers_" )
323
321
322
+ if y is None :
323
+ if hasattr (model , "predict_proba" ):
324
+ y = model .predict_proba (X )
325
+ else :
326
+ y = model .predict (X )
327
+
328
+ # Ensure input/output is a DataFrame for consistency
329
+ X_df = pd .DataFrame (X )
330
+ y_df = pd .DataFrame (y )
331
+
324
332
# If not a classfier or a clustering algorithm and output is a single column, then
325
333
# assume its a regression algorithm
326
334
is_regressor = not is_classifier and not is_clusterer and (y_df .shape [1 ] == 1 or "Regress" in type (model ).__name__ )
You can’t perform that action at this time.
0 commit comments