|
15 | 15 | from ..ensemble._hist_gradient_boosting.gradient_boosting import (
|
16 | 16 | BaseHistGradientBoosting,
|
17 | 17 | )
|
18 |
| -from ..exceptions import NotFittedError |
19 | 18 | from ..tree import DecisionTreeRegressor
|
20 | 19 | from ..utils import Bunch, _safe_indexing, check_array
|
21 | 20 | from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_assign
|
|
27 | 26 | StrOptions,
|
28 | 27 | validate_params,
|
29 | 28 | )
|
| 29 | +from ..utils._response import _get_response_values |
30 | 30 | from ..utils.extmath import cartesian
|
31 | 31 | from ..utils.validation import _check_sample_weight, check_is_fitted
|
32 | 32 | from ._pd_utils import _check_feature_names, _get_feature_index
|
@@ -261,51 +261,27 @@ def _partial_dependence_brute(
|
261 | 261 | predictions = []
|
262 | 262 | averaged_predictions = []
|
263 | 263 |
|
264 |
| - # define the prediction_method (predict, predict_proba, decision_function). |
265 |
| - if is_regressor(est): |
266 |
| - prediction_method = est.predict |
267 |
| - else: |
268 |
| - predict_proba = getattr(est, "predict_proba", None) |
269 |
| - decision_function = getattr(est, "decision_function", None) |
270 |
| - if response_method == "auto": |
271 |
| - # try predict_proba, then decision_function if it doesn't exist |
272 |
| - prediction_method = predict_proba or decision_function |
273 |
| - else: |
274 |
| - prediction_method = ( |
275 |
| - predict_proba |
276 |
| - if response_method == "predict_proba" |
277 |
| - else decision_function |
278 |
| - ) |
279 |
| - if prediction_method is None: |
280 |
| - if response_method == "auto": |
281 |
| - raise ValueError( |
282 |
| - "The estimator has no predict_proba and no " |
283 |
| - "decision_function method." |
284 |
| - ) |
285 |
| - elif response_method == "predict_proba": |
286 |
| - raise ValueError("The estimator has no predict_proba method.") |
287 |
| - else: |
288 |
| - raise ValueError("The estimator has no decision_function method.") |
| 264 | + if response_method == "auto": |
| 265 | + response_method = ( |
| 266 | + "predict" if is_regressor(est) else ["predict_proba", "decision_function"] |
| 267 | + ) |
289 | 268 |
|
290 | 269 | X_eval = X.copy()
|
291 | 270 | for new_values in grid:
|
292 | 271 | for i, variable in enumerate(features):
|
293 | 272 | _safe_assign(X_eval, new_values[i], column_indexer=variable)
|
294 | 273 |
|
295 |
| - try: |
296 |
| - # Note: predictions is of shape |
297 |
| - # (n_points,) for non-multioutput regressors |
298 |
| - # (n_points, n_tasks) for multioutput regressors |
299 |
| - # (n_points, 1) for the regressors in cross_decomposition (I think) |
300 |
| - # (n_points, 2) for binary classification |
301 |
| - # (n_points, n_classes) for multiclass classification |
302 |
| - pred = prediction_method(X_eval) |
303 |
| - |
304 |
| - predictions.append(pred) |
305 |
| - # average over samples |
306 |
| - averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight)) |
307 |
| - except NotFittedError as e: |
308 |
| - raise ValueError("'estimator' parameter must be a fitted estimator") from e |
| 274 | + # Note: predictions is of shape |
| 275 | + # (n_points,) for non-multioutput regressors |
| 276 | + # (n_points, n_tasks) for multioutput regressors |
| 277 | + # (n_points, 1) for the regressors in cross_decomposition (I think) |
| 278 | + # (n_points, 2) for binary classification |
| 279 | + # (n_points, n_classes) for multiclass classification |
| 280 | + pred, _ = _get_response_values(est, X_eval, response_method=response_method) |
| 281 | + |
| 282 | + predictions.append(pred) |
| 283 | + # average over samples |
| 284 | + averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight)) |
309 | 285 |
|
310 | 286 | n_samples = X.shape[0]
|
311 | 287 |
|
|
0 commit comments