8000 TST raise ValueError when sparse enet with sample weights · scikit-learn/scikit-learn@aa180c0 · GitHub
[go: up one dir, main page]

Skip to content

Commit aa180c0

Browse files
author
Christian Lorentzen
committed
TST raise ValueError when sparse enet with sample weights
1 parent 7a35928 commit aa180c0

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def test_lasso_cv():
174174
def test_lasso_cv_with_some_model_selection():
175175
from sklearn.pipeline import make_pipeline
176176
from sklearn.preprocessing import StandardScaler
177-
from sklearn.model_selection import StratifiedKFold
177+
from sklearn.model_selection import ShuffleSplit
178178
from sklearn import datasets
179179
from sklearn.linear_model import LassoCV
180180

@@ -184,7 +184,7 @@ def test_lasso_cv_with_some_model_selection():
184184

185185
pipe = make_pipeline(
186186
StandardScaler(),
187-
LassoCV(cv=StratifiedKFold())
187+
LassoCV(cv=ShuffleSplit(random_state=0))
188188
)
189189
pipe.fit(X, y)
190190

@@ -973,3 +973,13 @@ def test_enet_sample_weight_consistency(fit_intercept, alpha, normalize,
973973
X2, y2, sample_weight=None
974974
)
975975
assert_allclose(reg1.coef_, reg2.coef_)
976+
977+
978+
def test_enet_sample_weight_sparse():
979+
reg = ElasticNet()
980+
X = sparse.csc_matrix(np.zeros((3, 2)))
981+
y = np.array([-1, 0, 1])
982+
sw = np.array([1, 2, 3])
983+
with pytest.raises(ValueError, match="Sample weights do not.*support "
984+
"sparse matrices"):
985+
reg.fit(X, y, sample_weight=sw, check_input=True)

0 commit comments

Comments
 (0)
0