@@ -109,10 +109,11 @@ def test_parallel(global_dtype):
109
109
110
110
def test_meanshift_predict (global_dtype ):
111
111
# Test MeanShift.predict
112
+ global X
112
113
ms = MeanShift (bandwidth = 1.2 )
113
- Y = X .astype (global_dtype )
114
- labels = ms .fit_predict (Y )
115
- labels2 = ms .predict (Y )
114
+ X = X .astype (global_dtype )
115
+ labels = ms .fit_predict (X )
116
+ labels2 = ms .predict (X )
116
117
assert_array_equal (labels , labels2 )
117
118
118
119
@@ -213,8 +214,8 @@ def test_mean_shift_zero_bandwidth(global_dtype):
213
214
# to no binning.
214
215
ms_binning = MeanShift (bin_seeding = True , bandwidth = None ).fit (X )
215
216
ms_nobinning = MeanShift (bin_seeding = False ).fit (X )
216
- expected_labels = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ], dtype = global_dtype )
217
+ expected_labels = np .array ([0 , 0 , 0 , 1 , 1 , 1 , 2 , 2 ])
217
218
218
- assert v_measure_score (ms_binning .labels_ , expected_labels ) == 1
219
- assert v_measure_score (ms_nobinning .labels_ , expected_labels ) == 1
219
+ assert v_measure_score (ms_binning .labels_ , expected_labels ) == pytest . approx ( 1 )
220
+ assert v_measure_score (ms_nobinning .labels_ , expected_labels ) == pytest . approx ( 1 )
220
221
assert_allclose (ms_binning .cluster_centers_ , ms_nobinning .cluster_centers_ )
0 commit comments