8000 Revert "TST Fix missing assert and parametrize k-means tests (#12368)" · xhluca/scikit-learn@7012a16 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7012a16

Browse files
author
Xing
authored
Revert "TST Fix missing assert and parametrize k-means tests (scikit-learn#12368)"
This reverts commit 347c272.
1 parent 65e8ea0 commit 7012a16

File tree

1 file changed

+120
-33
lines changed

1 file changed

+120
-33
lines changed

sklearn/cluster/tests/test_k_means.py

Lines changed: 120 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,12 @@ def _check_fitted_model(km):
181181
% km.n_clusters, km.fit, [[0., 1.]])
182182

183183

184+
def test_k_means_plus_plus_init():
185+
km = KMeans(init="k-means++", n_clusters=n_clusters,
186+
random_state=42).fit(X)
187+
_check_fitted_model(km)
188+
189+
184190
def test_k_means_new_centers():
185191
# Explore the part of the code where a new center is reassigned
186192
X = np.array([[0, 0, 1, 1],
@@ -223,6 +229,24 @@ def test_k_means_precompute_distances_flag():
223229
assert_raises(ValueError, km.fit, X)
224230

225231

232+
def test_k_means_plus_plus_init_sparse():
233+
km = KMeans(init="k-means++", n_clusters=n_clusters, random_state=42)
234+
km.fit(X_csr)
235+
_check_fitted_model(km)
236+
237+
238+
def test_k_means_random_init():
239+
km = KMeans(init="random", n_clusters=n_clusters, random_state=42)
240+
km.fit(X)
241+
_check_fitted_model(km)
242+
243+
244+
def test_k_means_random_init_sparse():
245+
km = KMeans(init="random", n_clusters=n_clusters, random_state=42)
246+
km.fit(X_csr)
247+
_check_fitted_model(km)
248+
249+
226250
def test_k_means_plus_plus_init_not_precomputed():
227251
km = KMeans(init="k-means++", n_clusters=n_clusters, random_state=42,
228252
precompute_distances=False).fit(X)
@@ -235,11 +259,10 @@ def test_k_means_random_init_not_precomputed():
235259
_check_fitted_model(km)
236260

237261

238-
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse'])
239-
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()])
240-
def test_k_means_init(data, init):
241-
km = KMeans(init=init, n_clusters=n_clusters, random_state=42, n_init=1)
242-
km.fit(data)
262+
def test_k_means_perfect_init():
263+
km = KMeans(init=centers.copy(), n_clusters=n_clusters, random_state=42,
264+
n_init=1)
265+
km.fit(X)
243266
_check_fitted_model(km)
244267

245268

@@ -292,6 +315,13 @@ def test_k_means_fortran_aligned_data():
292315
assert_array_equal(km.labels_, labels)
293316

294317

318+
def test_mb_k_means_plus_plus_init_dense_array():
319+
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters,
320+
random_state=42)
321+
mb_k_means.fit(X)
322+
_check_fitted_model(mb_k_means)
323+
324+
295325
def test_mb_kmeans_verbose():
296326
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters,
297327
random_state=42, verbose=1)
@@ -303,25 +333,49 @@ def test_mb_kmeans_verbose():
303333
sys.stdout = old_stdout
304334

305335

336+
def test_mb_k_means_plus_plus_init_sparse_matrix():
337+
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters,
338+
random_state=42)
339+
mb_k_means.fit(X_csr)
340+
_check_fitted_model(mb_k_means)
341+
342+
306343
def test_minibatch_init_with_large_k():
307344
mb_k_means = MiniBatchKMeans(init='k-means++', init_size=10, n_clusters=20)
308345
# Check that a warning is raised, as the number clusters is larger
309346
# than the init_size
310347
assert_warns(RuntimeWarning, mb_k_means.fit, X)
311348

312349

350+
def test_minibatch_k_means_random_init_dense_array():
351+
# increase n_init to make random init stable enough
352+
mb_k_means = MiniBatchKMeans(init="random", n_clusters=n_clusters,
353+
random_state=42, n_init=10).fit(X)
354+
_check_fitted_model(mb_k_means)
355+
356+
357+
def test_minibatch_k_means_random_init_sparse_csr():
358+
# increase n_init to make random init stable enough
359+
mb_k_means = MiniBatchKMeans(init="random", n_clusters=n_clusters,
360+
random_state=42, n_init=10).fit(X_csr)
361+
_check_fitted_model(mb_k_means)
362+
363+
364+
def test_minibatch_k_means_perfect_init_dense_array():
365+
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters,
366+
random_state=42, n_init=1).fit(X)
367+
_check_fitted_model(mb_k_means)
368+
369+
313370
def test_minibatch_k_means_init_multiple_runs_with_explicit_centers():
314371
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters,
315372
random_state=42, n_init=10)
316373
assert_warns(RuntimeWarning, mb_k_means.fit, X)
317374

318375

