17
17
assert_array_equal ,
18
18
fails_if_pypy ,
19
19
)
20
- from sklearn .utils .fixes import _open_binary , _path
20
+ from sklearn .utils .fixes import CSR_CONTAINERS , _open_binary , _path
21
21
22
22
TEST_DATA_MODULE = "sklearn.datasets.tests.data"
23
23
datafile = "svmlight_classification.txt"
@@ -254,10 +254,11 @@ def test_invalid_filename():
254
254
load_svmlight_file ("trou pic nic douille" )
255
255
256
256
257
- def test_dump ():
257
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
258
+ def test_dump (csr_container ):
258
259
X_sparse , y_dense = _load_svmlight_local_test_file (datafile )
259
260
X_dense = X_sparse .toarray ()
260
- y_sparse = sp . csr_matrix (y_dense )
261
+ y_sparse = csr_container (y_dense )
261
262
262
263
# slicing a csr_matrix can unsort its .indices, so test that we sort
263
264
# those correctly
@@ -323,10 +324,11 @@ def test_dump():
323
324
)
324
325
325
326
326
- def test_dump_multilabel ():
327
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
328
+ def test_dump_multilabel (csr_container ):
327
329
X = [[1 , 0 , 3 , 0 , 5 ], [0 , 0 , 0 , 0 , 0 ], [0 , 5 , 0 , 1 , 0 ]]
328
330
y_dense = [[0 , 1 , 0 ], [1 , 0 , 1 ], [1 , 1 , 0 ]]
329
- y_sparse = sp . csr_matrix (y_dense )
331
+ y_sparse = csr_container (y_dense )
330
332
for y in [y_dense , y_sparse ]:
331
333
f = BytesIO ()
332
334
dump_svmlight_file (X , y , f , multilabel = True )
@@ -465,9 +467,10 @@ def test_load_with_long_qid():
465
467
assert_array_equal (X .toarray (), true_X )
466
468
467
469
468
- def test_load_zeros ():
470
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
471
+ def test_load_zeros (csr_container ):
469
472
f = BytesIO ()
470
- true_X = sp . csr_matrix (np .zeros (shape = (3 , 4 )))
473
+ true_X = csr_container (np .zeros (shape = (3 , 4 )))
471
474
true_y = np .array ([0 , 1 , 0 ])
472
475
dump_svmlight_file (true_X , true_y , f )
473
476
@@ -481,12 +484,13 @@ def test_load_zeros():
481
484
@pytest .mark .parametrize ("sparsity" , [0 , 0.1 , 0.5 , 0.99 , 1 ])
482
485
@pytest .mark .parametrize ("n_samples" , [13 , 101 ])
483
486
@pytest .mark .parametrize ("n_features" , [2 , 7 , 41 ])
484
- def test_load_with_offsets (sparsity , n_samples , n_features ):
487
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
488
+ def test_load_with_offsets (sparsity , n_samples , n_features , csr_container ):
485
489
rng = np .random .RandomState (0 )
486
490
X = rng .uniform (low = 0.0 , high = 1.0 , size = (n_samples , n_features ))
487
491
if sparsity :
488
492
X [X < sparsity ] = 0.0
489
- X = sp . csr_matrix (X )
493
+ X = csr_container (X )
490
494
y = rng .randint (low = 0 , high = 2 , size = n_samples )
491
495
492
496
f = BytesIO ()
@@ -517,7 +521,8 @@ def test_load_with_offsets(sparsity, n_samples, n_features):
517
521
assert_array_almost_equal (X .toarray (), X_concat .toarray ())
518
522
519
523
520
- def test_load_offset_exhaustive_splits ():
524
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
525
+ def test_load_offset_exhaustive_splits (csr_container ):
521
526
rng = np .random .RandomState (0 )
522
527
X = np .array (
523
528
[
@@ -530,7 +535,7 @@ def test_load_offset_exhaustive_splits():
530
535
[1 , 0 , 0 , 0 , 0 , 0 ],
531
536
]
532
537
)
533
- X = sp . csr_matrix (X )
538
+ X = csr_container (X )
534
539
n_samples , n_features = X .shape
535
540
y = rng .randint (low = 0 , high = 2 , size = n_samples )
536
541
query_id = np .arange (n_samples ) // 2
@@ -564,7 +569,8 @@ def test_load_with_offsets_error():
564
569
_load_svmlight_local_test_file (datafile , offset = 3 , length = 3 )
565
570
566
571
567
- def test_multilabel_y_explicit_zeros (tmp_path ):
572
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
573
+ def test_multilabel_y_explicit_zeros (tmp_path , csr_container ):
568
574
"""
569
575
Ensure that if y contains explicit zeros (i.e. elements of y.data equal to
570
576
0) then those explicit zeros are not encoded.
@@ -576,7 +582,7 @@ def test_multilabel_y_explicit_zeros(tmp_path):
576
582
indices = np .array ([0 , 2 , 2 , 0 , 1 , 2 ])
577
583
# The first and last element are explicit zeros.
578
584
data = np .array ([0 , 1 , 1 , 1 , 1 , 0 ])
579
- y = sp . csr_matrix ((data , indices , indptr ), shape = (3 , 3 ))
585
+ y = csr_container ((data , indices , indptr ), shape = (3 , 3 ))
580
586
# y as a dense array would look like
581
587
# [[0, 0, 1],
582
588
# [0, 0, 1],
0 commit comments