32
32
random_state = 11 ,
33
33
)
34
34
35
- DTYPES = (np .float64 , np .float32 )
36
-
37
35
38
36
def test_estimate_bandwidth ():
39
37
# Test estimate_bandwidth
40
38
bandwidth = estimate_bandwidth (X , n_samples = 200 )
41
39
assert 0.9 <= bandwidth <= 1.5
42
40
43
41
44
- @pytest .mark .parametrize ("dtype" , DTYPES )
45
- def test_estimate_bandwidth_1sample (dtype ):
42
+ def test_estimate_bandwidth_1sample (global_dtype ):
46
43
# Test estimate_bandwidth when n_samples=1 and quantile<1, so that
47
44
# n_neighbors is set to 1.
48
- bandwidth = estimate_bandwidth (X .astype (dtype ), n_samples = 1 , quantile = 0.3 )
45
+ bandwidth = estimate_bandwidth (X .astype (global_dtype ), n_samples = 1 , quantile = 0.3 )
49
46
assert bandwidth == pytest .approx (0.0 , abs = 1e-5 )
50
47
51
48
52
49
@pytest .mark .parametrize (
53
50
"bandwidth, cluster_all, expected, first_cluster_label" ,
54
51
[(1.2 , True , 3 , 0 ), (1.2 , False , 4 , - 1 )],
55
52
)
56
- @pytest .mark .parametrize ("dtype" , DTYPES )
57
- def test_mean_shift (dtype , bandwidth , cluster_all , expected , first_cluster_label ):
53
+ def test_mean_shift (
54
+ global_dtype , bandwidth , cluster_all , expected , first_cluster_label
55
+ ):
58
56
# Test MeanShift algorithm
59
57
ms = MeanShift (bandwidth = bandwidth , cluster_all = cluster_all )
60
- labels = ms .fit (X .astype(dtype )).labels_
58
+ labels = ms .fit (X .astype (global_dtype )).labels_
61
59
labels_unique = np .unique (labels )
62
60
n_clusters_ = len (labels_unique )
63
61
assert n_clusters_ == expected
@@ -86,8 +84,7 @@ def test_estimate_bandwidth_with_sparse_matrix():
86
84
estimate_bandwidth (X )
87
85
88
86
89
- @pytest .mark .parametrize ("dtype" , DTYPES )
90
- def test_parallel (dtype ):
87
+ def test_parallel (global_dtype ):
91
88
centers = np .array ([[1 , 1 ], [- 1 , - 1 ], [1 , - 1 ]]) + 10
92
89
X , _ = make_blobs (
93
90
n_samples = 50 ,
@@ -98,7 +95,7 @@ def test_parallel(dtype):
98
95
random_state = 11 ,
99
96
)
100
97
101
- X = X .astype (dtype )
98
+ X = X .astype (global_dtype )
102
99
103
100
ms1 = MeanShift (n_jobs = 2 )
104
101
ms1 .fit (X )
@@ -110,11 +107,10 @@ def test_parallel(dtype):
110
107
assert_array_equal (ms1 .labels_ , ms2 .labels_ )
111
108
112
109
113
- @pytest .mark .parametrize ("dtype" , DTYPES )
114
- def test_meanshift_predict (dtype ):
110
+ def test_meanshift_predict (global_dtype ):
115
111
# Test MeanShift.predict
116
112
ms = MeanShift (bandwidth = 1.2 )
117
- Y = X .astype (dtype )
113
+ Y = X .astype (global_dtype )
118
114
labels = ms .fit_predict (Y )
119
115
labels2 = ms .predict (Y )
120
116
assert_array_equal (labels , labels2 )
@@ -137,25 +133,23 @@ def test_unfitted():
137
133
assert not hasattr (ms , "labels_" )
138
134
139
135
140
- @pytest .mark .parametrize ("dtype" , DTYPES )
141
- def test_cluster_intensity_tie (dtype ):
142
- X = np .array ([[1 , 1 ], [2 , 1 ], [1 , 0 ], [4 , 7 ], [3 , 5 ], [3 , 6 ]], dtype = dtype )
136
+ def test_cluster_intensity_tie (global_dtype ):
137
+ X = np .array ([[1 , 1 ], [2 , 1 ], [1 , 0 ], [4 , 7 ], [3 , 5 ], [3 , 6 ]], dtype = global_dtype )
143
138
c1 = MeanShift (bandwidth = 2 ).fit (X )
144
139
145
- X = np .array ([[4 , 7 ], [3 , 5 ], [3 , 6 ], [1 , 1 ], [2 , 1 ], [1 , 0 ]], dtype = dtype )
140
+ X = np .array ([[4 , 7 ], [3 , 5 ], [3 , 6 ], [1 , 1 ], [2 , 1 ], [1 , 0 ]], dtype = global_dtype )
146
141
c2 = MeanShift (bandwidth = 2 ).fit (X )
147
142
assert_array_equal (c1 .labels_ , [1 , 1 , 1 , 0 , 0 , 0 ])
148
143
assert_array_equal (c2 .labels_ , [0 , 0 , 0 , 1 , 1 , 1 ])
149
144
150
145
151
- @pytest .mark .parametrize ("dtype" , DTYPES )
152
- def test_bin_seeds (dtype ):
146
+ def test_bin_seeds (global_dtype ):
153
147
# Test the bin seeding technique which can be used in the mean shift
154
148
# algorithm
155
149
# Data is just 6 points in the plane
156
150
X = np .array (
157
151
[[1.0 , 1.0 ], [1.4 , 1.4 ], [1.8 , 1.2 ], [2.0 , 1.0 ], [2.1 , 1.1 ], [0.0 , 0.0 ]],
158
- dtype = dtype ,
152
+ dtype = global_dtype ,
159
153
)
160
154
161
155
# With a bin coarseness of 1.0 and min_bin_freq of 1, 3 bins should be
@@ -186,7 +180,7 @@ def test_bin_seeds(dtype):
186
180
cluster_std = 0.1 ,
187
181
random_state = 0 ,
188
182
)
189
- X = X .astype (dtype )
183
+ X = X .astype (global_dtype )
190
184
test_bins = get_bin_seeds (X , 1 )
191
185
assert_array_equal (test_bins , [[0 , 0 ], [1 , 1 ]])
192
186
@@ -204,10 +198,9 @@ def test_max_iter(max_iter):
204
198
assert np .allclose (c1 , c2 )
205
199
206
200
207
- @pytest .mark .parametrize ("dtype" , DTYPES )
208
- def test_mean_shift_zero_bandwidth (dtype ):
201
+ def test_mean_shift_zero_bandwidth (global_dtype ):
209
202
# Check that mean shift works when the estimated bandwidth is 0.
210
- X = np .array ([1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 ]).reshape (- 1 , 1 ).astype (dtype )
203
+ X = np .array ([1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 ]).reshape (- 1 , 1 ).astype (global_dtype )
211
204
212
205
# estimate_bandwidth with default args returns 0 on this dataset
213
206
bandwidth = estimate_bandwidth (X )
@@ -220,7 +213,7 @@ def test_mean_shift_zero_bandwidth(dtype):
220
213
# to no binning.
221
214
ms_binning = MeanShift (bin_seeding = True , bandwidth = None ).fit (X )
222
215
ms_nobinning = MeanShift (bin_seeding = False ).fit (X )
223
- expected_labels = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ], dtype = dtype )
216
+ expected_labels = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ], dtype = global_dtype )
224
217
225
218
assert v_measure_score (ms_binning .labels_ , expected_labels ) == 1
226
219
assert v_measure_score (ms_nobinning .labels_ , expected_labels ) == 1
0 commit comments