diff --git a/sklearn/inspection/_partial_dependence.py b/sklearn/inspection/_partial_dependence.py index a1a9b6915a17a..5c878b12d3eeb 100644 --- a/sklearn/inspection/_partial_dependence.py +++ b/sklearn/inspection/_partial_dependence.py @@ -22,6 +22,13 @@ from ..utils import _get_column_indices from ..utils.validation import check_is_fitted from ..utils import Bunch +from ..utils._param_validation import ( + HasMethods, + Integral, + Interval, + StrOptions, + validate_params, +) from ..tree import DecisionTreeRegressor from ..ensemble import RandomForestRegressor from ..exceptions import NotFittedError @@ -223,6 +230,24 @@ def _partial_dependence_brute(est, grid, features, X, response_method): return averaged_predictions, predictions +@validate_params( + { + "estimator": [ + HasMethods(["fit", "predict"]), + HasMethods(["fit", "predict_proba"]), + HasMethods(["fit", "decision_function"]), + ], + "X": ["array-like", "sparse matrix"], + "features": ["array-like", Integral, str], + "categorical_features": ["array-like", None], + "feature_names": ["array-like", None], + "response_method": [StrOptions({"auto", "predict_proba", "decision_function"})], + "percentiles": [tuple], + "grid_resolution": [Interval(Integral, 1, None, closed="left")], + "method": [StrOptions({"auto", "recursion", "brute"})], + "kind": [StrOptions({"average", "individual", "both"})], + } +) def partial_dependence( estimator, X, @@ -268,13 +293,13 @@ def partial_dependence( :term:`predict_proba`, or :term:`decision_function`. Multioutput-multiclass classifiers are not supported. - X : {array-like or dataframe} of shape (n_samples, n_features) + X : {array-like, sparse matrix or dataframe} of shape (n_samples, n_features) ``X`` is used to generate a grid of values for the target ``features`` (where the partial dependence will be evaluated), and also to generate values for the complement features when the `method` is 'brute'. - features : array-like of {int, str} + features : array-like of {int, str, bool} or int or str The feature (e.g. `[0]`) or pair of interacting features (e.g. `[(0, 1)]`) for which the partial dependency should be computed. @@ -425,27 +450,12 @@ def partial_dependence( if not (hasattr(X, "__array__") or sparse.issparse(X)): X = check_array(X, force_all_finite="allow-nan", dtype=object) - accepted_responses = ("auto", "predict_proba", "decision_function") - if response_method not in accepted_responses: - raise ValueError( - "response_method {} is invalid. Accepted response_method names " - "are {}.".format(response_method, ", ".join(accepted_responses)) - ) - if is_regressor(estimator) and response_method != "auto": raise ValueError( "The response_method parameter is ignored for regressors and " "must be 'auto'." ) - accepted_methods = ("brute", "recursion", "auto") - if method not in accepted_methods: - raise ValueError( - "method {} is invalid. Accepted method names are {}.".format( - method, ", ".join(accepted_methods) - ) - ) - if kind != "average": if method == "recursion": raise ValueError( diff --git a/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py b/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py index 8e55d44a435bd..3968cef3b832d 100644 --- a/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py +++ b/sklearn/inspection/_plot/tests/test_plot_partial_dependence.py @@ -611,16 +611,6 @@ def test_plot_partial_dependence_dataframe(pyplot, clf_diabetes, diabetes): {"features": [1], "categorical_features": [1], "kind": "individual"}, "It is not possible to display individual effects", ), - ( - dummy_classification_data, - {"features": [1], "kind": "foo"}, - "Values provided to `kind` must be one of", - ), - ( - dummy_classification_data, - {"features": [0, 1], "kind": ["foo", "individual"]}, - "Values provided to `kind` must be one of", - ), ], ) def test_plot_partial_dependence_error(pyplot, data, params, err_msg): diff --git a/sklearn/inspection/tests/test_partial_dependence.py b/sklearn/inspection/tests/test_partial_dependence.py index a42f80da301d1..71c3f1d61cc14 100644 --- a/sklearn/inspection/tests/test_partial_dependence.py +++ b/sklearn/inspection/tests/test_partial_dependence.py @@ -510,31 +510,6 @@ def fit(self, X, y): {"features": [0], "response_method": "predict_proba", "method": "auto"}, "'recursion' method, the response_method must be 'decision_function'", ), - ( - GradientBoostingClassifier(random_state=0), - {"features": [0], "response_method": "blahblah"}, - "response_method blahblah is invalid. Accepted response_method", - ), - ( - NoPredictProbaNoDecisionFunction(), - {"features": [0], "response_method": "auto"}, - "The estimator has no predict_proba and no decision_function method", - ), - ( - NoPredictProbaNoDecisionFunction(), - {"features": [0], "response_method": "predict_proba"}, - "The estimator has no predict_proba method.", - ), - ( - NoPredictProbaNoDecisionFunction(), - {"features": [0], "response_method": "decision_function"}, - "The estimator has no decision_function method.", - ), - ( - LinearRegression(), - {"features": [0], "method": "blahblah"}, - "blahblah is invalid. Accepted method names are brute, recursion, auto", - ), ( LinearRegression(), {"features": [0], "method": "recursion", "kind": "individual"}, @@ -560,24 +535,6 @@ def test_partial_dependence_error(estimator, params, err_msg): partial_dependence(estimator, X, **params) -@pytest.mark.parametrize( - "with_dataframe, err_msg", - [ - (True, "Only array-like or scalar are supported"), - (False, "Only array-like or scalar are supported"), - ], -) -def test_partial_dependence_slice_error(with_dataframe, err_msg): - X, y = make_classification(random_state=0) - if with_dataframe: - pd = pytest.importorskip("pandas") - X = pd.DataFrame(X) - estimator = LogisticRegression().fit(X, y) - - with pytest.raises(TypeError, match=err_msg): - partial_dependence(estimator, X, features=slice(0, 2, 1)) - - @pytest.mark.parametrize( "estimator", [LinearRegression(), GradientBoostingClassifier(random_state=0)] ) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index d46ae07821ac2..e0b4920be6d36 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -165,6 +165,7 @@ def _check_function_param_validation( "sklearn.feature_selection.mutual_info_classif", "sklearn.feature_selection.mutual_info_regression", "sklearn.feature_selection.r_regression", + "sklearn.inspection.partial_dependence", "sklearn.inspection.permutation_importance", "sklearn.linear_model.orthogonal_mp", "sklearn.metrics.accuracy_score",