File tree 2 files changed +9
-3
lines changed
2 files changed +9
-3
lines changed Original file line number Diff line number Diff line change 17
17
from ..base import BaseEstimator , ClusterMixin , _fit_context
18
18
from ..metrics .pairwise import _VALID_METRICS
19
19
from ..neighbors import NearestNeighbors
20
- from ..utils ._param_validation import Interval , StrOptions
20
+ from ..utils ._param_validation import Interval , StrOptions , validate_params
21
21
from ..utils .validation import _check_sample_weight
22
22
from ._dbscan_inner import dbscan_inner
23
23
24
24
25
- # This function is not validated using validate_params because
26
- # it's just a factory for DBSCAN.
25
+ @validate_params (
26
+ {
27
+ "X" : ["array-like" , "sparse matrix" ],
28
+ "sample_weight" : ["array-like" , None ],
29
+ },
30
+ prefer_skip_nested_validation = False ,
31
+ )
27
32
def dbscan (
28
33
X ,
29
34
eps = 0.5 ,
Original file line number Diff line number Diff line change @@ -344,6 +344,7 @@ def test_function_param_validation(func_module):
344
344
345
345
PARAM_VALIDATION_CLASS_WRAPPER_LIST = [
346
346
("sklearn.cluster.affinity_propagation" , "sklearn.cluster.AffinityPropagation" ),
347
+ ("sklearn.cluster.dbscan" , "sklearn.cluster.DBSCAN" ),
347
348
("sklearn.cluster.k_means" , "sklearn.cluster.KMeans" ),
348
349
("sklearn.cluster.mean_shift" , "sklearn.cluster.MeanShift" ),
349
350
("sklearn.cluster.spectral_clustering" , "sklearn.cluster.SpectralClustering" ),
You can’t perform that action at this time.
0 commit comments