8000 FIX Fixes TheilSenRegressor with fit_intercept=False and one feature … · thomasjpfan/scikit-learn@06c710a · GitHub
[go: up one dir, main page]

Skip to content

Commit 06c710a

Browse files
authored
FIX Fixes TheilSenRegressor with fit_intercept=False and one feature (scikit-learn#18121)
1 parent 7dcb1ac commit 06c710a

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

doc/whats_new/v0.24.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,9 @@ Changelog
346346
efficient leave-one-out cross-validation scheme ``cv=None``. :pr:`6624` by
347347
:user:`Marijn van Vliet <wmvanvliet>`.
348348

349+
- |Fix| Fixes bug in :class:`linear_model.TheilSenRegressor` where
350+
`predict` and `score` would fail when `fit_intercept=False` and there was
351+
one feature during fitting. :pr:`18121` by `Thomas Fan`_.
349352

350353
:mod:`sklearn.manifold`
351354
.......................

sklearn/linear_model/_theil_sen.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _spatial_median(X, max_iter=300, tol=1.e-3):
110110
http://users.jyu.fi/~samiayr/pdf/ayramo_eurogen05.pdf
111111
"""
112112
if X.shape[1] == 1:
113-
return 1, np.median(X.ravel())
113+
return 1, np.median(X.ravel(), keepdims=True)
114114

115115
tol **= 2 # We are computing the tol on the squared norm
116116
spatial_median_old = np.mean(X, axis=0)
@@ -125,7 +125,6 @@ def _spatial_median(X, max_iter=300, tol=1.e-3):
125125
warnings.warn("Maximum number of iterations {max_iter} reached in "
126126
"spatial median for TheilSen regressor."
127127
"".format(max_iter=max_iter), ConvergenceWarning)
128-
129128
return n_iter, spatial_median
130129

131130

sklearn/linear_model/tests/test_theil_sen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def test_theil_sen_1d_no_intercept():
179179
assert_array_almost_equal(theil_sen.coef_, w + c, 1)
180180
assert_almost_equal(theil_sen.intercept_, 0.)
181181

182+
# non-regression test for #18104
183+
theil_sen.score(X, y)
184+
182185

183186
def test_theil_sen_2d():
184187
X, y, w, c = gen_toy_problem_2d()

0 commit comments

Comments
 (0)
0