diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py index d90411a0c8bfa..2590919401ee0 100644 --- a/sklearn/preprocessing/_polynomial.py +++ b/sklearn/preprocessing/_polynomial.py @@ -3,6 +3,7 @@ """ import collections import numbers +from numbers import Integral from itertools import chain, combinations from itertools import combinations_with_replacement as combinations_w_r @@ -15,6 +16,7 @@ from ..utils import check_array from ..utils.deprecation import deprecated from ..utils.validation import check_is_fitted, FLOAT_DTYPES, _check_sample_weight +from ..utils._param_validation import Interval, StrOptions from ..utils.validation import _check_feature_names_in from ..utils.stats import _weighted_percentile @@ -627,6 +629,17 @@ class SplineTransformer(TransformerMixin, BaseEstimator): [0. , 0. , 0.5 , 0.5 ]]) """ + _parameter_constraints = { + "n_knots": [Interval(Integral, 2, None, closed="left")], + "degree": [Interval(Integral, 0, None, closed="left")], + "knots": [StrOptions({"uniform", "quantile"}), "array-like"], + "extrapolation": [ + StrOptions({"error", "constant", "linear", "continue", "periodic"}) + ], + "include_bias": ["boolean"], + "order": [StrOptions({"C", "F"})], + } + def __init__( self, n_knots=5, @@ -767,6 +780,8 @@ def fit(self, X, y=None, sample_weight=None): self : object Fitted transformer. """ + self._validate_params() + X = self._validate_data( X, reset=True, @@ -779,20 +794,7 @@ def fit(self, X, y=None, sample_weight=None): _, n_features = X.shape - if not (isinstance(self.degree, numbers.Integral) and self.degree >= 0): - raise ValueError( - f"degree must be a non-negative integer, got {self.degree}." - ) - - if isinstance(self.knots, str) and self.knots in [ - "uniform", - "quantile", - ]: - if not (isinstance(self.n_knots, numbers.Integral) and self.n_knots >= 2): - raise ValueError( - f"n_knots must be a positive integer >= 2, got: {self.n_knots}" - ) - + if isinstance(self.knots, str): base_knots = self._get_base_knot_positions( X, n_knots=self.n_knots, knots=self.knots, sample_weight=sample_weight ) @@ -805,21 +807,6 @@ def fit(self, X, y=None, sample_weight=None): elif not np.all(np.diff(base_knots, axis=0) > 0): raise ValueError("knots must be sorted without duplicates.") - if self.extrapolation not in ( - "error", - "constant", - "linear", - "continue", - "periodic", - ): - raise ValueError( - "extrapolation must be one of 'error', " - "'constant', 'linear', 'continue' or 'periodic'." - ) - - if not isinstance(self.include_bias, (bool, np.bool_)): - raise ValueError("include_bias must be bool.") - # number of knots for base interval n_knots = base_knots.shape[0] @@ -936,7 +923,6 @@ def transform(self, X): spl = self.bsplines_[i] if self.extrapolation in ("continue", "error", "periodic"): - if self.extrapolation == "periodic": # With periodic extrapolation we map x to the segment # [spl.t[k], spl.t[n]]. diff --git a/sklearn/preprocessing/tests/test_polynomial.py b/sklearn/preprocessing/tests/test_polynomial.py index 0ab8f0f335f43..7ef420796dd44 100644 --- a/sklearn/preprocessing/tests/test_polynomial.py +++ b/sklearn/preprocessing/tests/test_polynomial.py @@ -28,80 +28,6 @@ def is_c_contiguous(a): assert np.isfortran(est(order="F").fit_transform(X)) -@pytest.mark.parametrize( - "params, err_msg", - [ - ({"degree": -1}, "degree must be a non-negative integer"), - ({"degree": 2.5}, "degree must be a non-negative integer"), - ({"degree": "string"}, "degree must be a non-negative integer"), - ({"n_knots": 1}, "n_knots must be a positive integer >= 2."), - ({"n_knots": 1}, "n_knots must be a positive integer >= 2."), - ({"n_knots": 2.5}, "n_knots must be a positive integer >= 2."), - ({"n_knots": "string"}, "n_knots must be a positive integer >= 2."), - ({"knots": 1}, "Expected 2D array, got scalar array instead:"), - ({"knots": [1, 2]}, "Expected 2D array, got 1D array instead:"), - ( - {"knots": [[1]]}, - r"Number of knots, knots.shape\[0\], must be >= 2.", - ), - ( - {"knots": [[1, 5], [2, 6]]}, - r"knots.shape\[1\] == n_features is violated.", - ), - ( - {"knots": [[1], [1], [2]]}, - "knots must be sorted without duplicates.", - ), - ({"knots": [[2], [1]]}, "knots must be sorted without duplicates."), - ( - {"extrapolation": None}, - "extrapolation must be one of 'error', 'constant', 'linear', " - "'continue' or 'periodic'.", - ), - ( - {"extrapolation": 1}, - "extrapolation must be one of 'error', 'constant', 'linear', " - "'continue' or 'periodic'.", - ), - ( - {"extrapolation": "string"}, - "extrapolation must be one of 'error', 'constant', 'linear', " - "'continue' or 'periodic'.", - ), - ({"include_bias": None}, "include_bias must be bool."), - ({"include_bias": 1}, "include_bias must be bool."), - ({"include_bias": "string"}, "include_bias must be bool."), - ( - {"extrapolation": "periodic", "n_knots": 3, "degree": 3}, - "Periodic splines require degree < n_knots. Got n_knots=3 and degree=3.", - ), - ( - {"extrapolation": "periodic", "knots": [[0], [1]], "degree": 2}, - "Periodic splines require degree < n_knots. Got n_knots=2 and degree=2.", - ), - ], -) -def test_spline_transformer_input_validation(params, err_msg): - """Test that we raise errors for invalid input in SplineTransformer.""" - X = [[1], [2]] - - with pytest.raises(ValueError, match=err_msg): - SplineTransformer(**params).fit(X) - - -def test_spline_transformer_manual_knot_input(): - """ - Test that array-like knot positions in SplineTransformer are accepted. - """ - X = np.arange(20).reshape(10, 2) - knots = [[0.5, 1], [1.5, 2], [5, 10]] - st1 = SplineTransformer(degree=3, knots=knots, n_knots=None).fit(X) - knots = np.asarray(knots) - st2 = SplineTransformer(degree=3, knots=knots, n_knots=None).fit(X) - for i in range(X.shape[1]): - assert_allclose(st1.bsplines_[i].t, st2.bsplines_[i].t) - - @pytest.mark.parametrize("extrapolation", ["continue", "periodic"]) def test_spline_transformer_integer_knots(extrapolation): """Test that SplineTransformer accepts integer value knot positions.""" @@ -238,31 +164,6 @@ def test_spline_transformer_get_base_knot_positions( assert_allclose(base_knots, expected_knots) -@pytest.mark.parametrize( - "knots, n_knots, degree", - [ - ("uniform", 5, 3), - ("uniform", 12, 8), - ( - [[-1.0, 0.0], [0, 1.0], [0.1, 2.0], [0.2, 3.0], [0.3, 4.0], [1, 5.0]], - None, - 3, - ), - ], -) -def test_spline_transformer_periodicity_of_extrapolation(knots, n_knots, degree): - """Test that the SplineTransformer is periodic for multiple features.""" - X_1 = np.linspace((-1, 0), (1, 5), 10) - X_2 = np.linspace((1, 5), (3, 10), 10) - - splt = SplineTransformer( - knots=knots, n_knots=n_knots, degree=degree, extrapolation="periodic" - ) - splt.fit(X_1) - - assert_allclose(splt.transform(X_1), splt.transform(X_2)) - - @pytest.mark.parametrize(["bias", "intercept"], [(True, False), (False, True)]) def test_spline_transformer_periodic_linear_regression(bias, intercept): """Test that B-splines fit a periodic curve pretty well.""" diff --git a/sklearn/tests/test_common.py b/sklearn/tests/test_common.py index 9d7c53113bcf6..ab90a0dda301c 100644 --- a/sklearn/tests/test_common.py +++ b/sklearn/tests/test_common.py @@ -517,7 +517,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator): "SpectralBiclustering", "SpectralCoclustering", "SpectralEmbedding", - "SplineTransformer", "TransformedTargetRegressor", ]