8000 FIX/TST boundary cases in dbscan · scikit-learn/scikit-learn@cdb0577 · GitHub
[go: up one dir, main page]

Skip to content

Commit cdb0577

Browse files
committed
FIX/TST boundary cases in dbscan
1 parent d39fe22 commit cdb0577

File tree

2 files changed

+32
-17
lines changed

2 files changed

+32
-17
lines changed

sklearn/cluster/dbscan_.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
3434
3535
min_samples : int, optional
3636
The number of samples (or total weight) in a neighborhood for a point
37-
to be considered as a core point.
37+
to be considered as a core point. This number is inclusive of the
38+
core point.
3839
3940
metric : string, or callable
4041
The metric to use when calculating distance between instances in a
@@ -122,7 +123,7 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
122123
labels = -np.ones(X.shape[0], dtype=np.int)
123124

124125
# A list of all core samples found.
125-
core_samples = np.flatnonzero(n_neighbors > min_samples)
126+
core_samples = np.flatnonzero(n_neighbors >= min_samples)
126127
index_order = core_samples[random_state.permutation(core_samples.shape[0])]
127128

128129
# label_num is the label given to the new cluster
@@ -170,7 +171,8 @@ class DBSCAN(BaseEstimator, ClusterMixin):
170171
as in the same neighborhood.
171172
min_samples : int, optional
172173
The number of samples (or total weight) in a neighborhood for a point
173-
to be considered as a core point.
174+
to be considered as a core point. This number is inclusive of the
175+
core point.
174176
metric : string, or callable
175177
The metric to use when calculating distance between instances in a
176178
feature array. If metric is a string or callable, it must be one of
@@ -234,8 +236,8 @@ def fit(self, X, y=None, sample_weight=None):
234236
A feature array, or array of distances between samples if
235237
``metric='precomputed'``.
236238
sample_weight : array, shape (n_samples,), optional
237-
Weight of each sample, such that a sample with weight greater
238-
than ``min_samples`` is automatically a core sample; a sample with
239+
Weight of each sample, such that a sample with weight at least
240+
``min_samples`` is automatically a core sample; a sample with
239241
negative weight may inhibit its eps-neighbor from being core.
240242
Note that weights are absolute, and default to 1.
241243
"""
@@ -260,8 +262,8 @@ def fit_predict(self, X, y=None, sample_weight=None):
260262
A feature array, or array of distances between samples if
261263
``metric='precomputed'``.
262264
sample_weight : array, shape (n_samples,), optional
263-
Weight of each sample, such that a sample with weight greater
264-
than ``min_samples`` is automatically a core sample; a sample with
265+
Weight of each sample, such that a sample with weight at least
266+
``min_samples`` is automatically a core sample; a sample with
265267
negative weight may inhibit its eps-neighbor from being core.
266268
Note that weights are absolute, and default to 1.
267269

sklearn/cluster/tests/test_dbscan.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from sklearn.utils.testing import assert_equal
1313
from sklearn.utils.testing import assert_array_equal
1414
from sklearn.utils.testing import assert_raises
15+
from sklearn.utils.testing import assert_in
16+
from sklearn.utils.testing import assert_not_in
1517
from sklearn.cluster.dbscan_ import DBSCAN
1618
from sklearn.cluster.dbscan_ import dbscan
1719
from sklearn.cluster.tests.common import generate_clustered_data
@@ -185,33 +187,44 @@ def test_pickle():
185187
assert_equal(type(pickle.loads(s)), obj.__class__)
186188

187189

190+
def test_boundaries():
191+
# ensure min_samples is inclusive of core point
192+
core, _ = dbscan([[0], [1]], eps=2, min_samples=2)
193+
assert_in(0, core)
194+
# ensure eps is inclusive of circumference
195+
core, _ = dbscan([[0], [1], [1]], eps=1, min_samples=2)
196+
assert_in(0, core)
197+
core, _ = dbscan([[0], [1], [1]], eps=.99, min_samples=2)
198+
assert_not_in(0, core)
199+
200+
188201
def test_weighted_dbscan():
189202
# ensure sample_weight is validated
190203
assert_raises(ValueError, dbscan, [[0], [1]], sample_weight=[2])
191204
assert_raises(ValueError, dbscan, [[0], [1]], sample_weight=[2, 3, 4])
192205

193206
# ensure sample_weight has an effect
194207
assert_array_equal([], dbscan([[0], [1]], sample_weight=None,
195-
min_samples=5)[0])
208+
min_samples=6)[0])
196209
assert_array_equal([], dbscan([[0], [1]], sample_weight=[5, 5],
197-
min_samples=5)[0])
210+
min_samples=6)[0])
198211
assert_array_equal([0], dbscan([[0], [1]], sample_weight=[6, 5],
199-
min_samples=5)[0])
212+
min_samples=6)[0])
200213
assert_array_equal([0, 1], dbscan([[0], [1]], sample_weight=[6, 6],
201-
min_samples=5)[0])
214+
min_samples=6)[0])
202215

203216
# points within eps of each other:
204217
assert_array_equal([0, 1], dbscan([[0], [1]], eps=1.5,
205-
sample_weight=[5, 1], min_samples=5)[0])
218+
sample_weight=[5, 1], min_samples=6)[0])
206219
# and effect of non-positive and non-integer sample_weight:
207220
assert_array_equal([], dbscan([[0], [1]], sample_weight=[5, 0],
208-
eps=1.5, min_samples=5)[0])
209-
assert_array_equal([0, 1], dbscan([[0], [1]], sample_weight=[5, 0.1],
210-
eps=1.5, min_samples=5)[0])
221+
eps=1.5, min_samples=6)[0])
222+
assert_array_equal([0, 1], dbscan([[0], [1]], sample_weight=[5.9, 0.1],
223+
eps=1.5, min_samples=6)[0])
211224
assert_array_equal([0, 1], dbscan([[0], [1]], sample_weight=[6, 0],
212-
eps=1.5, min_samples=5)[0])
225+
eps=1.5, min_samples=6)[0])
213226
assert_array_equal([], dbscan([[0], [1]], sample_weight=[6, -1],
214-
eps=1.5, min_samples=5)[0])
227+
eps=1.5, min_samples=6)[0])
215228

216229
# for non-negative sample_weight, cores should be identical to repetition
217230
rng = np.random.RandomState(42)

0 commit comments

Comments
 (0)
29D4
0