8000 ENH: ElasticNetCV and LassoCV raise ValueError with multitarget outputs · r2k0/scikit-learn@af8904c · GitHub
[go: up one dir, main page]

8000
Skip to content

Commit af8904c

Browse files
MechCoderagramfort
authored andcommitted
ENH: ElasticNetCV and LassoCV raise ValueError with multitarget outputs
1 parent a8089b0 commit af8904c

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

sklearn/linear_model/coordinate_descent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,10 @@ def fit(self, X, y):
916916
copy_X = False
917917

918918
y = np.asarray(y, dtype=np.float64)
919+
920+
if y.ndim > 1:
921+
raise ValueError("For multi-task outputs, fit the linear model "
922+
"per output/task")
919923
if X.shape[0] != y.shape[0]:
920924
raise ValueError("X and y have inconsistent dimensions (%d != %d)"
921925
% (X.shape[0], y.shape[0]))

sklearn/linear_model/tests/test_coordinate_descent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.testing import SkipTest
1515
from sklearn.utils.testing import assert_true
1616
from sklearn.utils.testing import assert_greater
17+
from sklearn.utils.testing import assert_raises
1718

1819
from sklearn.linear_model.coordinate_descent import Lasso, \
1920
LassoCV, ElasticNet, ElasticNetCV, MultiTaskLasso, MultiTaskElasticNet, \
@@ -336,6 +337,13 @@ def test_enet_multitarget():
336337
assert_array_almost_equal(dual_gap[k], estimator.dual_gap_)
337338

338339

340+
def test_multioutput_enetcv_error():
341+
X = np.random.randn(10, 2)
342+
y = np.random.randn(10, 2)
343+
clf = ElasticNetCV()
344+
assert_raises(ValueError, clf.fit, X, y)
345+
346+
339347
if __name__ == '__main__':
340348
import nose
341349
nose.runmodule()

0 commit comments

Comments
 (0)
0