8000 MAINT add parameters validation for SplineTransformer by kasmith11 · Pull Request #24057 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT add parameters validation for SplineTransformer #24057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 26, 2022

Conversation

kasmith11
Copy link
Contributor

Reference Issues/PRs

References #23462 as well #22722

What does this implement/fix? Explain your changes.

Adds _parameter_constraints to sklearn/preprocessing/_polynomial.py/SplineTransformer and removes SplineTransformer from PARAM_VALIDATION_ESTIMATORS_TO_IGNORE in sklearn/tests/test_common.py

Any other comments?

#23462 mentions that spotting and removing existing validation/tests is easier with codecov since it becomes unreachable code. Can this be down locally or do I need to submit a pull request?

@jeremiedbb jeremiedbb added No Changelog Needed Validation related to input validation labels Jul 30, 2022
@glemaitre glemaitre changed the title towards #23462 - SplineTransformer MAINT add parameters validation for SplineTransformer Aug 24, 2022
@glemaitre glemaitre self-requested a review August 24, 2022 09:11
Copy link
< 8000 span aria-label="This user is a member of the scikit-learn organization." data-view-component="true" class="tooltipped tooltipped-n"> Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, the rest of the PR seems fine.

@@ -767,6 +780,8 @@ def fit(self, X, y=None, sample_weight=None):
self : object
Fitted transformer.
"""
self._validate_params()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now some parameters will be validated and you need to remove some extra unnecessary checks. I check and you should have the following diff in fit:

diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py
index 585c21eb2d..0b651e1f2f 100644
--- a/sklearn/preprocessing/_polynomial.py
+++ b/sklearn/preprocessing/_polynomial.py
@@ -794,20 +794,8 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
 
         _, n_features = X.shape
 
-        if not (isinstance(self.degree, numbers.Integral) and self.degree >= 0):
-            raise ValueError(
-                f"degree must be a non-negative integer, got {self.degree}."
-            )
-
-        if isinstance(self.knots, str) and self.knots in [
-            "uniform",
-            "quantile",
-        ]:
-            if not (isinstance(self.n_knots, numbers.Integral) and self.n_knots >= 2):
-                raise ValueError(
-                    f"n_knots must be a positive integer >= 2, got: {self.n_knots}"
-                )
 
+        if isinstance(self.knots, str):
             base_knots = self._get_base_knot_positions(
                 X, n_knots=self.n_knots, knots=self.knots, sample_weight=sample_weight
             )
@@ -820,20 +808,6 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
             elif not np.all(np.diff(base_knots, axis=0) > 0):
                 raise ValueError("knots must be sorted without duplicates.")
 
-        if self.extrapolation not in (
-            "error",
-            "constant",
-            "linear",
-            "continue",
-            "periodic",
-        ):
-            raise ValueError(
-                "extrapolation must be one of 'error', "
-                "'constant', 'linear', 'continue' or 'periodic'."
-            )
-
-        if not isinstance(self.include_bias, (bool, np.bool_)):
-            raise ValueError("include_bias must be bool.")
 
         # number of knots for base interval
         n_knots = base_knots.shape[0]

Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM.

Copy link
Member
@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomasjpfan thomasjpfan merged commit c01fad4 into scikit-learn:main Aug 26, 2022
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Sep 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0