10000 TST Extend tests for `scipy.sparse.*array` in `sklearn/datasets/tests… · REDVM/scikit-learn@b14fa86 · GitHub
[go: up one dir, main page]

Skip to content

Commit b14fa86

Browse files
TialoREDVM
authored andcommitted
TST Extend tests for scipy.sparse.*array in sklearn/datasets/tests/test_svmlight_format.py (scikit-learn#27220)
1 parent 4f02b60 commit b14fa86

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

sklearn/datasets/tests/test_svmlight_format.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
assert_array_equal,
1818
fails_if_pypy,
1919
)
20-
from sklearn.utils.fixes import _open_binary, _path
20+
from sklearn.utils.fixes import CSR_CONTAINERS, _open_binary, _path
2121

2222
TEST_DATA_MODULE = "sklearn.datasets.tests.data"
2323
datafile = "svmlight_classification.txt"
@@ -254,10 +254,11 @@ def test_invalid_filename():
254254
load_svmlight_file("trou pic nic douille")
255255

256256

257-
def test_dump():
257+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
258+
def test_dump(csr_container):
258259
X_sparse, y_dense = _load_svmlight_local_test_file(datafile)
259260
X_dense = X_sparse.toarray()
260-
y_sparse = sp.csr_matrix(y_dense)
261+
y_sparse = csr_container(y_dense)
261262

262263
# slicing a csr_matrix can unsort its .indices, so test that we sort
263264
# those correctly
@@ -323,10 +324,11 @@ def test_dump():
323324
)
324325

325326

326-
def test_dump_multilabel():
327+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
328+
def test_dump_multilabel(csr_container):
327329
X = [[1, 0, 3, 0, 5], [0, 0, 0, 0, 0], [0, 5, 0, 1, 0]]
328330
y_dense = [[0, 1, 0], [1, 0, 1], [1, 1, 0]]
329-
y_sparse = sp.csr_matrix(y_dense)
331+
y_sparse = csr_container(y_dense)
330332
for y in [y_dense, y_sparse]:
331333
f = BytesIO()
332334
dump_svmlight_file(X, y, f, multilabel=True)
@@ -465,9 +467,10 @@ def test_load_with_long_qid():
465467
assert_array_equal(X.toarray(), true_X)
466468

467469

468-
def test_load_zeros():
470+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
471+
def test_load_zeros(csr_container):
469472
f = BytesIO()
470-
true_X = sp.csr_matrix(np.zeros(shape=(3, 4)))
473+
true_X = csr_container(np.zeros(shape=(3, 4)))
471474
true_y = np.array([0, 1, 0])
472475
dump_svmlight_file(true_X, true_y, f)
473476

@@ -481,12 +484,13 @@ def test_load_zeros():
481484
@pytest.mark.parametrize("sparsity", [0, 0.1, 0.5, 0.99, 1])
482485
@pytest.mark.parametrize("n_samples", [13, 101])
483486
@pytest.mark.parametrize("n_features", [2, 7, 41])
484-
def test_load_with_offsets(sparsity, n_samples, n_features):
487+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
488+
def test_load_with_offsets(sparsity, n_samples, n_features, csr_container):
485489
rng = np.random.RandomState(0)
486490
X = rng.uniform(low=0.0, high=1.0, size=(n_samples, n_features))
487491
if sparsity:
488492
X[X < sparsity] = 0.0
489-
X = sp.csr_matrix(X)
493+
X = csr_container(X)
490494
y = rng.randint(low=0, high=2, size=n_samples)
491495

492496
f = BytesIO()
@@ -517,7 +521,8 @@ def test_load_with_offsets(sparsity, n_samples, n_features):
517521
assert_array_almost_equal(X.toarray(), X_concat.toarray())
518522

519523

520-
def test_load_offset_exhaustive_splits():
524+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
525+
def test_load_offset_exhaustive_splits(csr_container):
521526
rng = np.random.RandomState(0)
522527
X = np.array(
523528
[
@@ -530,7 +535,7 @@ def test_load_offset_exhaustive_splits():
530535
[1, 0, 0, 0, 0, 0],
531536
]
532537
)
533-
X = sp.csr_matrix(X)
538+
X = csr_container(X)
534539
n_samples, n_features = X.shape
535540
y = rng.randint(low=0, high=2, size=n_samples)
536541
query_id = np.arange(n_samples) // 2
@@ -564,7 +569,8 @@ def test_load_with_offsets_error():
564569
_load_svmlight_local_test_file(datafile, offset=3, length=3)
565570

566571

567-
def test_multilabel_y_explicit_zeros(tmp_path):
572+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
573+
def test_multilabel_y_explicit_zeros(tmp_path, csr_container):
568574
"""
569575
Ensure that if y contains explicit zeros (i.e. elements of y.data equal to
570576
0) then those explicit zeros are not encoded.
@@ -576,7 +582,7 @@ def test_multilabel_y_explicit_zeros(tmp_path):
576582
indices = np.array([0, 2, 2, 0, 1, 2])
577583
# The first and last element are explicit zeros.
578584
data = np.array([0, 1, 1, 1, 1, 0])
579-
y = sp.csr_matrix((data, indices, indptr), shape=(3, 3))
585+
y = csr_container((data, indices, indptr), shape=(3, 3))
580586
# y as a dense array would look like
581587
# [[0, 0, 1],
582588
# [0, 0, 1],

0 commit comments

Comments
 (0)
0