8000 Assert review comments · scikit-learn/scikit-learn@e3cdaf1 · GitHub
[go: up one dir, main page]

Skip to content

Commit e3cdaf1

Browse files
jjerphanjeremiedbb
andcommitted
Assert review comments
Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
1 parent 70cde82 commit e3cdaf1

File tree

1 file changed

+23
-16
lines changed

1 file changed

+23
-16
lines changed

sklearn/cluster/tests/test_mean_shift.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from scipy import sparse
1111

1212
from sklearn.utils._testing import assert_array_equal
13-
from sklearn.utils._testing import assert_array_almost_equal
1413
from sklearn.utils._testing import assert_allclose
1514

1615
from sklearn.cluster import MeanShift
@@ -42,8 +41,12 @@ def test_estimate_bandwidth():
4241
def test_estimate_bandwidth_1sample(global_dtype):
4342
# Test estimate_bandwidth when n_samples=1 and quantile<1, so that
4443
# n_neighbors is set to 1.
45-
bandwidth = estimate_bandwidth(X.astype(global_dtype), n_samples=1, quantile=0.3)
46-
assert bandwidth == pytest.approx(0.0, abs=1e-5)
44+
bandwidth = estimate_bandwidth(
45+
X.astype(global_dtype, copy=False), n_samples=1, quantile=0.3
46+
)
47+
48+
assert bandwidth.dtype == X.dtype
49+
assert_allclose(bandwidth, 0.0, atol=1e-5)
4750

4851

4952
@pytest.mark.parametrize(
@@ -54,14 +57,15 @@ def test_mean_shift(
5457
global_dtype, bandwidth, cluster_all, expected, first_cluster_label
5558
):
5659
# Test MeanShift algorithm
60+
X_ = X.astype(global_dtype, copy=False)
5761
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
58-
labels = ms.fit(X.astype(global_dtype)).labels_
62+
labels = ms.fit(X_).labels_
5963
labels_unique = np.unique(labels)
6064
10000 n_clusters_ = len(labels_unique)
6165
assert n_clusters_ == expected
6266
assert labels_unique[0] == first_cluster_label
6367

64-
cluster_centers, labels_mean_shift = mean_shift(X, cluster_all=cluster_all)
68+
cluster_centers, labels_mean_shift = mean_shift(X_, cluster_all=cluster_all)
6569
labels_mean_shift_unique = np.unique(labels_mean_shift)
6670
n_clusters_mean_shift = len(labels_mean_shift_unique)
6771
assert n_clusters_mean_shift == expected
@@ -95,25 +99,24 @@ def test_parallel(global_dtype):
9599
random_state=11,
96100
)
97101

98-
X = X.astype(global_dtype)
102+
X = X.astype(global_dtype, copy=False)
99103

100104
ms1 = MeanShift(n_jobs=2)
101105
ms1.fit(X)
102106

103107
ms2 = MeanShift()
104108
ms2.fit(X)
105109

106-
assert_array_almost_equal(ms1.cluster_centers_, ms2.cluster_centers_)
110+
assert_allclose(ms1.cluster_centers_, ms2.cluster_centers_)
107111
assert_array_equal(ms1.labels_, ms2.labels_)
108112

109113

110114
def test_meanshift_predict(global_dtype):
111115
# Test MeanShift.predict
112-
global X
113116
ms = MeanShift(bandwidth=1.2)
114-
X = X.astype(global_dtype)
115-
labels = ms.fit_predict(X)
116-
labels2 = ms.predict(X)
117+
X_ = X.astype(global_dtype, copy=False)
118+
labels = ms.fit_predict(X_)
119+
labels2 = ms.predict(X_)
117120
assert_array_equal(labels, labels2)
118121

119122

@@ -171,7 +174,7 @@ def test_bin_seeds(global_dtype):
171174
# we bail and use the whole data here.
172175
with warnings.catch_warnings(record=True):
173176
test_bins = get_bin_seeds(X, 0.01, 1)
174-
assert_array_almost_equal(test_bins, X)
177+
assert_allclose(test_bins, X)
175178

176179
# tight clusters around [0, 0] and [1, 1], only get two bins
177180
X, _ = make_blobs(
@@ -181,7 +184,7 @@ def test_bin_seeds(global_dtype):
181184
cluster_std=0.1,
182185
random_state=0,
183186
)
184-
X = X.astype(global_dtype)
187+
X = X.astype(global_dtype, copy=False)
185188
test_bins = get_bin_seeds(X, 1)
186189
assert_array_equal(test_bins, [[0, 0], [1, 1]])
187190

@@ -201,7 +204,11 @@ def test_max_iter(max_iter):
201204

202205
def test_mean_shift_zero_bandwidth(global_dtype):
203206
# Check that mean shift works when the estimated bandwidth is 0.
204-
X = np.array([1, 1, 1, 2, 2, 2, 3, 3]).reshape(-1, 1).astype(global_dtype)
207+
X = (
208+
np.array([1, 1, 1, 2, 2, 2, 3, 3])
209+
.reshape(-1, 1)
210+
.astype(global_dtype, copy=False)
211+
)
205212

206213
# estimate_bandwidth with default args returns 0 on this dataset
207214
bandwidth = estimate_bandwidth(X)
@@ -216,6 +223,6 @@ def test_mean_shift_zero_bandwidth(global_dtype):
216223
ms_nobinning = MeanShift(bin_seeding=False).fit(X)
217224
expected_labels = np.array([0, 0, 0, 1, 1, 1, 2, 2])
218225

219-
assert v_measure_score(ms_binning.labels_, expected_labels) == pytest.approx(1)
220-
assert v_measure_score(ms_nobinning.labels_, expected_labels) == pytest.approx(1)
226+
assert_allclose(v_measure_score(ms_binning.labels_, expected_labels), 1)
227+
assert_allclose(v_measure_score(ms_nobinning.labels_, expected_labels), 1)
221228
assert_allclose(ms_binning.cluster_centers_, ms_nobinning.cluster_centers_)

0 commit comments

Comments
 (0)
0