12
12
# License: BSD 3 clause
13
13
14
14
from abc import ABC , abstractmethod
15
+ from numbers import Integral , Real
15
16
import warnings
16
17
17
18
import numpy as np
34
35
from ..utils import check_random_state
35
36
from ..utils .validation import check_is_fitted , _check_sample_weight
36
37
from ..utils .validation import _is_arraylike_not_scalar
38
+ from ..utils ._param_validation import Interval
39
+ from ..utils ._param_validation import StrOptions
40
+ from ..utils ._param_validation import validate_params
37
41
from ..utils ._openmp_helpers import _openmp_effective_n_threads
38
42
from ..utils ._readonly_array_wrapper import ReadonlyArrayWrapper
39
43
from ..exceptions import ConvergenceWarning
55
59
# Initialization heuristic
56
60
57
61
62
+ @validate_params (
63
+ {
64
+ "X" : ["array-like" , "sparse matrix" ],
65
+ "n_clusters" : [Interval (Integral , 1 , None , closed = "left" )],
66
+ "x_squared_norms" : ["array-like" , None ],
67
+ "random_state" : ["random_state" ],
68
+ "n_local_trials" : [Interval (Integral , 1 , None , closed = "left" ), None ],
69
+ }
70
+ )
58
71
def kmeans_plusplus (
59
72
X , n_clusters , * , x_squared_norms = None , random_state = None , n_local_trials = None
60
73
):
@@ -114,7 +127,6 @@ def kmeans_plusplus(
114
127
>>> indices
115
128
array([4, 2])
116
129
"""
117
-
118
130
# Check data
119
131
check_array (X , accept_sparse = "csr" , dtype = [np .float64 , np .float32 ])
120
132
@@ -135,12 +147,6 @@ def kmeans_plusplus(
135
147
f"be equal to the length of n_samples { X .shape [0 ]} ."
136
148
)
137
149
138
- if n_local_trials is not None and n_local_trials < 1 :
139
- raise ValueError (
140
- f"n_local_trials is set to { n_local_trials } but should be an "
141
- "integer value greater than zero."
142
- )
143
-
144
150
random_state = check_random_state (random_state )
145
151
146
152
# Call private k-means++
@@ -794,6 +800,16 @@ class _BaseKMeans(
794
800
):
795
801
"""Base class for KMeans and MiniBatchKMeans"""
796
802
803
+ _parameter_constraints = {
804
+ "n_clusters" : [Interval (Integral , 1 , None , closed = "left" )],
805
+ "init" : [StrOptions ({"k-means++" , "random" }), callable , "array-like" ],
806
+ "n_init" : [Interval (Integral , 1 , None , closed = "left" )],
807
+ "max_iter" : [Interval (Integral , 1 , None , closed = "left" )],
808
+ "tol" : [Interval (Real , 0 , None , closed = "left" )],
809
+ "verbose" : [Interval (Integral , 0 , None , closed = "left" ), bool ],
810
+ "random_state" : ["random_state" ],
811
+ }
812
+
797
813
def __init__ (
798
814
self ,
799
815
n_clusters ,
@@ -813,16 +829,7 @@ def __init__(
813
829
self .verbose = verbose
814
830
self .random_state = random_state
815
831
816
- def _check_params (self , X ):
817
- # n_init
818
- if self .n_init <= 0 :
819
- raise ValueError (f"n_init should be > 0, got { self .n_init } instead." )
820
- self ._n_init = self .n_init
821
-
822
- # max_iter
823
- if self .max_iter <= 0 :
824
- raise ValueError (f"max_iter should be > 0, got { self .max_iter } instead." )
825
-
832
+ def _check_params_vs_input (self , X ):
826
833
# n_clusters
827
834
if X .shape [0 ] < self .n_clusters :
828
835
raise ValueError (
@@ -833,16 +840,7 @@ def _check_params(self, X):
833
840
self ._tol = _tolerance (X , self .tol )
834
841
835
842
# init
836
- if not (
837
- _is_arraylike_not_scalar (self .init )
838
- or callable (self .init )
839
- or (isinstance (self .init , str ) and self .init in ["k-means++" , "random" ])
840
- ):
841
- raise ValueError (
842
- "init should be either 'k-means++', 'random', an array-like or a "
843
- f"callable, got '{ self .init } ' instead."
844
- )
845
-
843
+ self ._n_init = self .n_init
846
844
if _is_arraylike_not_scalar (self .init ) and self ._n_init != 1 :
847
845
warnings .warn (
848
846
"Explicit initial center position passed: performing only"
@@ -1275,6 +1273,14 @@ class KMeans(_BaseKMeans):
1275
1273
[ 1., 2.]])
1276
1274
"""
1277
1275
1276
+ _parameter_constraints = {
1277
+ ** _BaseKMeans ._parameter_constraints ,
1278
+ "copy_x" : [bool ],
1279
+ "algorithm" : [
1280
+ StrOptions ({"lloyd" , "elkan" , "auto" , "full" }, deprecated = {"auto" , "full" })
1281
+ ],
1282
+ }
1283
+
1278
1284
def __init__ (
1279
1285
self ,
1280
1286
n_clusters = 8 ,
@@ -1301,15 +1307,8 @@ def __init__(
1301
1307
self .copy_x = copy_x
1302
1308
self .algorithm = algorithm
1303
1309
1304
- def _check_params (self , X ):
1305
- super ()._check_params (X )
1306
-
1307
- # algorithm
1308
- if self .algorithm not in ("lloyd" , "elkan" , "auto" , "full" ):
1309
- raise ValueError (
1310
- "Algorithm must be either 'lloyd' or 'elkan', "
1311
- f"got { self .algorithm } instead."
1312
- )
1310
+ def _check_params_vs_input (self , X ):
1311
+ super ()._check_params_vs_input (X )
1313
1312
1314
1313
self ._algorithm = self .algorithm
1315
1314
if self ._algorithm in ("auto" , "full" ):
@@ -1362,6 +1361,8 @@ def fit(self, X, y=None, sample_weight=None):
1362
1361
self : object
1363
1362
Fitted estimator.
1364
1363
"""
1364
+ self ._validate_params ()
1365
+
1365
1366
X = self ._validate_data (
1366
1367
X ,
1367
1368
accept_sparse = "csr" ,
@@ -1371,7 +1372,8 @@ def fit(self, X, y=None, sample_weight=None):
1371
1372
accept_large_sparse = False ,
1372
1373
)
1373
1374
1374
- self ._check_params (X )
1375
+ self ._check_params_vs_input (X )
1376
+
1375
1377
random_state = check_random_state (self .random_state )
1376
1378
sample_weight = _check_sample_weight (sample_weight , X , dtype = X .dtype )
1377
1379
self ._n_threads = _openmp_effective_n_threads ()
@@ -1755,6 +1757,15 @@ class MiniBatchKMeans(_BaseKMeans):
1755
1757
array([0, 1], dtype=int32)
1756
1758
"""
1757
1759
1760
+ _parameter_constraints = {
1761
+ ** _BaseKMeans ._parameter_constraints ,
1762
+ "batch_size" : [Interval (Integral , 1 , None , closed = "left" )],
1763
+ "compute_labels" : [bool ],
1764
+ "max_no_improvement" : [Interval (Integral , 0 , None , closed = "left" ), None ],
1765
+ "init_size" : [Interval (Integral , 1 , None , closed = "left" ), None ],
1766
+ "reassignment_ratio" : [Interval (Real , 0 , None , closed = "left" )],
1767
+ }
1768
+
1758
1769
def __init__ (
1759
1770
self ,
1760
1771
n_clusters = 8 ,
@@ -1788,26 +1799,12 @@ def __init__(
1788
1799
self .init_size = init_size
1789
1800
self .reassignment_ratio = reassignment_ratio
1790
1801
1791
- def _check_params (self , X ):
1792
- super ()._check_params (X )
1793
-
1794
- # max_no_improvement
1795
- if self .max_no_improvement is not None and self .max_no_improvement < 0 :
1796
- raise ValueError (
1797
- "max_no_improvement should be >= 0, got "
1798
- f"{ self .max_no_improvement } instead."
1799
- )
1802
+ def _check_params_vs_input (self , X ):
1803
+ super ()._check_params_vs_input (X )
1800
1804
1801
- # batch_size
1802
- if self .batch_size <= 0 :
1803
- raise ValueError (
1804
- f"batch_size should be > 0, got { self .batch_size } instead."
1805
- )
1806
1805
self ._batch_size = min (self .batch_size , X .shape [0 ])
1807
1806
1808
1807
# init_size
1809
- if self .init_size is not None and self .init_size <= 0 :
1810
- raise ValueError (f"init_size should be > 0, got { self .init_size } instead." )
1811
1808
self ._init_size = self .init_size
1812
1809
if self ._init_size is None :
1813
1810
self ._init_size = 3 * self ._batch_size
@@ -1949,6 +1946,8 @@ def fit(self, X, y=None, sample_weight=None):
1949
1946
self : object
1950
1947
Fitted estimator.
1951
1948
"""
1949
+ self ._validate_params ()
1950
+
1952
1951
X = self ._validate_data (
1953
1952
X ,
1954
1953
accept_sparse = "csr" ,
@@ -1957,7 +1956,7 @@ def fit(self, X, y=None, sample_weight=None):
1957
1956
accept_large_sparse = False ,
1958
1957
)
1959
1958
1960
- self ._check_params (X )
1959
+ self ._check_params_vs_input (X )
1961
1960
random_state = check_random_state (self .random_state )
1962
1961
sample_weight = _check_sample_weight (sample_weight , X , dtype = X .dtype )
1963
1962
self ._n_threads = _openmp_effective_n_threads ()
@@ -2106,6 +2105,9 @@ def partial_fit(self, X, y=None, sample_weight=None):
2106
2105
"""
2107
2106
has_centers = hasattr (self , "cluster_centers_" )
2108
2107
2108
+ if not has_centers :
2109
+ self ._validate_params ()
2110
+
2109
2111
X = self ._validate_data (
2110
2112
X ,
2111
2113
accept_sparse = "csr" ,
@@ -2126,7 +2128,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
2126
2128
2127
2129
if not has_centers :
2128
2130
# this instance has not been fitted yet (fit or partial_fit)
2129
- self ._check_params (X )
2131
+ self ._check_params_vs_input (X )
2130
2132
self ._n_threads = _openmp_effective_n_threads ()
2131
2133
2132
2134
# Validate init array
0 commit comments