8000 [MRG] Fix diagonal in DBSCAN with precomputed sparse neighbors graph … · scikit-learn/scikit-learn@819d8ef · GitHub
[go: up one dir, main page]

Skip to content

Commit 819d8ef

Browse files
TomDLTogrisel
authored andcommitted
[MRG] Fix diagonal in DBSCAN with precomputed sparse neighbors graph (#12105)
1 parent 2e2e69d commit 819d8ef

File tree

4 files changed

+39
-12
lines changed

4 files changed

+39
-12
lines changed

doc/whats_new/v0.20.rst

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ Changelog
4747
avoid pickling errors caused by the serialization of their methods.
4848
:issue:`12171` by :user:`Thomas Moreau <tomMoral>`
4949

50+
- |Fix| Fixed a bug in :class:`cluster.DBSCAN` with precomputed sparse neighbors
51+
graph, which would add explicitly zeros on the diagonal even when already
52+
present. :issue:`12105` by `Tom Dupre la Tour`_.
53+
5054
.. _changes_0_20:
5155

5256
Version 0.20.0
@@ -663,7 +667,7 @@ Support for Python 3.3 has been officially dropped.
663667

664668
- |Feature| :func:`metrics.classification_report` now reports all applicable averages on
665669
the given data, including micro, macro and weighted average as well as samples
666-
average for multilabel data. :issue:`11679` by :user:`Alexander Pacha <apacha>`.
670+
average for multilabel data. :issue:`11679` by :user:`Alexander Pacha <apacha>`.
667671

668672
- |Feature| :func:`metrics.average_precision_score` now supports binary
669673
``y_true`` other than ``{0, 1}`` or ``{-1, 1}`` through ``pos_label``
@@ -917,7 +921,7 @@ Support for Python 3.3 has been officially dropped.
917921
keyword arguments on to the pipeline's last estimator, enabling the use of
918922
parameters such as ``return_std`` in a pipeline with caution.
919923
:issue:`9304` by :user:`Breno Freitas <brenolf>`.
920-
924+
921925
- |API| :class:`pipeline.FeatureUnion` now supports ``'drop'`` as a transformer
922926
to drop features. :issue:`11144` by :user:`thomasjpfan`.
923927

@@ -1039,7 +1043,7 @@ Support for Python 3.3 has been officially dropped.
10391043
- |API| The NaN marker for the missing values has been changed
10401044
between the :class:`preprocessing.Imputer` and the
10411045
:class:`impute.SimpleImputer`.
1042-
``missing_values='NaN'`` should now be
1046+
``missing_values='NaN'`` should now be
10431047
``missing_values=np.nan``. :issue:`11211` by
10441048
:user:`Jeremie du Boisberranger <jeremiedbb>`.
10451049

doc/whats_new/v0.21.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ parameters, may produce different models from the previous version. This often
1717
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
1818
random sampling procedures.
1919

20-
- please add class and reason here (see version 0.20 what's new)
20+
- :class:`cluster.DBSCAN` (bug fix)
2121

2222
Details are listed in the changelog below.
2323

@@ -48,6 +48,10 @@ Support for Python 3.4 and below has been officially dropped.
4848
to set and that scales better, by :user:`Shane <espg>` and
4949
:user:`Adrin Jalali <adrinjalali>`.
5050

51+
- |Fix| Fixed a bug in :class:`cluster.DBSCAN` with precomputed sparse neighbors
52+
graph, which would add explicitly zeros on the diagonal even when already
53+
present. :issue:`12105` by `Tom Dupre la Tour`_.
54+
5155
Multiple modules
5256
................
5357

sklearn/cluster/dbscan_.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ..base import BaseEstimator, ClusterMixin
1616
from ..utils import check_array, check_consistent_length
17+
from ..utils.testing import ignore_warnings
1718
from ..neighbors import NearestNeighbors
1819

1920
from ._dbscan_inner import dbscan_inner
@@ -142,15 +143,16 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', metric_params=None,
142143
if metric == 'precomputed' and sparse.issparse(X):
143144
neighborhoods = np.empty(X.shape[0], dtype=object)
144145
X.sum_duplicates() # XXX: modifies X's internals in-place
146+
147+
# set the diagonal to explicit values, as a point is its own neighbor
148+
with ignore_warnings():
149+
X.setdiag(X.diagonal()) # XXX: modifies X's internals in-place
150+
145151
X_mask = X.data <= eps
146152
masked_indices = X.indices.astype(np.intp, copy=False)[X_mask]
147-
masked_indptr = np.concatenate(([0], np.cumsum(X_mask)))[X.indptr[1:]]
153+
masked_indptr = np.concatenate(([0], np.cumsum(X_mask)))
154+
masked_indptr = masked_indptr[X.indptr[1:-1]]
148155

149-
# insert the diagonal: a point is its own neighbor, but 0 distance
150-
# means absence from sparse matrix data
151-
masked_indices = np.insert(masked_indices, masked_indptr,
152-
np.arange(X.shape[0]))
153-
masked_indptr = masked_indptr[:-1] + np.arange(1, X.shape[0])
154156
# split into rows
155157
neighborhoods[:] = np.split(masked_indices, masked_indptr)
156158
else:

sklearn/cluster/tests/test_dbscan.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ def test_dbscan_sparse():
8181
assert_array_equal(labels_dense, labels_sparse)
8282

8383

84-
def test_dbscan_sparse_precomputed():
84+
@pytest.mark.parametrize('include_self', [False, True])
85+
def test_dbscan_sparse_precomputed(include_self):
8586
D = pairwise_distances(X)
8687
nn = NearestNeighbors(radius=.9).fit(X)
87-
D_sparse = nn.radius_neighbors_graph(mode='distance')
88+
X_ = X if include_self else None
89+
D_sparse = nn.radius_neighbors_graph(X=X_, mode='distance')
8890
# Ensure it is sparse not merely on diagonals:
8991
assert D_sparse.nnz < D.shape[0] * (D.shape[0] - 1)
9092
core_sparse, labels_sparse = dbscan(D_sparse,
@@ -97,6 +99,21 @@ def test_dbscan_sparse_precomputed():
9799
assert_array_equal(labels_dense, labels_sparse)
98100

99101

102+
@pytest.mark.parametrize('use_sparse', [True, False])
103+
@pytest.mark.parametrize('metric', ['precomputed', 'minkowski'])
104+
def test_dbscan_input_not_modified(use_sparse, metric):
105+
# test that the input is not modified by dbscan
106+
X = np.random.RandomState(0).rand(10, 10)
107+
X = sparse.csr_matrix(X) if use_sparse else X
108+
X_copy = X.copy()
109+
dbscan(X, metric=metric)
110+
111+
if use_sparse:
112+
assert_array_equal(X.toarray(), X_copy.toarray())
113+
else:
114+
assert_array_equal(X, X_copy)
115+
116+
100117
def test_dbscan_no_core_samples():
101118
rng = np.random.RandomState(0)
102119
X = rng.rand(40, 10)

0 commit comments

Comments
 (0)
0