8000 TST Use global_dtype · scikit-learn/scikit-learn@ab34666 · GitHub
[go: up one dir, main page]

Skip to content

Commit ab34666

Browse files
committed
TST Use global_dtype
1 parent b31cccb commit ab34666

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

sklearn/cluster/tests/test_mean_shift.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,32 +32,30 @@
3232
random_state=11,
3333
)
3434

35-
DTYPES = (np.float64, np.float32)
36-
3735

3836
def test_estimate_bandwidth():
3937
# Test estimate_bandwidth
4038
bandwidth = estimate_bandwidth(X, n_samples=200)
4139
assert 0.9 <= bandwidth <= 1.5
4240

4341

44-
@pytest.mark.parametrize("dtype", DTYPES)
45-
def test_estimate_bandwidth_1sample(dtype):
42+
def test_estimate_bandwidth_1sample(global_dtype):
4643
# Test estimate_bandwidth when n_samples=1 and quantile<1, so that
4744
# n_neighbors is set to 1.
48-
bandwidth = estimate_bandwidth(X.astype(dtype), n_samples=1, quantile=0.3)
45+
bandwidth = estimate_bandwidth(X.astype(global_dtype), n_samples=1, quantile=0.3)
4946
assert bandwidth == pytest.approx(0.0, abs=1e-5)
5047

5148

5249
@pytest.mark.parametrize(
5350
"bandwidth, cluster_all, expected, first_cluster_label",
5451
[(1.2, True, 3, 0), (1.2, False, 4, -1)],
5552
)
56-
@pytest.mark.parametrize("dtype", DTYPES)
57-
def test_mean_shift(dtype, bandwidth, cluster_all, expected, first_cluster_label):
53+
def test_mean_shift(
54+
global_dtype, bandwidth, cluster_all, expected, first_cluster_label
55+
):
5856
# Test MeanShift algorithm
5957
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
60-
labels = ms.fit(X.astype(dtype)).labels_
58+
labels = ms.fit(X.astype(global_dtype)).labels_
6159
labels_unique = np.unique(labels)
6260
n_clusters_ = len(labels_unique)
6361
assert n_clusters_ == expected
@@ -86,8 +84,7 @@ def test_estimate_bandwidth_with_sparse_matrix():
8684
estimate_bandwidth(X)
8785

8886

89-
@pytest.mark.parametrize("dtype", DTYPES)
90-
def test_parallel(dtype):
87+
def test_parallel(global_dtype):
9188
centers = np.array([[1, 1], [-1, -1], [1, -1]]) + 10
9289
X, _ = make_blobs(
9390
n_samples=50,
@@ -98,7 +95,7 @@ def test_parallel(dtype):
9895
random_state=11,
9996
)
10097

101-
X = X.astype(dtype)
98+
X = X.astype(global_dtype)
10299

103100
ms1 = MeanShift(n_jobs=2)
104101
ms1.fit(X)
@@ -110,11 +107,10 @@ def test_parallel(dtype):
110107
assert_array_equal(ms1.labels_, ms2.labels_)
111108

112109

113-
@pytest.mark.parametrize("dtype", DTYPES)
114-
def test_meanshift_predict(dtype):
110+
def test_meanshift_predict(global_dtype):
115111
# Test MeanShift.predict
116112
ms = MeanShift(bandwidth=1.2)
117-
Y = X.astype(dtype)
113+
Y = X.astype(global_dtype)
118114
labels = ms.fit_predict(Y)
119115
labels2 = ms.predict(Y)
120116
assert_array_equal(labels, labels2)
@@ -137,25 +133,23 @@ def test_unfitted():
137133
assert not hasattr(ms, "labels_")
138134

139135

140-
@pytest.mark.parametrize("dtype", DTYPES)
141-
def test_cluster_intensity_tie(dtype):
142-
X = np.array([[1, 1], [2, 1], [1, 0], [4, 7], [3, 5], [3, 6]], dtype=dtype)
136+
def test_cluster_intensity_tie(global_dtype):
137+
X = np.array([[1, 1], [2, 1], [1, 0], [4, 7], [3, 5], [3, 6]], dtype=global_dtype)
143138
c1 = MeanShift(bandwidth=2).fit(X)
144139

145-
X = np.array([[4, 7], [3, 5], [3, 6], [1, 1], [2, 1], [1, 0]], dtype=dtype)
140+
X = np.array([[4, 7], [3, 5], [3, 6], [1, 1], [2, 1], [1, 0]], dtype=global_dtype)
146141
c2 = MeanShift(bandwidth=2).fit(X)
147142
assert_array_equal(c1.labels_, [1, 1, 1, 0, 0, 0])
148143
assert_array_equal(c2.labels_, [0, 0, 0, 1, 1, 1])
149144

150145

151-
@pytest.mark.parametrize("dtype", DTYPES)
152-
def test_bin_seeds(dtype):
146+
def test_bin_seeds(global_dtype):
153147
# Test the bin seeding technique which can be used in the mean shift
154148
# algorithm
155149
# Data is just 6 points in the plane
156150
X = np.array(
157151
[[1.0, 1.0], [1.4, 1.4], [1.8, 1.2], [2.0, 1.0], [2.1, 1.1], [0.0, 0.0]],
158-
dtype=dtype,
152+
dtype=global_dtype,
159153
)
160154

161155
# With a bin coarseness of 1.0 and min_bin_freq of 1, 3 bins should be
@@ -186,7 +180,7 @@ def test_bin_seeds(dtype):
186180
cluster_std=0.1,
187181
random_state=0,
188182
)
189-
X = X.astype(dtype)
183+
X = X.astype(global_dtype)
190184
test_bins = get_bin_seeds(X, 1)
191185
assert_array_equal(test_bins, [[0, 0], [1, 1]])
192186

@@ -204,10 +198,9 @@ def test_max_iter(max_iter):
204198
assert np.allclose(c1, c2)
205199

206200

207-
@pytest.mark.parametrize("dtype", DTYPES)
208-
def test_mean_shift_zero_bandwidth(dtype):
201+
def test_mean_shift_zero_bandwidth(global_dtype):
209202
# Check that mean shift works when the estimated bandwidth is 0.
210-
X = np.array([1, 1, 1, 2, 2, 2, 3, 3]).reshape(-1, 1).astype(dtype)
203+
X = np.array([1, 1, 1, 2, 2, 2, 3, 3]).reshape(-1, 1).astype(global_dtype)
211204

212205
# estimate_bandwidth with default args returns 0 on this dataset
213206
bandwidth = estimate_bandwidth(X)
@@ -220,7 +213,7 @@ def test_mean_shift_zero_bandwidth(dtype):
220213
# to no binning.
221214
ms_binning = MeanShift(bin_seeding=True, bandwidth=None).fit(X)
222215
ms_nobinning = MeanShift(bin_seeding=False).fit(X)
223-
expected_labels = np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=dtype)
216+
expected_labels = np.array([0, 0, 0, 1, 1, 1, 2, 2], dtype=global_dtype)
224217

225218
assert v_measure_score(ms_binning.labels_, expected_labels) == 1
226219
assert v_measure_score(ms_nobinning.labels_, expected_labels) == 1

0 commit comments

Comments
 (0)
0