8000 [MRG+1] Add DBSCAN support for additional metric params by naoyak · Pull Request #8139 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Add DBSCAN support for additional metric params #8139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions sklearn/cluster/dbscan_.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from ._dbscan_inner import dbscan_inner


def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', metric_params=None,
algorithm='auto', leaf_size=30, p=2, sample_weight=None, n_jobs=1):
"""Perform DBSCAN clustering from vector array or distance matrix.

Expand Down Expand Up @@ -50,6 +50,11 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
must be square. X may be a sparse matrix, in which case only "nonzero"
elements may be considered neighbors for DBSCAN.

metric_params : dict, optional
Additional keyword arguments for the metric function.

.. versionadded:: 0.19

algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional
The algorithm to be used by the NearestNeighbors module
to compute pointwise distances and find nearest neighbors.
Expand Down Expand Up @@ -130,7 +135,8 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski',
else:
neighbors_model = NearestNeighbors(radius=eps, algorithm=algorithm,
leaf_size=leaf_size,
metric=metric, p=p,
metric=metric,
metric_params=metric_params, p=p,
n_jobs=n_jobs)
neighbors_model.fit(X)
# This has worst case O(n^2) memory complexity
Expand Down Expand Up @@ -184,6 +190,11 @@ class DBSCAN(BaseEstimator, ClusterMixin):
.. versionadded:: 0.17
metric *precomputed* to accept precomputed sparse matrix.

metric_params : dict, optional
Additional keyword arguments for the metric function.

.. versionadded:: 0.19

algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional
The algorithm to be used by the NearestNeighbors module
to compute pointwise distances and find nearest neighbors.
Expand Down Expand Up @@ -237,10 +248,12 @@ class DBSCAN(BaseEstimator, ClusterMixin):
"""

def __init__(self, eps=0.5, min_samples=5, metric='euclidean',
algorithm='auto', leaf_size=30, p=None, n_jobs=1):
metric_params=None, algorithm='auto', leaf_size=30, p=None,
n_jobs=1):
self.eps = eps
self.min_samples = min_samples
self.metric = metric
self.metric_params = metric_params
self.algorithm = algorithm
self.leaf_size = leaf_size
self.p = p
Expand Down
28 changes: 28 additions & 0 deletions sklearn/cluster/tests/test_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,34 @@ def test_dbscan_callable():
assert_equal(n_clusters_2, n_clusters)


def test_dbscan_metric_params():
# Tests that DBSCAN works with the metrics_params argument.
eps = 0.8
min_samples = 10
p = 1

# Compute DBSCAN with metric_params arg
db = DBSCAN(metric='minkowski', metric_params={'p': p}, eps=eps,
min_samples=min_samples, algorithm='ball_tree').fit(X)
core_sample_1, labels_1 = db.core_sample_indices_, db.labels_

# Test that sample labels are the same as passing Minkowski 'p' directly
db = DBSCAN(metric='minkowski', eps=eps, min_samples=min_samples,
algorithm='ball_tree', p=p).fit(X)
core_sample_2, labels_2 = db.core_sample_indices_, db.labels_

assert_array_equal(core_sample_1, core_sample_2)
assert_array_equal(labels_1, labels_2)

# Minkowski with p=1 should be equivalent to Manhattan distance
db = DBSCAN(metric='manhattan', eps=eps, min_samples=min_samples,
algorithm='ball_tree').fit(X)
core_sample_3, labels_3 = db.core_sample_indices_, db.labels_

assert_array_equal(core_sample_1, core_sample_3)
assert_array_equal(labels_1, labels_3)


def test_dbscan_balltree():
# Tests the DBSCAN algorithm with balltree for neighbor calculation.
eps = 0.8
Expand Down
0