From 391742c23e50e64fef6d6ce8e5aaba8c8142874c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Dupr=C3=A9=20la=20Tour?= Date: Mon, 20 Sep 2021 12:23:26 -0700 Subject: [PATCH 1/2] FIX improve error message for large sparse matrix input in LogisticRegression --- sklearn/linear_model/_logistic.py | 6 +++--- sklearn/linear_model/tests/test_logistic.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 012d14102cae3..08e71edbc69ab 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -663,7 +663,7 @@ def _logistic_regression_path( X, accept_sparse="csr", dtype=np.float64, - accept_large_sparse=solver != "liblinear", + accept_large_sparse=solver not in ["liblinear", "sag", "saga"], ) y = check_array(y, ensure_2d=False, dtype=None) check_consistent_length(X, y) @@ -1511,7 +1511,7 @@ def fit(self, X, y, sample_weight=None): accept_sparse="csr", dtype=_dtype, order="C", - accept_large_sparse=solver != "liblinear", + accept_large_sparse=solver not in ["liblinear", "sag", "saga"], ) check_classification_targets(y) self.classes_ = np.unique(y) @@ -2080,7 +2080,7 @@ def fit(self, X, y, sample_weight=None): accept_sparse="csr", dtype=np.float64, order="C", - accept_large_sparse=solver != "liblinear", + accept_large_sparse=solver not in ["liblinear", "sag", "saga"], ) check_classification_targets(y) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index f900994081b47..7b03a061fbeac 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -2237,3 +2237,21 @@ def test_sample_weight_not_modified(multi_class, class_weight): ) clf.fit(X, y, sample_weight=W) assert_allclose(expected, W) + + +@pytest.mark.parametrize("solver", ["liblinear", "lbfgs", "newton-cg", "sag", "saga"]) +def test_large_sparse_matrix(solver): + # Solvers either accept large sparse matrices, or raise helpful error. + + # generate sparse matrix with int64 indices + X = sp.rand(20, 10, format="csr") + for attr in ["indices", "indptr"]: + setattr(X, attr, getattr(X, attr).astype("int64")) + y = np.random.randint(2, size=X.shape[0]) + + if solver in ["liblinear", "sag", "saga"]: + msg = "Only sparse matrices with 32-bit integer indices" + with pytest.raises(ValueError, match=msg): + LogisticRegression(solver=solver).fit(X, y) + else: + LogisticRegression(solver=solver).fit(X, y) From 8463ab88920a1eab17631b54ab93ee43098b3fba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Dupr=C3=A9=20la=20Tour?= Date: Mon, 20 Sep 2021 15:03:43 -0700 Subject: [PATCH 2/2] add whatsnew entry, add pull-request reference --- doc/whats_new/v1.1.rst | 7 +++++++ sklearn/linear_model/tests/test_logistic.py | 1 + 2 files changed, 8 insertions(+) diff --git a/doc/whats_new/v1.1.rst b/doc/whats_new/v1.1.rst index 9c1084e393e8d..3aabed6214771 100644 --- a/doc/whats_new/v1.1.rst +++ b/doc/whats_new/v1.1.rst @@ -38,6 +38,13 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. +:mod:`sklearn.linear_model` +........................... + +- |Fix| :class:`linear_model.LogisticRegression` now raises a better error + message when the solver does not support sparse matrices with int64 indices. + :pr:`21093` by `Tom Dupre la Tour`_. + :mod:`sklearn.utils` .................... diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 7b03a061fbeac..1171613eb3718 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -2242,6 +2242,7 @@ def test_sample_weight_not_modified(multi_class, class_weight): @pytest.mark.parametrize("solver", ["liblinear", "lbfgs", "newton-cg", "sag", "saga"]) def test_large_sparse_matrix(solver): # Solvers either accept large sparse matrices, or raise helpful error. + # Non-regression test for pull-request #21093. # generate sparse matrix with int64 indices X = sp.rand(20, 10, format="csr")