10
10
from scipy import sparse
11
11
12
12
from sklearn .utils ._testing import assert_array_equal
13
- from sklearn .utils ._testing import assert_array_almost_equal
14
13
from sklearn .utils ._testing import assert_allclose
15
14
16
15
from sklearn .cluster import MeanShift
@@ -42,8 +41,12 @@ def test_estimate_bandwidth():
42
41
def test_estimate_bandwidth_1sample (global_dtype ):
43
42
# Test estimate_bandwidth when n_samples=1 and quantile<1, so that
44
43
# n_neighbors is set to 1.
45
- bandwidth = estimate_bandwidth (X .astype (global_dtype ), n_samples = 1 , quantile = 0.3 )
46
- assert bandwidth == pytest .approx (0.0 , abs = 1e-5 )
44
+ bandwidth = estimate_bandwidth (
45
+ X .astype (global_dtype , copy = False ), n_samples = 1 , quantile = 0.3
46
+ )
47
+
48
+ assert bandwidth .dtype == X .dtype
49
+ assert_allclose (bandwidth , 0.0 , atol = 1e-5 )
47
50
48
51
49
52
@pytest .mark .parametrize (
@@ -54,14 +57,15 @@ def test_mean_shift(
54
57
global_dtype , bandwidth , cluster_all , expected , first_cluster_label
55
58
):
56
59
# Test MeanShift algorithm
60
+ X_ = X .astype (global_dtype , copy = False )
57
61
ms = MeanShift (bandwidth = bandwidth , cluster_all = cluster_all )
58
- labels = ms .fit (X . astype ( global_dtype ) ).labels_
62
+ labels = ms .fit (X_ ).labels_
59
63
labels_unique = np .unique (labels )
60
64
10000
n_clusters_ = len (labels_unique )
61
65
assert n_clusters_ == expected
62
66
assert labels_unique [0 ] == first_cluster_label
63
67
64
- cluster_centers , labels_mean_shift = mean_shift (X , cluster_all = cluster_all )
68
+ cluster_centers , labels_mean_shift = mean_shift (X_ , cluster_all = cluster_all )
65
69
labels_mean_shift_unique = np .unique (labels_mean_shift )
66
70
n_clusters_mean_shift = len (labels_mean_shift_unique )
67
71
assert n_clusters_mean_shift == expected
@@ -95,25 +99,24 @@ def test_parallel(global_dtype):
95
99
random_state = 11 ,
96
100
)
97
101
98
- X = X .astype (global_dtype )
102
+ X = X .astype (global_dtype , copy = False )
99
103
100
104
ms1 = MeanShift (n_jobs = 2 )
101
105
ms1 .fit (X )
102
106
103
107
ms2 = MeanShift ()
104
108
ms2 .fit (X )
105
109
106
- assert_array_almost_equal (ms1 .cluster_centers_ , ms2 .cluster_centers_ )
110
+ assert_allclose (ms1 .cluster_centers_ , ms2 .cluster_centers_ )
107
111
assert_array_equal (ms1 .labels_ , ms2 .labels_ )
108
112
109
113
110
114
def test_meanshift_predict (global_dtype ):
111
115
# Test MeanShift.predict
112
- global X
113
116
ms = MeanShift (bandwidth = 1.2 )
114
- X = X .astype (global_dtype )
115
- labels = ms .fit_predict (X )
116
- labels2 = ms .predict (X )
117
+ X_ = X .astype (global_dtype , copy = False )
118
+ labels = ms .fit_predict (X_ )
119
+ labels2 = ms .predict (X_ )
117
120
assert_array_equal (labels , labels2 )
118
121
119
122
@@ -171,7 +174,7 @@ def test_bin_seeds(global_dtype):
171
174
# we bail and use the whole data here.
172
175
with warnings .catch_warnings (record = True ):
173
176
test_bins = get_bin_seeds (X , 0.01 , 1 )
174
- assert_array_almost_equal (test_bins , X )
177
+ assert_allclose (test_bins , X )
175
178
176
179
# tight clusters around [0, 0] and [1, 1], only get two bins
177
180
X , _ = make_blobs (
@@ -181,7 +184,7 @@ def test_bin_seeds(global_dtype):
181
184
cluster_std = 0.1 ,
182
185
random_state = 0 ,
183
186
)
184
- X = X .astype (global_dtype )
187
+ X = X .astype (global_dtype , copy = False )
185
188
test_bins = get_bin_seeds (X , 1 )
186
189
assert_array_equal (test_bins , [[0 , 0 ], [1 , 1 ]])
187
190
@@ -201,7 +204,11 @@ def test_max_iter(max_iter):
201
204
202
205
def test_mean_shift_zero_bandwidth (global_dtype ):
203
206
# Check that mean shift works when the estimated bandwidth is 0.
204
- X = np .array ([1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 ]).reshape (- 1 , 1 ).astype (global_dtype )
207
+ X = (
208
+ np .array ([1 , 1 , 1 , 2 , 2 , 2 , 3 , 3 ])
209
+ .reshape (- 1 , 1 )
210
+ .astype (global_dtype , copy = False )
211
+ )
205
212
206
213
# estimate_bandwidth with default args returns 0 on this dataset
207
214
bandwidth = estimate_bandwidth (X )
@@ -216,6 +223,6 @@ def test_mean_shift_zero_bandwidth(global_dtype):
216
223
ms_nobinning = MeanShift (bin_seeding = False ).fit (X )
217
224
expected_labels = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ])
218
225
219
- assert v_measure_score (ms_binning .labels_ , expected_labels ) == pytest . approx ( 1 )
220
- assert v_measure_score (ms_nobinning .labels_ , expected_labels ) == pytest . approx ( 1 )
226
+ assert_allclose ( v_measure_score (ms_binning .labels_ , expected_labels ), 1 )
227
+ assert_allclose ( v_measure_score (ms_nobinning .labels_ , expected_labels ), 1 )
221
228
assert_allclose (ms_binning .cluster_centers_ , ms_nobinning .cluster_centers_ )
0 commit comments