8000 FIX improve error message for large sparse matrix input in LogisticRe… · scikit-learn/scikit-learn@762eca5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 762eca5

Browse files
authored
FIX improve error message for large sparse matrix input in LogisticRegression (#21093)
1 parent 2eabb45 commit 762eca5

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

doc/whats_new/v1.1.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@ Changelog
3838
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
3939
where 123456 is the *pull request* number, not the issue number.
4040
41+
:mod:`sklearn.linear_model`
42+
...........................
43+
44+
- |Fix| :class:`linear_model.LogisticRegression` now raises a better error
45+
message when the solver does not support sparse matrices with int64 indices.
46+
:pr:`21093` by `Tom Dupre la Tour`_.
47+
4148
:mod:`sklearn.utils`
4249
....................
4350

sklearn/linear_model/_logistic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ def _logistic_regression_path(
663663
X,
664664
accept_sparse="csr",
665665
dtype=np.float64,
666-
accept_large_sparse=solver != "liblinear",
666+
accept_large_sparse=solver not in ["liblinear", "sag", "saga"],
667667
)
668668
y = check_array(y, ensure_2d=False, dtype=None)
669669
check_consistent_length(X, y)
@@ -1511,7 +1511,7 @@ def fit(self, X, y, sample_weight=None):
15111511
accept_sparse="csr",
15121512
dtype=_dtype,
15131513
order="C",
1514-
accept_large_sparse=solver != "liblinear",
1514+
accept_large_sparse=solver not in ["liblinear", "sag", "saga"],
15151515
)
15161516
check_classification_targets(y)
15171517
self.classes_ = np.unique(y)
@@ -2080,7 +2080,7 @@ def fit(self, X, y, sample_weight=None):
20802080
accept_sparse="csr",
20812081
dtype=np.float64,
20822082
order="C",
2083-
accept_large_sparse=solver != "liblinear",
2083+
accept_large_sparse=solver not in ["liblinear", "sag", "saga"],
20842084
)
20852085
check_classification_targets(y)
20862086

sklearn/linear_model/tests/test_logistic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2237,3 +2237,22 @@ def test_sample_weight_not_modified(multi_class, class_weight):
22372237
)
22382238
clf.fit(X, y, sample_weight=W)
22392239
assert_allclose(expected, W)
2240+
2241+
2242+
@pytest.mark.parametrize("solver", ["liblinear", "lbfgs", "newton-cg", "sag", "saga"])
2243+
def test_large_sparse_matrix(solver):
2244+
# Solvers either accept large sparse matrices, or raise helpful error.
2245+
# Non-regression test for pull-request #21093.
2246+
2247+
# generate sparse matrix with int64 indices
2248+
X = sp.rand(20, 10, format="csr")
2249+
for attr in ["indices", "indptr"]:
2250+
setattr(X, attr, getattr(X, attr).astype("int64"))
2251+
y = np.random.randint(2, size=X.shape[0])
2252+
2253+
if solver in ["liblinear", "sag", "saga"]:
2254+
msg = "Only sparse matrices with 32-bit integer indices"
2255+
with pytest.raises(ValueError, match=msg):
2256+
LogisticRegression(solver=solver).fit(X, y)
2257+
else:
2258+
LogisticRegression(solver=solver).fit(X, y)

0 commit comments

Comments
 (0)
0