8000 TST Fix missing assert and parametrize k-means tests (#12368) · scikit-learn/scikit-learn@76b1078 · GitHub
[go: up one dir, main page]

Skip to content

Commit 76b1078

Browse files
jeremiedbbrth
authored andcommitted
TST Fix missing assert and parametrize k-means tests (#12368)
1 parent 1e7cd7d commit 76b1078

File tree

1 file changed

+33
-120
lines changed

1 file changed

+33
-120
lines changed

sklearn/cluster/tests/test_k_means.py

Lines changed: 33 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,6 @@ def _check_fitted_model(km):
181181
% km.n_clusters, km.fit, [[0., 1.]])
182182

183183

184-
def test_k_means_plus_plus_init():
185-
km = KMeans(init="k-means++", n_clusters=n_clusters,
186-
random_state=42).fit(X)
187-
_check_fitted_model(km)
188-
189-
190184
def test_k_means_new_centers():
191185
# Explore the part of the code where a new center is reassigned
192186
X = np.array([[0, 0, 1, 1],
@@ -229,24 +223,6 @@ def test_k_means_precompute_distances_flag():
229223
assert_raises(ValueError, km.fit, X)
230224

231225

232-
def test_k_means_plus_plus_init_sparse():
233-
km = KMeans(init="k-means++", n_clusters=n_clusters, random_state=42)
234-
km.fit(X_csr)
235-
_check_fitted_model(km)
236-
237-
238-
def test_k_means_random_init():
239-
km = KMeans(init="random", n_clusters=n_clusters, random_state=42)
240-
km.fit(X)
241-
_check_fitted_model(km)
242-
243-
244-
def test_k_means_random_init_sparse():
245-
km = KMeans(init="random", n_clusters=n_clusters, random_state=42)
246-
km.fit(X_csr)
247-
_check_fitted_model(km)
248-
249-
250226
def test_k_means_plus_plus_init_not_precomputed():
251227
km = KMeans(init="k-means++", n_clusters=n_clusters, random_state=42,
252228
precompute_distances=False).fit(X)
@@ -259,10 +235,11 @@ def test_k_means_random_init_not_precomputed():
259235
_check_fitted_model(km)
260236

261237

262-
def test_k_means_perfect_init():
263-
km = KMeans(init=centers.copy(), n_clusters=n_clusters, random_state=42,
264-
n_init=1)
265-
km.fit(X)
238+
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse'])
239+
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()])
240+
def test_k_means_init(data, init):
241+
km = KMeans(init=init, n_clusters=n_clusters, random_state=42, n_init=1)
242+
km.fit(data)
266243
_check_fitted_model(km)
267244

268245

@@ -315,13 +292,6 @@ def test_k_means_fortran_aligned_data():
315292
assert_array_equal(km.labels_, labels)
316293

317294

318-
def test_mb_k_means_plus_plus_init_dense_array():
319-
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters,
320-
random_state=42)
321-
mb_k_means.fit(X)
322-
_check_fitted_model(mb_k_means)
323-
324-
325295
def test_mb_kmeans_verbose():
326296
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters,
327297
random_state=42, verbose=1)
@@ -333,49 +303,25 @@ def test_mb_kmeans_verbose():
333303
sys.stdout = old_stdout
334304

335305

336-
def test_mb_k_means_plus_plus_init_sparse_matrix():
337-
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters,
338-
random_state=42)
339-
mb_k_means.fit(X_csr)
340-
_check_fitted_model(mb_k_means)
341-
342-
343306
def test_minibatch_init_with_large_k():
344307
mb_k_means = MiniBatchKMeans(init='k-means++', init_size=10, n_clusters=20)
345308
# Check that a warning is raised, as the number clusters is larger
346309
# than the init_size
347310
assert_warns(RuntimeWarning, mb_k_means.fit, X)
348311

349312

350-
def test_minibatch_k_means_random_init_dense_array():
351-
# increase n_init to make random init stable enough
352-
mb_k_means = MiniBatchKMeans(init="random", n_clusters=n_clusters,
353-
random_state=42, n_init=10).fit(X)
354-
_check_fitted_model(mb_k_means)
355-
356-
357-
def test_minibatch_k_means_random_init_sparse_csr():
358-
# increase n_init to make random init stable enough
359-
mb_k_means = MiniBatchKMeans(init="random", n_clusters=n_clusters,
360-
random_state=42, n_init=10).fit(X_csr)
361-
_check_fitted_model(mb_k_means)
362-
363-
364-
def test_minibatch_k_means_perfect_init_dense_array():
365-
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters,
366-
random_state=42, n_init=1).fit(X)
367-
_check_fitted_model(mb_k_means)
368-
369-
370313
def test_minibatch_k_means_init_multiple_runs_with_explicit_centers():
371314
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters,
372315
random_state=42, n_init=10)
373316
assert_warns(RuntimeWarning, mb_k_means.fit, X)
374317

375318

