8000 Allow `n_knots=None` if knots are explicitly specified in `SplineTran… · scikit-learn/scikit-learn@7781256 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7781256

Browse files
mlondschienogrisel
andauthored
Allow n_knots=None if knots are explicitly specified in SplineTransformer (#20191)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 800aee6 commit 7781256

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

sklearn/preprocessing/_polynomial.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ class SplineTransformer(TransformerMixin, BaseEstimator):
350350
----------
351351
n_knots : int, default=5
352352
Number of knots of the splines if `knots` equals one of
353-
{'uniform', 'quantile'}. Must be larger or equal 2.
353+
{'uniform', 'quantile'}. Must be larger or equal 2. Ignored if `knots`
354+
is array-like.
354355
355356
degree : int, default=3
356357
The polynomial degree of the spline basis. Must be a non-negative
@@ -546,15 +547,17 @@ def fit(self, X, y=None):
546547
):
547548
raise ValueError("degree must be a non-negative integer.")
548549

549-
if not (
550-
isinstance(self.n_knots, numbers.Integral) and self.n_knots >= 2
551-
):
552-
raise ValueError("n_knots must be a positive integer >= 2.")
553-
554550
if isinstance(self.knots, str) and self.knots in [
555551
"uniform",
556552
"quantile",
557553
]:
554+
if not (
555+
isinstance(self.n_knots, numbers.Integral)
556+
and self.n_knots >= 2
557+
):
558+
raise ValueError("n_knots must be a positive integer >= 2, "
559+
f"got: {self.n_knots}")
560+
558561
base_knots = self._get_base_knot_positions(
559562
X, n_knots=self.n_knots, knots=self.knots
560563
)

sklearn/preprocessing/tests/test_polynomial.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,9 @@ def test_spline_transformer_manual_knot_input():
9696
"""
9797
X = np.arange(20).reshape(10, 2)
9898
knots = [[0.5, 1], [1.5, 2], [5, 10]]
99-
st1 = SplineTransformer(degree=3, knots=knots).fit(X)
99+
st1 = SplineTransformer(degree=3, knots=knots, n_knots=None).fit(X)
100100
knots = np.asarray(knots)
101-
st2 = SplineTransformer(degree=3, knots=knots).fit(X)
101+
st2 = SplineTransformer(degree=3, knots=knots, n_knots=None).fit(X)
102102
for i in range(X.shape[1]):
103103
assert_allclose(st1.bsplines_[i].t, st2.bsplines_[i].t)
104104

@@ -216,7 +216,7 @@ def test_spline_transformer_linear_regression(bias, intercept):
216216
("uniform", 12, 8),
217217
(
218218
[[-1.0, 0.0], [0, 1.0], [0.1, 2.0], [0.2, 3.0], [0.3, 4.0], [1, 5.0]],
219-
100, # this gets ignored.
219+
None,
220220
3
221221
)
222222
])

0 commit comments

Comments
 (0)
0