diff --git a/sklearn/cluster/dbscan_.py b/sklearn/cluster/dbscan_.py index ed79546d73eb9..a02db3feafb00 100644 --- a/sklearn/cluster/dbscan_.py +++ b/sklearn/cluster/dbscan_.py @@ -20,7 +20,7 @@ from ._dbscan_inner import dbscan_inner -def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', +def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', metric_params=None, algorithm='auto', leaf_size=30, p=2, sample_weight=None, n_jobs=1): """Perform DBSCAN clustering from vector array or distance matrix. @@ -50,6 +50,11 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', must be square. X may be a sparse matrix, in which case only "nonzero" elements may be considered neighbors for DBSCAN. + metric_params : dict, optional + Additional keyword arguments for the metric function. + + .. versionadded:: 0.19 + algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional The algorithm to be used by the NearestNeighbors module to compute pointwise distances and find nearest neighbors. @@ -130,7 +135,8 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', else: neighbors_model = NearestNeighbors(radius=eps, algorithm=algorithm, leaf_size=leaf_size, - metric=metric, p=p, + metric=metric, + metric_params=metric_params, p=p, n_jobs=n_jobs) neighbors_model.fit(X) # This has worst case O(n^2) memory complexity @@ -184,6 +190,11 @@ class DBSCAN(BaseEstimator, ClusterMixin): .. versionadded:: 0.17 metric *precomputed* to accept precomputed sparse matrix. + metric_params : dict, optional + Additional keyword arguments for the metric function. + + .. versionadded:: 0.19 + algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional The algorithm to be used by the NearestNeighbors module to compute pointwise distances and find nearest neighbors. @@ -237,10 +248,12 @@ class DBSCAN(BaseEstimator, ClusterMixin): """ def __init__(self, eps=0.5, min_samples=5, metric='euclidean', - algorithm='auto', leaf_size=30, p=None, n_jobs=1): + metric_params=None, algorithm='auto', leaf_size=30, p=None, + n_jobs=1): self.eps = eps self.min_samples = min_samples self.metric = metric + self.metric_params = metric_params self.algorithm = algorithm self.leaf_size = leaf_size self.p = p diff --git a/sklearn/cluster/tests/test_dbscan.py b/sklearn/cluster/tests/test_dbscan.py index afddf52b03ae8..b4b34dcefb822 100644 --- a/sklearn/cluster/tests/test_dbscan.py +++ b/sklearn/cluster/tests/test_dbscan.py @@ -133,6 +133,34 @@ def test_dbscan_callable(): assert_equal(n_clusters_2, n_clusters) +def test_dbscan_metric_params(): + # Tests that DBSCAN works with the metrics_params argument. + eps = 0.8 + min_samples = 10 + p = 1 + + # Compute DBSCAN with metric_params arg + db = DBSCAN(metric='minkowski', metric_params={'p': p}, eps=eps, + min_samples=min_samples, algorithm='ball_tree').fit(X) + core_sample_1, labels_1 = db.core_sample_indices_, db.labels_ + + # Test that sample labels are the same as passing Minkowski 'p' directly + db = DBSCAN(metric='minkowski', eps=eps, min_samples=min_samples, + algorithm='ball_tree', p=p).fit(X) + core_sample_2, labels_2 = db.core_sample_indices_, db.labels_ + + assert_array_equal(core_sample_1, core_sample_2) + assert_array_equal(labels_1, labels_2) + + # Minkowski with p=1 should be equivalent to Manhattan distance + db = DBSCAN(metric='manhattan', eps=eps, min_samples=min_samples, + algorithm='ball_tree').fit(X) + core_sample_3, labels_3 = db.core_sample_indices_, db.labels_ + + assert_array_equal(core_sample_1, core_sample_3) + assert_array_equal(labels_1, labels_3) + + def test_dbscan_balltree(): # Tests the DBSCAN algorithm with balltree for neighbor calculation. eps = 0.8