8000 [MRG+1] Add DBSCAN support for additional metric params (#8139) · NelleV/scikit-learn@7d270a4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7d270a4

Browse files
naoyakNelleV
authored andcommitted
[MRG+1] Add DBSCAN support for additional metric params (scikit-learn#8139)
* Add DBSCAN support for additional metric params
1 parent aaf471f commit 7d270a4

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

sklearn/cluster/dbscan_.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ._dbscan_inner import dbscan_inner
2121

2222

23-
def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
23+
def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', metric_params=None,
2424
algorithm='auto', leaf_size=30, p=2, sample_weight=None, n_jobs=1):
2525
"""Perform DBSCAN clustering from vector array or distance matrix.
2626
@@ -50,6 +50,11 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
5050
must be square. X may be a sparse matrix, in which case only "nonzero"
5151
elements may be considered neighbors for DBSCAN.
5252
53+
metric_params : dict, optional
54+
Additional keyword arguments for the metric function.
55+
56+
.. versionadded:: 0.19
57+
5358
algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional
5459
The algorithm to be used by the NearestNeighbors module
5560
to compute pointwise distances and find nearest neighbors.
@@ -130,7 +135,8 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
130135
else:
131136
neighbors_model = NearestNeighbors(radius=eps, algorithm=algorithm,
132137
leaf_size=leaf_size,
133-
metric=metric, p=p,
138+
metric=metric,
139+
metric_params=metric_params, p=p,
134140
n_jobs=n_jobs)
135141
neighbors_model.fit(X)
136142
# This has worst case O(n^2) memory complexity
@@ -184,6 +190,11 @@ class DBSCAN(BaseEstimator, ClusterMixin):
184190
.. versionadded:: 0.17
185191
metric *precomputed* to accept precomputed sparse matrix.
186192
193+
metric_params : dict, optional
194+
Additional keyword arguments for the metric function.
195+
196+
.. versionadded:: 0.19
197+
187198
algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional
188199
The algorithm to be used by the NearestNeighbors module
189200
to compute pointwise distances and find nearest neighbors.
@@ -237,10 +248,12 @@ class DBSCAN(BaseEstimator, ClusterMixin):
237248
"""
238249

239250
def __init__(self, eps=0.5, min_samples=5, metric='euclidean',
240-
algorithm='auto', leaf_size=30, p=None, n_jobs=1):
251+
metric_params=None, algorithm='auto', leaf_size=30, p=None,
252+
n_jobs=1):
241253
self.eps = eps
242254
self.min_samples = min_samples
243255
self.metric = metric
256+
self.metric_params = metric_params
244257
self.algorithm = algorithm
245258
self.leaf_size = leaf_size
246259
self.p = p

sklearn/cluster/tests/test_dbscan.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,34 @@ def test_dbscan_callable():
133133
assert_equal(n_clusters_2, n_clusters)
134134

135135

136+
def test_dbscan_metric_params():
137+
# Tests that DBSCAN works with the metrics_params argument.
138+
eps = 0.8
139+
min_samples = 10
140+
p = 1
141+
142+
# Compute DBSCAN with metric_params arg
143+
db = DBSCAN(metric='minkowski', metric_params={'p': p}, eps=eps,
144+
min_samples=min_samples, algorithm='ball_tree').fit(X)
145+
core_sample_1, labels_1 = db.core_sample_indices_, db.labels_
146+
147+
# Test that sample labels are the same as passing Minkowski 'p' directly
148+
db = DBSCAN(metric='minkowski', eps=eps, min_samples=min_samples,
149+
algorithm='ball_tree', p=p).fit(X)
150+
core_sample_2, labels_2 = db.core_sample_indices_, db.labels_
151+
152+
assert_array_equal(core_sample_1, core_sample_2)
153+
assert_array_equal(labels_1, labels_2)
154+
155+
# Minkowski with p=1 should be equivalent to Manhattan distance
156+
db = DBSCAN(metric='manhattan', eps=eps, min_samples=min_samples,
157+
algorithm='ball_tree').fit(X)
158+
core_sample_3, labels_3 = db.core_sample_indices_, db.labels_
159+
160+
assert_array_equal(core_sample_1, core_sample_3)
161+
assert_array_equal(labels_1, labels_3)
162+
163+
136164
def test_dbscan_balltree():
137165
# Tests the DBSCAN algorithm with balltree for neighbor calculation.
138166
eps = 0.8

0 commit comments

Comments
 (0)
0