1
1
import numpy as np
2
2
import pytest
3
- import scipy .sparse as sp
4
3
5
4
from sklearn .cluster import BisectingKMeans
6
5
from sklearn .metrics import v_measure_score
7
6
from sklearn .utils ._testing import assert_allclose , assert_array_equal
7
+ from sklearn .utils .fixes import CSR_CONTAINERS
8
8
9
9
10
10
@pytest .mark .parametrize ("bisecting_strategy" , ["biggest_inertia" , "largest_cluster" ])
@@ -33,7 +33,8 @@ def test_three_clusters(bisecting_strategy, init):
33
33
assert_allclose (v_measure_score (expected_labels , bisect_means .labels_ ), 1.0 )
34
34
35
35
36
- def test_sparse ():
36
+ @pytest .mark .parametrize ("csr_container" , CSR_CONTAINERS )
37
+ def test_sparse (csr_container ):
37
38
"""Test Bisecting K-Means with sparse data.
38
39
39
40
Checks if labels and centers are the same between dense and sparse.
@@ -43,7 +44,7 @@ def test_sparse():
43
44
44
45
X = rng .rand (20 , 2 )
45
46
X [X < 0.8 ] = 0
46
- X_csr = sp . csr_matrix (X )
47
+ X_csr = csr_container (X )
47
48
48
49
bisect_means = BisectingKMeans (n_clusters = 3 , random_state = 0 )
49
50
@@ -84,48 +85,48 @@ def test_one_cluster():
84
85
assert_allclose (bisect_means .cluster_centers_ , X .mean (axis = 0 ).reshape (1 , - 1 ))
85
86
86
87
87
- @pytest .mark .parametrize ("is_sparse " , [ True , False ])
88
- def test_fit_predict (is_sparse ):
88
+ @pytest .mark .parametrize ("csr_container " , CSR_CONTAINERS + [ None ])
89
+ def test_fit_predict (csr_container ):
89
90
"""Check if labels from fit(X) method are same as from fit(X).predict(X)."""
90
91
rng = np .random .RandomState (0 )
91
92
92
93
X = rng .rand (10 , 2 )
93
94
94
- if is_sparse :
95
+ if csr_container is not None :
95
96
X [X < 0.8 ] = 0
96
- X = sp . csr_matrix (X )
97
+ X = csr_container (X )
97
98
98
99
bisect_means = BisectingKMeans (n_clusters = 3 , random_state = 0 )
99
100
bisect_means .fit (X )
100
101
101
102
assert_array_equal (bisect_means .labels_ , bisect_means .predict (X ))
102
103
103
104
104
- @pytest .mark .parametrize ("is_sparse " , [ True , False ])
105
- def test_dtype_preserved (is_sparse , global_dtype ):
105
+ @pytest .mark .parametrize ("csr_container " , CSR_CONTAINERS + [ None ])
106
+ def test_dtype_preserved (csr_container , global_dtype ):
106
107
"""Check that centers dtype is the same as input data dtype."""
107
108
rng = np .random .RandomState (0 )
108
109
X = rng .rand (10 , 2 ).astype (global_dtype , copy = False )
109
110
110
- if is_sparse :
111
+ if csr_container is not None :
111
112
X [X < 0.8 ] = 0
112
- X = sp . csr_matrix (X )
113
+ X = csr_container (X )
113
114
114
115
km = BisectingKMeans (n_clusters = 3 , random_state = 0 )
115
116
km .fit (X )
116
117
117
118
assert km .cluster_centers_ .dtype == global_dtype
118
119
119
120
120
- @pytest .mark .parametrize ("is_sparse " , [ True , False ])
121
- def test_float32_float64_equivalence (is_sparse ):
121
+ @pytest .mark .parametrize ("csr_container " , CSR_CONTAINERS + [ None ])
122
+ def test_float32_float64_equivalence (csr_container ):
122
123
"""Check that the results are the same between float32 and float64."""
123
124
rng = np .random .RandomState (0 )
124
125
X = rng .rand (10 , 2 )
125
126
126
- if is_sparse :
127
+ if csr_container is not None :
127
128
X [X < 0.8 ] = 0
128
- X = sp . csr_matrix (X )
129
+ X = csr_container (X )
129
130
130
131
km64 = BisectingKMeans (n_clusters = 3 , random_state = 0 ).fit (X )
131
132
km32 = BisectingKMeans (n_clusters = 3 , random_state = 0 ).fit (X .astype (np .float32 ))
0 commit comments