Skip to content

Commit 8228d02

Browse files
committed
fix: handle case where scikit model output not provided
1 parent 72097a3 commit 8228d02

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/sasctl/utils/model_info.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,14 +313,22 @@ class SklearnModelInfo(ModelInfo):
313313
}
314314

315315
def __init__(self, model, X, y):
316-
# Ensure input/output is a DataFrame for consistency
317-
X_df = pd.DataFrame(X)
318-
y_df = pd.DataFrame(y)
316+
319317

320318
is_classifier = hasattr(model, "classes_")
321319
is_binary_classifier = is_classifier and len(model.classes_) == 2
322320
is_clusterer = hasattr(model, "cluster_centers_")
323321

322+
if y is None:
323+
if hasattr(model, "predict_proba"):
324+
y = model.predict_proba(X)
325+
else:
326+
y = model.predict(X)
327+
328+
# Ensure input/output is a DataFrame for consistency
329+
X_df = pd.DataFrame(X)
330+
y_df = pd.DataFrame(y)
331+
324332
# If not a classfier or a clustering algorithm and output is a single column, then
325333
# assume its a regression algorithm
326334
is_regressor = not is_classifier and not is_clusterer and (y_df.shape[1] == 1 or "Regress" in type(model).__name__)

0 commit comments

Comments
 (0)