8000 Modified dbscan_.py · scikit-learn/scikit-learn@4f209ed · GitHub
[go: up one dir, main page]

Skip to content

Commit 4f209ed

Browse files
Maria AndriopoulouMaria Andriopoulou
Maria Andriopoulou
authored and
Maria Andriopoulou
committed
Modified dbscan_.py
1 parent 22eedc1 commit 4f209ed

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

sklearn/cluster/dbscan_.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from ..base import BaseEstimator, ClusterMixin
1616
from ..utils import check_array, check_consistent_length
17+
from ..utils.validation import FLOAT_DTYPES
1718
from ..neighbors import NearestNeighbors
1819

1920
from ._dbscan_inner import dbscan_inner
@@ -111,7 +112,7 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', metric_params=None,
111112
if not eps > 0.0:
112113
raise ValueError("eps must be positive.")
113114

114-
X = check_array(X, accept_sparse='csr')
115+
X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES)
115116
if sample_weight is not None:
116117
sample_weight = np.asarray(sample_weight)
117118
check_consistent_length(X, sample_weight)
@@ -279,7 +280,7 @@ def fit(self, X, y=None, sample_weight=None):
279280
y : Ignored
280281
281282
"""
282-
X = check_array(X, accept_sparse='csr')
283+
X = check_array(X, accept_sparse='csr', dtype=FLOAT_DTYPES)
283284
clust = dbscan(X, sample_weight=sample_weight,
284285
**self.get_params())
285286
self.core_sample_indices_, self.labels_ = clust

0 commit comments

Comments
 (0)
0