8000 MAINT small refactoring in partial_dependence (#30104) · scikit-learn/scikit-learn@bfa0d65 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit bfa0d65

Browse files
authored
MAINT small refactoring in partial_dependence (#30104)
1 parent fd07977 commit bfa0d65

File tree

1 file changed

+16
-40
lines changed

1 file changed

+16
-40
lines changed

sklearn/inspection/_partial_dependence.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from ..ensemble._hist_gradient_boosting.gradient_boosting import (
1616
BaseHistGradientBoosting,
1717
)
18-
from ..exceptions import NotFittedError
1918
from ..tree import DecisionTreeRegressor
2019
from ..utils import Bunch, _safe_indexing, check_array
2120
from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_assign
@@ -27,6 +26,7 @@
2726
StrOptions,
2827
validate_params,
2928
)
29+
from ..utils._response import _get_response_values
3030
from ..utils.extmath import cartesian
3131
from ..utils.validation import _check_sample_weight, check_is_fitted
3232
from ._pd_utils import _check_feature_names, _get_feature_index
@@ -261,51 +261,27 @@ def _partial_dependence_brute(
261261
predictions = []
262262
averaged_predictions = []
263263

264-
# define the prediction_method (predict, predict_proba, decision_function).
265-
if is_regressor(est):
266-
prediction_method = est.predict
267-
else:
268-
predict_proba = getattr(est, "predict_proba", None)
269-
decision_function = getattr(est, "decision_function", None)
270-
if response_method == "auto":
271-
# try predict_proba, then decision_function if it doesn't exist
272-
prediction_method = predict_proba or decision_function
273-
else:
274-
prediction_method = (
275-
predict_proba
276-
if response_method == "predict_proba"
277-
else decision_function
278-
)
279-
if prediction_method is None:
280-
if response_method == "auto":
281-
raise ValueError(
282-
"The estimator has no predict_proba and no "
283-
"decision_function method."
284-
)
285-
elif response_method == "predict_proba":
286-
raise ValueError("The estimator has no predict_proba method.")
287-
else:
288-
raise ValueError("The estimator has no decision_function method.")
264+
if response_method == "auto":
265+
response_method = (
266+
"predict" if is_regressor(est) else ["predict_proba", "decision_function"]
267+
)
289268

290269
X_eval = X.copy()
291270
for new_values in grid:
292271
for i, variable in enumerate(features):
293272
_safe_assign(X_eval, new_values[i], column_indexer=variable)
294273

295-
try:
296-
# Note: predictions is of shape
297-
# (n_points,) for non-multioutput regressors
298-
# (n_points, n_tasks) for multioutput regressors
299-
# (n_points, 1) for the regressors in cross_decomposition (I think)
300-
# (n_points, 2) for binary classification
301-
# (n_points, n_classes) for multiclass classification
302-
pred = prediction_method(X_eval)
303-
304-
predictions.append(pred)
305-
# average over samples
306-
averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight))
307-
except NotFittedError as e:
308-
raise ValueError("'estimator' parameter must be a fitted estimator") from e
274+
# Note: predictions is of shape
275+
# (n_points,) for non-multioutput regressors
276+
# (n_points, n_tasks) for multioutput regressors
277+
# (n_points, 1) for the regressors in cross_decomposition (I think)
278+
# (n_points, 2) for binary classification
279+
# (n_points, n_classes) for multiclass classification
280+
pred, _ = _get_response_values(est, X_eval, response_method=response_method)
281+
282+
predictions.append(pred)
283+
# average over samples
284+
averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight))
309285

310286
n_samples = X.shape[0]
311287

0 commit comments

Comments
 (0)
0