8000 MAINT add parameters validation for SplineTransformer by kasmith11 · Pull Request #24057 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT add parameters validation for SplineTransformer #24057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 16 additions & 30 deletions sklearn/preprocessing/_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import collections
import numbers
from numbers import Integral
from itertools import chain, combinations
from itertools import combinations_with_replacement as combinations_w_r

Expand All @@ -15,6 +16,7 @@
from ..utils import check_array
from ..utils.deprecation import deprecated
from ..utils.validation import check_is_fitted, FLOAT_DTYPES, _check_sample_weight
from ..utils._param_validation import Interval, StrOptions
from ..utils.validation import _check_feature_names_in
from ..utils.stats import _weighted_percentile

Expand Down Expand Up @@ -627,6 +629,17 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
[0. , 0. , 0.5 , 0.5 ]])
"""

_parameter_constraints = {
"n_knots": [Interval(Integral, 2, None, closed="left")],
"degree": [Interval(Integral, 0, None, closed="left")],
"knots": [StrOptions({"uniform", "quantile"}), "array-like"],
"extrapolation": [
StrOptions({"error", "constant", "linear", "continue", "periodic"})
],
"include_bias": ["boolean"],
"order": [StrOptions({"C", "F"})],
}

def __init__(
self,
n_knots=5,
Expand Down Expand Up @@ -767,6 +780,8 @@ def fit(self, X, y=None, sample_weight=None):
self : object
Fitted transformer.
"""
self._validate_params()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now some parameters will be validated and you need to remove some extra unnecessary checks. I check and you should have the following diff in fit:

diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py
index 585c21eb2d..0b651e1f2f 100644
--- a/sklearn/preprocessing/_polynomial.py
+++ b/sklearn/preprocessing/_polynomial.py
@@ -794,20 +794,8 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
 
         _, n_features = X.shape
 
-        if not (isinstance(self.degree, numbers.Integral) and self.degree >= 0):
-            raise ValueError(
-                f"degree must be a non-negative integer, got {self.degree}."
-            )
-
-        if isinstance(self.knots, str) and self.knots in [
-            "uniform",
-            "quantile",
-        ]:
-            if not (isinstance(self.n_knots, numbers.Integral) and self.n_knots >= 2):
-                raise ValueError(
-                    f"n_knots must be a positive integer >= 2, got: {self.n_knots}"
-                )
 
+        if isinstance(self.knots, str):
             base_knots = self._get_base_knot_positions(
                 X, n_knots=self.n_knots, knots=self.knots, sample_weight=sample_weight
             )
@@ -820,20 +808,6 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
             elif not np.all(np.diff(base_knots, axis=0) > 0):
                 raise ValueError("knots must be sorted without duplicates.")
 
-        if self.extrapolation not in (
-            "error",
-            "constant",
-            "linear",
-            "continue",
-            "periodic",
-        ):
-            raise ValueError(
-                "extrapolation must be one of 'error', "
-                "'constant', 'linear', 'continue' or 'periodic'."
-            )
-
-        if not isinstance(self.include_bias, (bool, np.bool_)):
-            raise ValueError("include_bias must be bool.")
 
         # number of knots for base interval
         n_knots = base_knots.shape[0]


