8000 TST Extend tests for `scipy.sparse.*array` in `sklearn/cluster/tests/… · REDVM/scikit-learn@283baa7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 283baa7

Browse files
Kislovskiyogrisel
authored andcommitted
TST Extend tests for scipy.sparse.*array in sklearn/cluster/tests/test_bisect_k_means.py (scikit-learn#27099)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 941a474 commit 283baa7

File tree

2 files changed

+18
-17
lines changed

2 files changed

+18
-17
lines changed

sklearn/cluster/tests/test_bisect_k_means.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
22
import pytest
3-
import scipy.sparse as sp
43

54
from sklearn.cluster import BisectingKMeans
65
from sklearn.metrics import v_measure_score
76
from sklearn.utils._testing import assert_allclose, assert_array_equal
7+
from sklearn.utils.fixes import CSR_CONTAINERS
88

99

1010
@pytest.mark.parametrize("bisecting_strategy", ["biggest_inertia", "largest_cluster"])
@@ -33,7 +33,8 @@ def test_three_clusters(bisecting_strategy, init):
3333
assert_allclose(v_measure_score(expected_labels, bisect_means.labels_), 1.0)
3434

3535

36-
def test_sparse():
36+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
37+
def test_sparse(csr_container):
3738
"""Test Bisecting K-Means with sparse data.
3839
3940
Checks if labels and centers are the same between dense and sparse.
@@ -43,7 +44,7 @@ def test_sparse():
4344

4445
X = rng.rand(20, 2)
4546
X[X < 0.8] = 0
46-
X_csr = sp.csr_matrix(X)
47+
X_csr = csr_container(X)
4748

4849
bisect_means = BisectingKMeans(n_clusters=3, random_state=0)
4950

@@ -84,48 +85,48 @@ def test_one_cluster():
8485
assert_allclose(bisect_means.cluster_centers_, X.mean(axis=0).reshape(1, -1))
8586

8687

87-
@pytest.mark.parametrize("is_sparse", [True, False])
88-
def test_fit_predict(is_sparse):
88+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS + [None])
89+
def test_fit_predict(csr_container):
8990
"""Check if labels from fit(X) method are same as from fit(X).predict(X)."""
9091
rng = np.random.RandomState(0)
9192

9293
X = rng.rand(10, 2)
9394

94-
if is_sparse:
95+
if csr_container is not None:
9596
X[X < 0.8] = 0
96-
X = sp.csr_matrix(X)
97+
X = csr_container(X)
9798

9899
bisect_means = BisectingKMeans(n_clusters=3, random_state=0)
99100
bisect_means.fit(X)
100101

101102
assert_array_equal(bisect_means.labels_, bisect_means.predict(X))
102103

103104

104-
@pytest.mark.parametrize("is_sparse", [True, False])
105-
def test_dtype_preserved(is_sparse, global_dtype):
105+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS + [None])
106+
def test_dtype_preserved(csr_container, global_dtype):
106107
"""Check that centers dtype is the same as input data dtype."""
107108
rng = np.random.RandomState(0)
108109
X = rng.rand(10, 2).astype(global_dtype, copy=False)
109110

110-
if is_sparse:
111+
if csr_container is not None:
111112
X[X < 0.8] = 0
112-
X = sp.csr_matrix(X)
113+
X = csr_container(X)
113114

114115
km = BisectingKMeans(n_clusters=3, random_state=0)
115116
km.fit(X)
116117

117118
assert km.cluster_centers_.dtype == global_dtype
118119

119120

120-
@pytest.mark.parametrize("is_sparse", [True, False])
121-
def test_float32_float64_equivalence(is_sparse):
121+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS + [None])
122+
def test_float32_float64_equivalence(csr_container):
122123
"""Check that the results are the same between float32 and float64."""
123124
rng = np.random.RandomState(0)
124125
X = rng.rand(10, 2)
125126

126-
if is_sparse:
127+
if csr_container is not None:
127128
X[X < 0.8] = 0
128-
X = sp.csr_matrix(X)
129+
X = csr_container(X)
129130

130131
km64 = BisectingKMeans(n_clusters=3, random_state=0).fit(X)
131132
km32 = BisectingKMeans(n_clusters=3, random_state=0).fit(X.astype(np.float32))

sklearn/utils/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,9 +1002,9 @@ def _check_large_sparse(X, accept_large_sparse=False):
10021002
"""Raise a ValueError if X has 64bit indices and accept_large_sparse=False"""
10031003
if not accept_large_sparse:
10041004
supported_indices = ["int32"]
1005-
if X.getformat() == "coo":
1005+
if X.format == "coo":
10061006
index_keys = ["col", "row"]
1007-
elif X.getformat() in ["csr", "csc", "bsr"]:
1007+
elif X.format in ["csr", "csc", "bsr"]:
10081008
index_keys = ["indices", "indptr"]
10091009
else:
10101010< 39A6 /code>
return

0 commit comments

Comments
 (0)
0