376-
def test_minibatch_k_means_perfect_init_sparse_csr():
377-
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters,
378-
random_state=42, n_init=1).fit(X_csr)
319+
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse'])
320+
@pytest.mark.parametrize('init', ["random", 'k-means++', centers.copy()])
321+
def test_minibatch_k_means_init(data, init):
322+
mb_k_means = MiniBatchKMeans(init=init, n_clusters=n_clusters,
323+
random_state=42, n_init=10)
324+
mb_k_means.fit(data)
379325
_check_fitted_model(mb_k_means)
380326

381327

@@ -585,64 +531,39 @@ def test_predict():
585531
assert_array_equal(pred, km.labels_)
586532

587533

588-
def test_score():
589-
590-
km1 = KMeans(n_clusters=n_clusters, max_iter=1, random_state=42, n_init=1)
591-
s1 = km1.fit(X).score(X)
592-
km2 = KMeans(n_clusters=n_clusters, max_iter=10, random_state=42, n_init=1)
593-
s2 = km2.fit(X).score(X)
594-
assert_greater(s2, s1)
595-
534+
@pytest.mark.parametrize('algo', ['full', 'elkan'])
535+
def test_score(algo):
536+
# Check that fitting k-means with multiple inits gives better score
596537
km1 = KMeans(n_clusters=n_clusters, max_iter=1, random_state=42, n_init=1,
597-
algorithm='elkan')
538+
algorithm=algo)
598539
s1 = km1.fit(X).score(X)
599540
km2 = KMeans(n_clusters=n_clusters, max_iter=10, random_state=42, n_init=1,
600-
algorithm='elkan')
541+
algorithm=algo)
601542
s2 = km2.fit(X).score(X)
602543
assert_greater(s2, s1)
603544

604545

605-
def test_predict_minibatch_dense_input():
606-
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, random_state=40).fit(X)
607-
608-
# sanity check: predict centroid labels
609-
pred = mb_k_means.predict(mb_k_means.cluster_centers_)
610-
assert_array_equal(pred, np.arange(n_clusters))
611-
612-
# sanity check: re-predict labeling for training set samples
613-
pred = mb_k_means.predict(X)
614-
assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_)
615-
616-
617-
def test_predict_minibatch_kmeanspp_init_sparse_input():
618-
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init='k-means++',
619-
n_init=10).fit(X_csr)
546+
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse'])
547+
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()])
548+
def test_predict_minibatch(data, init):
549+
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init=init,
550+
n_init=10, random_state=0).fit(data)
620551

621552
# sanity check: re-predict labeling for training set samples
622-
assert_array_equal(mb_k_means.predict(X_csr), mb_k_means.labels_)
553+
assert_array_equal(mb_k_means.predict(data), mb_k_means.labels_)
623554

624555
# sanity check: predict centroid labels
625556
pred = mb_k_means.predict(mb_k_means.cluster_centers_)
626557
assert_array_equal(pred, np.arange(n_clusters))
627558

628-
# check that models trained on sparse input also works for dense input at
629-
# predict time
630-
assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_)
631-
632-
633-
def test_predict_minibatch_random_init_sparse_input():
634-
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init='random',
635-
n_init=10).fit(X_csr)
636-
637-
# sanity check: re-predict labeling for training set samples
638-
assert_array_equal(mb_k_means.predict(X_csr), mb_k_means.labels_)
639-
640-
# sanity check: predict centroid labels
641-
pred = mb_k_means.predict(mb_k_means.cluster_centers_)
642-
assert_array_equal(pred, np.arange(n_clusters))
643559

560+
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()])
561+
def test_predict_minibatch_dense_sparse(init):
644562
# check that models trained on sparse input also works for dense input at
645563
# predict time
564+
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init=init,
565+
n_init=10, random_state=0).fit(X_csr)
566+
646567
assert_array_equal(mb_k_means.predict(< F987 span class=pl-c1>X), mb_k_means.labels_)
647568

648569

@@ -694,27 +615,19 @@ def test_fit_transform():
694615
assert_array_almost_equal(X1, X2)
695616

696617

697-
def test_predict_equal_labels():
698-
km = KMeans(random_state=13, n_jobs=1, n_init=1, max_iter=1,
699-
algorithm='full')
700-
km.fit(X)
701-
assert_array_equal(km.predict(X), km.labels_)
702-
618+
@pytest.mark.parametrize('algo', ['full', 'elkan'])
619+
def test_predict_equal_labels(algo):
703620
km = KMeans(random_state=13, n_jobs=1, n_init=1, max_iter=1,
704-
algorithm='elkan')
621+
algorithm=algo)
705622
km.fit(X)
706623
assert_array_equal(km.predict(X), km.labels_)
707624

708625

709626
def test_full_vs_elkan():
627+
km1 = KMeans(algorithm='full', random_state=13).fit(X)
628+
km2 = KMeans(algorithm='elkan', random_state=13).fit(X)
710629

711-
km1 = KMeans(algorithm='full', random_state=13)
712-
km2 = KMeans(algorithm='elkan', random_state=13)
713-
714-
km1.fit(X)
715-
km2.fit(X)
716-
717-
homogeneity_score(km1.predict(X), km2.predict(X)) == 1.0
630+
assert homogeneity_score(km1.predict(X), km2.predict(X)) == 1.0
718631

719632

720633
def test_n_init():

0 commit comments

Comments
 (0)
0