319-
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse'])
320-
@pytest.mark.parametrize('init', ["random", 'k-means++', centers.copy()])
321-
def test_minibatch_k_means_init(data, init):
322-
mb_k_means = MiniBatchKMeans(init=init, n_clusters=n_clusters,
323-
random_state=42, n_init=10)
324-
mb_k_means.fit(data)
376+
def test_minibatch_k_means_perfect_init_sparse_csr():
377+
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters,
378+
random_state=42, n_init=1).fit(X_csr)
325379
_check_fitted_model(mb_k_means)
326380

327381

@@ -531,39 +585,64 @@ def test_predict():
531585
assert_array_equal(pred, km.labels_)
532586

533587

534-
@pytest.mark.parametrize('algo', ['full', 'elkan'])
535-
def test_score(algo):
536-
# Check that fitting k-means with multiple inits gives better score
588+
def test_score():
589+
590+
km1 = KMeans(n_clusters=n_clusters, max_iter=1, random_state=42, n_init=1)
591+
s1 = km1.fit(X).score(X)
592+
km2 = KMeans(n_clusters=n_clusters, max_iter=10, random_state=42, n_init=1)
593+
s2 = km2.fit(X).score(X)
594+
assert_greater(s2, s1)
595+
537596
km1 = KMeans(n_clusters=n_clusters, max_iter=1, random_state=42, n_init=1,
538-
algorithm=algo)
597+
algorithm='elkan')
539598
s1 = km1.fit(X).score(X)
540599
km2 = KMeans(n_clusters=n_clusters, max_iter=10, random_state=42, n_init=1,
541-
algorithm=algo)
600+
algorithm='elkan')
542601
s2 = km2.fit(X).score(X)
543602
assert_greater(s2, s1)
544603

545604

546-
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse'])
547-
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()])
548-
def test_predict_minibatch(data, init):
549-
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init=init,
550-
n_init=10, random_state=0).fit(data)
605+
def test_predict_minibatch_dense_input():
606+
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, random_state=40).fit(X)
607+
608+
# sanity check: predict centroid labels
609+
pred = mb_k_means.predict(mb_k_means.cluster_centers_)
610+
assert_array_equal(pred, np.arange(n_clusters))
611+
612+
# sanity check: re-predict labeling for training set samples
613+
pred = mb_k_means.predict(X)
614+
assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_)
615+
616+
617+
def test_predict_minibatch_kmeanspp_init_sparse_input():
618+
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init='k-means++',
619+
n_init=10).fit(X_csr)
551620

552621
# sanity check: re-predict labeling for training set samples
553-
assert_array_equal(mb_k_means.predict(data), mb_k_means.labels_)
622+
assert_array_equal(mb_k_means.predict(X_csr), mb_k_means.labels_)
554623

555624
# sanity check: predict centroid labels
556625
pred = mb_k_means.predict(mb_k_means.cluster_centers_)
557626
assert_array_equal(pred, np.arange(n_clusters))
558627

559-
560-
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()])
561-
def test_predict_minibatch_dense_sparse(init):
562628
# check that models trained on sparse input also works for dense input at
563629
# predict time
564-
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init=init,
565-
n_init=10, random_state=0).fit(X_csr)
630+
assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_)
566631

632+
633+
def test_predict_minibatch_random_init_sparse_input():
634+
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init='random',
635+
n_init=10).fit(X_csr)
636+
637+
# sanity check: re-predict labeling for training set samples
638+
assert_array_equal(mb_k_means.predict(X_csr), mb_k_means.labels_)
639+
640+
# sanity check: predict centroid labels
641+
pred = mb_k_means.predict(mb_k_means.cluster_centers_)
642+
assert_array_equal(pred, np.arange(n_clusters))
643+
644+
# check that models trained on sparse input also works for dense input at
645+
# predict time
567646
assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_)
568647

569648

@@ -615,19 +694,27 @@ def test_fit_transform():
615694
assert_array_almost_equal(X1, X2)
616695

617696

618-
@pytest.mark.parametrize('algo', ['full', 'elkan'])
619-
def test_predict_equal_labels(algo):
697+
def test_predict_equal_labels():
698+
km = KMeans(random_state=13, n_jobs=1, n_init=1, max_iter=1,
699+
algorithm='full')
700+
km.fit(X)
701+
assert_array_equal(km.predict(X), km.labels_)
702+
620703
km = KMeans(random_state=13, n_jobs=1, n_init=1, max_iter=1,
621-
algorithm=algo)
704+
algorithm='elkan')
622705
km.fit(X)
623706
assert_array_equal(km.predict(X), km.labels_)
624707

625708

626709
def test_full_vs_elkan():
627-
km1 = KMeans(algorithm='full', random_state=13).fit(X)
628-
km2 = KMeans(algorithm='elkan', random_state=13).fit(X)
629710

630-
assert homogeneity_score(km1.predict(X), km2.predict(X)) == 1.0
711+
km1 = KMeans(algorithm='full', random_state=13)
712+
km2 = KMeans(algorithm='elkan', random_state=13)
713+
714+
km1.fit(X)
715+
km2.fit(X)
716+
717+
homogeneity_score(km1.predict(X), km2.predict(X)) == 1.0
631718

632719

633720
def test_n_init():

0 commit comments

Comments
 (0)
0