8000 MNT Make dbscan call DBSCAN.fit and not the opposite (#14994) · scikit-learn/scikit-learn@c56bce4 · GitHub
[go: up one dir, main page]

Skip to content

Commit c56bce4

Browse files
NicolasHugTomDLT
authored andcommitted
MNT Make dbscan call DBSCAN.fit and not the opposite (#14994)
1 parent ef97ab2 commit c56bce4

File tree

1 file changed

+63
-55
lines changed

1 file changed

+63
-55
lines changed

sklearn/cluster/dbscan_.py

Lines changed: 63 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -136,58 +136,12 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', metric_params=None,
136136
DBSCAN revisited, revisited: why and how you should (still) use DBSCAN.
137137
ACM Transactions on Database Systems (TODS), 42(3), 19.
138138
"""
139-
if not eps > 0.0:
140-
raise ValueError("eps must be positive.")
141-
142-
X = check_array(X, accept_sparse='csr')
143-
if sample_weight is not None:
144-
sample_weight = np.asarray(sample_weight)
145-
check_consistent_length(X, sample_weight)
146-
147-
# Calculate neighborhood for all samples. This leaves the original point
148-
# in, which needs to be considered later (i.e. point i is in the
149-
# neighborhood of point i. While True, its useless information)
150-
if metric == 'precomputed' and sparse.issparse(X):
151-
neighborhoods = np.empty(X.shape[0], dtype=object)
152-
X.sum_duplicates() # XXX: modifies X's internals in-place
153-
154-
# set the diagonal to explicit values, as a point is its own neighbor
155-
with warnings.catch_warnings():
156-
warnings.simplefilter('ignore', sparse.SparseEfficiencyWarning)
157-
X.setdiag(X.diagonal()) # XXX: modifies X's internals in-place
158-
159-
X_mask = X.data <= eps
160-
masked_indices = X.indices.astype(np.intp, copy=False)[X_mask]
161-
masked_indptr = np.concatenate(([0], np.cumsum(X_mask)))
162-
masked_indptr = masked_indptr[X.indptr[1:-1]]
163-
164-
# split into rows
165-
neighborhoods[:] = np.split(masked_indices, masked_indptr)
166-
else:
167-
neighbors_model = NearestNeighbors(radius=eps, algorithm=algorithm,
168-
leaf_size=leaf_size,
169-
metric=metric,
170-
metric_params=metric_params, p=p,
171-
n_jobs=n_jobs)
172-
neighbors_model.fit(X)
173-
# This has worst case O(n^2) memory complexity
174-
neighborhoods = neighbors_model.radius_neighbors(X, eps,
175-
return_distance=False)
176-
177-
if sample_weight is None:
178-
n_neighbors = np.array([len(neighbors)
179-
for neighbors in neighborhoods])
180-
else:
181-
n_neighbors = np.array([np.sum(sample_weight[neighbors])
182-
for neighbors in neighborhoods])
183-
184-
# Initially, all samples are noise.
185-
labels = np.full(X.shape[0], -1, dtype=np.intp)
186-
187-
# A list of all core samples found.
188-
core_samples = np.asarray(n_neighbors >= min_samples, dtype=np.uint8)
189-
dbscan_inner(core_samples, neighborhoods, labels)
190-
return np.where(core_samples)[0], labels
139+
140+
est = DBSCAN(eps=eps, min_samples=min_samples, metric=metric,
141+
metric_params=metric_params, algorithm=algorithm,
142+
leaf_size=leaf_size, p=p, n_jobs=n_jobs)
143+
est.fit(X, sample_weight=sample_weight)
144+
return est.core_sample_indices_, est.labels_
191145

192146

193147
class DBSCAN(ClusterMixin, BaseEstimator):
@@ -353,9 +307,63 @@ def fit(self, X, y=None, sample_weight=None):
353307
354308
"""
355309
X = check_array(X, accept_sparse='csr')
356-
clust = dbscan(X, sample_weight=sample_weight,
357-
**self.get_params())
358-
self.core_sample_indices_, self.labels_ = clust
310+
311+
if not self.eps > 0.0:
312+
raise ValueError("eps must be positive.")
313+
314+
if sample_weight is not None:
315+
sample_weight = np.asarray(sample_weight)
316+
check_consistent_length(X, sample_weight)
317+
318+
# Calculate neighborhood for all samples. This leaves the original
319+
# point in, which needs to be considered later (i.e. point i is in the
320+
# neighborhood of point i. While True, its useless information)
321+
if self.metric == 'precomputed' and sparse.issparse(X):
322+
neighborhoods = np.empty(X.shape[0], dtype=object)
323+
X.sum_duplicates() # XXX: modifies X's internals in-place
324+
325+
# set the diagonal to explicit values, as a point is its own
326+
# neighbor
327+
with warnings.catch_warnings():
328+
warnings.simplefilter('ignore', sparse.SparseEfficiencyWarning)
329+
X.setdiag(X.diagonal()) # XXX: modifies X's internals in-place
330+
331+
X_mask = X.data <= self.eps
332+
masked_indices = X.indices.astype(np.intp, copy=False)[X_mask]
333+
masked_indptr = np.concatenate(([0], np.cumsum(X_mask)))
334+
masked_indptr = masked_indptr[X.indptr[1:-1]]
335+
336+
# split into rows
337+
neighborhoods[:] = np.split(masked_indices, masked_indptr)
338+
else:
339+
neighbors_model = NearestNeighbors(
340+
radius=self.eps, algorithm=self.algorithm,
341+
leaf_size=self.leaf_size, metric=self.metric,
342+
metric_params=self.metric_params, p=self.p, n_jobs=self.n_jobs
343+
)
344+
neighbors_model.fit(X)
345+
# This has worst case O(n^2) memory complexity
346+
neighborhoods = neighbors_model.radius_neighbors(
347+
X, self.eps, return_distance=False)
348+
349+
if sample_weight is None:
350+
n_neighbors = np.array([len(neighbors)
351+
for neighbors in neighborhoods])
352+
else:
353+
n_neighbors = np.array([np.sum(sample_weight[neighbors])
354+
for neighbors in neighborhoods])
355+
356+
# Initially, all samples are noise.
357+
labels = np.full(X.shape[0], -1, dtype=np.intp)
358+
359+
# A list of all core samples found.
360+
core_samples = np.asarray(n_neighbors >= self.min_samples,
361+
dtype=np.uint8)
362+
dbscan_inner(core_samples, neighborhoods, labels)
363+
364+
self.core_sample_indices_ = np.where(core_samples)[0]
365+
self.labels_ = labels
366+
359367
if len(self.core_sample_indices_):
360368
# fix for scipy sparse indexing issue
361369
self.components_ = X[self.core_sample_indices_].copy()

0 commit comments

Comments
 (0)
0