-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MAINT add parameters validation for SplineTransformer #24057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MAINT add parameters validation for SplineTransformer #24057
Conversation
…nto validate_spline_transformer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise, the rest of the PR seems fine.
@@ -767,6 +780,8 @@ def fit(self, X, y=None, sample_weight=None): | |||
self : object | |||
Fitted transformer. | |||
""" | |||
self._validate_params() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now some parameters will be validated and you need to remove some extra unnecessary checks. I check and you should have the following diff
in fit
:
diff --git a/sklearn/preprocessing/_polynomial.py b/sklearn/preprocessing/_polynomial.py
index 585c21eb2d..0b651e1f2f 100644
--- a/sklearn/preprocessing/_polynomial.py
+++ b/sklearn/preprocessing/_polynomial.py
@@ -794,20 +794,8 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
_, n_features = X.shape
- if not (isinstance(self.degree, numbers.Integral) and self.degree >= 0):
- raise ValueError(
- f"degree must be a non-negative integer, got {self.degree}."
- )
-
- if isinstance(self.knots, str) and self.knots in [
- "uniform",
- "quantile",
- ]:
- if not (isinstance(self.n_knots, numbers.Integral) and self.n_knots >= 2):
- raise ValueError(
- f"n_knots must be a positive integer >= 2, got: {self.n_knots}"
- )
+ if isinstance(self.knots, str):
base_knots = self._get_base_knot_positions(
X, n_knots=self.n_knots, knots=self.knots, sample_weight=sample_weight
)
@@ -820,20 +808,6 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
elif not np.all(np.diff(base_knots, axis=0) > 0):
raise ValueError("knots must be sorted without duplicates.")
- if self.extrapolation not in (
- "error",
- "constant",
- "linear",
- "continue",
- "periodic",
- ):
- raise ValueError(
- "extrapolation must be one of 'error', "
- "'constant', 'linear', 'continue' or 'periodic'."
- )
-
- if not isinstance(self.include_bias, (bool, np.bool_)):
- raise ValueError("include_bias must be bool.")
# number of knots for base interval
n_knots = base_knots.shape[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Reference Issues/PRs
References #23462 as well #22722
What does this implement/fix? Explain your changes.
Adds
_parameter_constraints
tosklearn/preprocessing/_polynomial.py/SplineTransformer
and removesSplineTransformer
fromPARAM_VALIDATION_ESTIMATORS_TO_IGNORE
insklearn/tests/test_common.py
Any other comments?
#23462 mentions that spotting and removing existing validation/tests is easier with codecov since it becomes unreachable code. Can this be down locally or do I need to submit a pull request?