8000 Merge pull request #5362 from MechCoder/lasso_fix · scikit-learn/scikit-learn@da9a7cd · GitHub
[go: up one dir, main page]

Skip to cont 8000 ent

Commit da9a7cd

Browse files
committed
Merge pull request #5362 from MechCoder/lasso_fix
[MRG] Lasso and ElasticNet should handle non-integer dtypes for fit_intercept=False
2 parents 281aebd + 1d5b473 commit da9a7cd

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

sklearn/linear_model/coordinate_descent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ def fit(self, X, y, check_input=True):
652652
# We expect X and y to be already float64 Fortran ordered arrays
653653
# when bypassing checks
654654
if check_input:
655+
y = np.asarray(y, dtype=np.float64)
655656
X, y = check_X_y(X, y, accept_sparse='csc', dtype=np.float64,
656657
order='F',
657658
copy=self.copy_X and self.fit_intercept,

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,3 +661,16 @@ def test_overrided_gram_matrix():
661661
" to fit intercept, "
662662
"or X was normalized : recomputing Gram matrix.",
663663
clf.fit, X, y)
664+
665+
666+
def test_lasso_non_float_y():
667+
X = [[0, 0], [1, 1], [-1, -1]]
668+
y = [0, 1, 2]
669+
y_float = [0.0, 1.0, 2.0]
670+
671+
for model in [ElasticNet, Lasso]:
672+
clf = model(fit_intercept=False)
673+
clf.fit(X, y)
674+
clf_float = model(fit_intercept=False)
675+
clf_float.fit(X, y_float)
676+
assert_array_equal(clf.coef_, clf_float.coef_)

0 commit comments

Comments
 (0)
0