@@ -1051,20 +1051,25 @@ def check_clustering(name, clusterer_orig):
1051
1051
assert_in (pred .dtype , [np .dtype ('int32' ), np .dtype ('int64' )])
1052
1052
assert_in (pred2 .dtype , [np .dtype ('int32' ), np .dtype ('int64' )])
1053
1053
1054
+ # Add noise to X to test the possible values of the labels
1055
+ rng = np .random .RandomState (7 )
1056
+ X_noise = np .concatenate ([X , rng .uniform (low = - 3 , high = 3 , size = (5 , 2 ))])
1057
+ labels = clusterer .fit_predict (X_noise )
1058
+
1054
1059
# There should be at least one sample in every cluster. Equivalently
1055
1060
# labels_ should contain all the consecutive values between its
1056
1061
# min and its max.
1057
- pred_sorted = np .unique (pred )
1058
- assert_array_equal (pred_sorted , np .arange (pred_sorted [0 ],
1059
- pred_sorted [- 1 ] + 1 ))
1062
+ labels_sorted = np .unique (labels )
1063
+ assert_array_equal (labels_sorted , np .arange (labels_sorted [0 ],
1064
+ labels_sorted [- 1 ] + 1 ))
1060
1065
1061
- # labels_ should be greater than -1
1062
- assert_greater_equal ( pred_sorted [0 ], - 1 )
1063
- # labels_ should be less than n_clusters - 1
1066
+ # Labels are expected to start at 0 (no noise) or -1 (if noise)
1067
+ assert_true ( labels_sorted [0 ] in [ 0 , - 1 ] )
1068
+ # Labels should be less than n_clusters - 1
1064
1069
if hasattr (clusterer , 'n_clusters' ):
1065
1070
n_clusters = getattr (clusterer , 'n_clusters' )
1066
- assert_greater_equal (n_clusters - 1 , pred_sorted [- 1 ])
1067
- # else labels_ should be less than max(labels_) which is necessarily true
1071
+ assert_greater_equal (n_clusters - 1 , labels_sorted [- 1 ])
1072
+ # else labels should be less than max(labels_) which is necessarily true
1068
1073
1069
1074
1070
1075
@ignore_warnings (category = DeprecationWarning )
0 commit comments