8000 [MRG + 1] Labels of clustering should start at 0 or -1 if noise (#10015) · scikit-learn/scikit-learn@1f7fa76 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1f7fa76

Browse files
albertcthomasagramfort
authored andcommitted
[MRG + 1] Labels of clustering should start at 0 or -1 if noise (#10015)
* test labels of clustering should start at 0 or -1 if noise * take into account agramfort's comment * fix test
1 parent 202b532 commit 1f7fa76

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,20 +1051,25 @@ def check_clustering(name, clusterer_orig):
10511051
assert_in(pred.dtype, [np.dtype('int32'), np.dtype('int64')])
10521052
assert_in(pred2.dtype, [np.dtype('int32'), np.dtype('int64')])
10531053

1054+
# Add noise to X to test the possible values of the labels
1055+
rng = np.random.RandomState(7)
1056+
X_noise = np.concatenate([X, rng.uniform(low=-3, high=3, size=(5, 2))])
1057+
labels = clusterer.fit_predict(X_noise)
1058+
10541059
# There should be at least one sample in every cluster. Equivalently
10551060
# labels_ should contain all the consecutive values between its
10561061
# min and its max.
1057-
pred_sorted = np.unique(pred)
1058-
assert_array_equal(pred_sorted, np.arange(pred_sorted[0],
1059-
pred_sorted[-1] + 1))
1062+
labels_sorted = np.unique(labels)
1063+
assert_array_equal(labels_sorted, np.arange(labels_sorted[0],
1064+
labels_sorted[-1] + 1))
10601065

1061-
# labels_ should be greater than -1
1062-
assert_greater_equal(pred_sorted[0], -1)
1063-
# labels_ should be less than n_clusters - 1
1066+
# Labels are expected to start at 0 (no noise) or -1 (if noise)
1067+
assert_true(labels_sorted[0] in [0, -1])
1068+
# Labels should be less than n_clusters - 1
10641069
if hasattr(clusterer, 'n_clusters'):
10651070
n_clusters = getattr(clusterer, 'n_clusters')
1066-
assert_greater_equal(n_clusters - 1, pred_sorted[-1])
1067-
# else labels_ should be less than max(labels_) which is necessarily true
1071+
assert_greater_equal(n_clusters - 1, labels_sorted[-1])
1072+
# else labels should be less than max(labels_) which is necessarily true
10681073

10691074

10701075
@ignore_warnings(category=DeprecationWarning)

0 commit comments

Comments
 (0)
0