X = self._validate_data(
X,
reset=True,
Expand All @@ -779,20 +794,7 @@ def fit(self, X, y=None, sample_weight=None):

_, n_features = X.shape

if not (isinstance(self.degree, numbers.Integral) and self.degree >= 0):
raise ValueError(
f"degree must be a non-negative integer, got {self.degree}."
)

if isinstance(self.knots, str) and self.knots in [
"uniform",
"quantile",
]:
if not (isinstance(self.n_knots, numbers.Integral) and self.n_knots >= 2):
raise ValueError(
f"n_knots must be a positive integer >= 2, got: {self.n_knots}"
)

if isinstance(self.knots, str):
base_knots = self._get_base_knot_positions(
X, n_knots=self.n_knots, knots=self.knots, sample_weight=sample_weight
)
Expand All @@ -805,21 +807,6 @@ def fit(self, X, y=None, sample_weight=None):
elif not np.all(np.diff(base_knots, axis=0) > 0):
raise ValueError("knots must be sorted without duplicates.")

if self.extrapolation not in (
"error",
"constant",
"linear",
"continue",
"periodic",
):
raise ValueError(
"extrapolation must be one of 'error', "
"'constant', 'linear', 'continue' or 'periodic'."
)

if not isinstance(self.include_bias, (bool, np.bool_)):
raise ValueError("include_bias must be bool.")

# number of knots for base interval
n_knots = base_knots.shape[0]

Expand Down Expand Up @@ -936,7 +923,6 @@ def transform(self, X):
spl = self.bsplines_[i]

if self.extrapolation in ("continue", "error", "periodic"):

if self.extrapolation == "periodic":
# With periodic extrapolation we map x to the segment
# [spl.t[k], spl.t[n]].
Expand Down
99 changes: 0 additions & 99 deletions sklearn/preprocessing/tests/test_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,80 +28,6 @@ def is_c_contiguous(a):
assert np.isfortran(est(order="F").fit_transform(X))


@pytest.mark.parametrize(
"params, err_msg",
[
({"degree": -1}, "degree must be a non-negative integer"),
({"degree": 2.5}, "degree must be a non-negative integer"),
({"degree": "string"}, "degree must be a non-negative integer"),
({"n_knots": 1}, "n_knots must be a positive integer >= 2."),
({"n_knots": 1}, "n_knots must be a positive integer >= 2."),
({"n_knots": 2.5}, "n_knots must be a positive integer >= 2."),
({"n_knots": "string"}, "n_knots must be a positive integer >= 2."),
({"knots": 1}, "Expected 2D array, got scalar array instead:"),
({"knots": [1, 2]}, "Expected 2D array, got 1D array instead:"),
(
{"knots": [[1]]},
r"Number of knots, knots.shape\[0\], must be >= 2.",
),
(
{"knots": [[1, 5], [2, 6]]},
r"knots.shape\[1\] == n_features is violated.",
),
(
{"knots": [[1], [1], [2]]},
"knots must be sorted without duplicates.",
),
({"knots": [[2], [1]]}, "knots must be sorted without duplicates."),
(
{"extrapolation": None},
"extrapolation must be one of 'error', 'constant', 'linear', "
"'continue' or 'periodic'.",
),
(
{"extrapolation": 1},
"extrapolation must be one of 'error', 'constant', 'linear', "
"'continue' or 'periodic'.",
),
(
{"extrapolation": "string"},
"extrapolation must be one of 'error', 'constant', 'linear', "
"'continue' or 'periodic'.",
),
({"include_bias": None}, "include_bias must be bool."),
({"include_bias": 1}, "include_bias must be bool."),
({"include_bias": "string"}, "include_bias must be bool."),
(
{"extrapolation": "periodic", "n_knots": 3, "degree": 3},
"Periodic splines require degree < n_knots. Got n_knots=3 and degree=3.",
),
(
{"extrapolation": "periodic", "knots": [[0], [1]], "degree": 2},
"Periodic splines require degree < n_knots. Got n_knots=2 and degree=2.",
),
],
)
def test_spline_transformer_input_validation(params, err_msg):
"""Test that we raise errors for invalid input in SplineTransformer."""
X = [[1], [2]]

with pytest.raises(ValueError, match=err_msg):
SplineTransformer(**params).fit(X)


def test_spline_transformer_manual_knot_input():
"""
Test that array-like knot positions in SplineTransformer are accepted.
"""
X = np.arange(20).reshape(10, 2)
knots = [[0.5, 1], [1.5, 2], [5, 10]]
st1 = SplineTransformer(degree=3, knots=knots, n_knots=None).fit(X)
knots = np.asarray(knots)
st2 = SplineTransformer(degree=3, knots=knots, n_knots=None).fit(X)
for i in range(X.shape[1]):
assert_allclose(st1.bsplines_[i].t, st2.bsplines_[i].t)


@pytest.mark.parametrize("extrapolation", ["continue", "periodic"])
def test_spline_transformer_integer_knots(extrapolation):
"""Test that SplineTransformer accepts integer value knot positions."""
Expand Down Expand Up @@ -238,31 +164,6 @@ def test_spline_transformer_get_base_knot_positions(
assert_allclose(base_knots, expected_knots)


@pytest.mark.parametrize(
"knots, n_knots, degree",
[
("uniform", 5, 3),
("uniform", 12, 8),
(
[[-1.0, 0.0], [0, 1.0], [0.1, 2.0], [0.2, 3.0], [0.3, 4.0], [1, 5.0]],
None,
3,
),
],
)
def test_spline_transformer_periodicity_of_extrapolation(knots, n_knots, degree):
"""Test that the SplineTransformer is periodic for multiple features."""
X_1 = np.linspace((-1, 0), (1, 5), 10)
X_2 = np.linspace((1, 5), (3, 10), 10)

splt = SplineTransformer(
knots=knots, n_knots=n_knots, degree=degree, extrapolation="periodic"
)
splt.fit(X_1)

assert_allclose(splt.transform(X_1), splt.transform(X_2))


@pytest.mark.parametrize(["bias", "intercept"], [(True, False), (False, True)])
def test_spline_transformer_periodic_linear_regression(bias, intercept):
"""Test that B-splines fit a periodic curve pretty well."""
Expand Down
1 change: 0 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"SpectralBiclustering",
"SpectralCoclustering",
"SpectralEmbedding",
"SplineTransformer",
"TransformedTargetRegressor",
]

Expand Down
0