Skip to content

Commit bfa0d65

Browse files
authored
MAINT small refactoring in partial_dependence (scikit-learn#30104)
1 parent fd07977 commit bfa0d65

File tree

1 file changed

+16
-40
lines changed

1 file changed

+16
-40
lines changed

sklearn/inspection/_partial_dependence.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from ..ensemble._hist_gradient_boosting.gradient_boosting import (
1616
BaseHistGradientBoosting,
1717
)
18-
from ..exceptions import NotFittedError
1918
from ..tree import DecisionTreeRegressor
2019
from ..utils import Bunch, _safe_indexing, check_array
2120
from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_assign
@@ -27,6 +26,7 @@
2726
StrOptions,
2827
validate_params,
2928
)
29+
from ..utils._response import _get_response_values
3030
from ..utils.extmath import cartesian
3131
from ..utils.validation import _check_sample_weight, check_is_fitted
3232
from ._pd_utils import _check_feature_names, _get_feature_index
@@ -261,51 +261,27 @@ def _partial_dependence_brute(
261261
predictions = []
262262
averaged_predictions = []
263263

264-
# define the prediction_method (predict, predict_proba, decision_function).
265-
if is_regressor(est):
266-
prediction_method = est.predict
267-
else:
268-
predict_proba = getattr(est, "predict_proba", None)
269-
decision_function = getattr(est, "decision_function", None)
270-
if response_method == "auto":
271-
# try predict_proba, then decision_function if it doesn't exist
272-
prediction_method = predict_proba or decision_function
273-
else:
274-
prediction_method = (
275-
predict_proba
276-
if response_method == "predict_proba"
277-
else decision_function
278-
)
279-
if prediction_method is None:
280-
if response_method == "auto":
281-
raise ValueError(
282-
"The estimator has no predict_proba and no "
283-
"decision_function method."
284-
)
285-
elif response_method == "predict_proba":
286-
raise ValueError("The estimator has no predict_proba method.")
287-
else:
288-
raise ValueError("The estimator has no decision_function method.")
264+
if response_method == "auto":
265+
response_method = (
266+
"predict" if is_regressor(est) else ["predict_proba", "decision_function"]
267+
)
289268

290269
X_eval = X.copy()
291270
for new_values in grid:
292271
for i, variable in enumerate(features):
293272
_safe_assign(X_eval, new_values[i], column_indexer=variable)
294273

295-
try:
296-
# Note: predictions is of shape
297-
# (n_points,) for non-multioutput regressors
298-
# (n_points, n_tasks) for multioutput regressors
299-
# (n_points, 1) for the regressors in cross_decomposition (I think)
300-
# (n_points, 2) for binary classification
301-
# (n_points, n_classes) for multiclass classification
302-
pred = prediction_method(X_eval)
303-
304-
predictions.append(pred)
305-
# average over samples
306-
averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight))
307-
except NotFittedError as e:
308-
raise ValueError("'estimator' parameter must be a fitted estimator") from e
274+
# Note: predictions is of shape
275+
# (n_points,) for non-multioutput regressors
276+
# (n_points, n_tasks) for multioutput regressors
277+
# (n_points, 1) for the regressors in cross_decomposition (I think)
278+
# (n_points, 2) for binary classification
279+
# (n_points, n_classes) for multiclass classification
280+
pred, _ = _get_response_values(est, X_eval, response_method=response_method)
281+
282+
predictions.append(pred)
283+
# average over samples
284+
averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight))
309285

310286
n_samples = X.shape[0]
311287

0 commit comments

Comments
 (0)