8000 FIX bug in SplineTransformer.n_features_out_ (#19577) · scikit-learn/scikit-learn@f0a6f05 · GitHub
[go: up one dir, main page]

Skip to content

Commit f0a6f05

Browse files
authored
FIX bug in SplineTransformer.n_features_out_ (#19577)
1 parent e0f0c7f commit f0a6f05

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

sklearn/preprocessing/_polynomial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def fit(self, X, y=None):
307307
]
308308
self.bsplines_ = bsplines
309309

310-
self.n_features_out_ = n_out - n_features * self.include_bias
310+
self.n_features_out_ = n_out - n_features * (1 - self.include_bias)
311311
return self
312312

313313
def transform(self, X):
@@ -336,7 +336,7 @@ def transform(self, X):
336336

337337
# Note that scipy BSpline returns float64 arrays and converts input
338338
# x=X[:, i] to c-contiguous float64.
339-
n_out = self.n_features_out_ + n_features * self.include_bias
339+
n_out = self.n_features_out_ + n_features * (1 - self.include_bias)
340340
if X.dtype in FLOAT_DTYPES:
341341
dtype = X.dtype
342342
else:

sklearn/preprocessing/tests/test_polynomial.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,19 @@ def test_spline_transformer_kbindiscretizer():
243243
# Though they should be exactly equal, we test approximately with high
244244
# accuracy.
245245
assert_allclose(splines, kbins, rtol=1e-13)
246+
247+
248+
@pytest.mark.parametrize("n_knots", [5, 10])
249+
@pytest.mark.parametrize("include_bias", [True, False])
250+
@pytest.mark.parametrize("degree", [3, 5])
251+
def test_spline_transformer_n_features_out(n_knots, include_bias, degree):
252+
"""Test that transform results in n_features_out_ features."""
253+
splt = SplineTransformer(
254+
n_knots=n_knots,
255+
degree=degree,
256+
include_bias=include_bias
257+
)
258+
X = np.linspace(0, 1, 10)[:, None]
259+
splt.fit(X)
260+
261+
assert splt.transform(X).shape[1] == splt.n_features_out_

0 commit comments

Comments
 (0)
0