8000 MAINT Common parameter validation (#22722) · lesteve/scikit-learn@e652d45 · GitHub
[go: up one dir, main page]

Skip to content

Commit e652d45

Browse files
jeremiedbbadrinjalali
authored andcommitted
MAINT Common parameter validation (scikit-learn#22722)
* common parameter validation * black * cln * wip * wip * rework * renaming and cleaning * lint * re lint * cln * add tests * lint * make random_state constraint * lint * closed positional * increase coverage + validate constraints * exp typing * trigger ci ? * lint * cln * rev type hints * cln * interval closed kwarg only * address comments * address comments + more tests + cln + improve err msg * lint * cln * cln * address comments * address comments * lint * adapt or skip new estimators * lint Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent 5116c07 commit e652d45

File tree

9 files changed

+1068
-143
lines changed

9 files changed

+1068
-143
lines changed

sklearn/base.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525
from .utils.validation import _check_feature_names_in
2626
from .utils.validation import _generate_get_feature_names_out
2727
from .utils.validation import check_is_fitted
28-
from .utils._estimator_html_repr import estimator_html_repr
2928
from .utils.validation import _get_feature_names
29+
from .utils._estimator_html_repr import estimator_html_repr
30+
from .utils._param_validation import validate_parameter_constraints
3031

3132

3233
def clone(estimator, *, safe=True):
@@ -601,6 +602,20 @@ def _validate_data(
601602

602603
return out
603604

605+
def _validate_params(self):
606+
"""Validate types and values of constructor parameters
607+
608+
The expected type and values must be defined in the `_parameter_constraints`
609+
class attribute, which is a dictionary `param_name: list of constraints`. See
610+
the docstring of `validate_parameter_constraints` for a description of the
611+
accepted constraints.
612+
"""
613+
validate_parameter_constraints(
614+
self._parameter_constraints,
615+
self.get_params(deep=False),
616+
caller_name=self.__class__.__name__,
617+
)
618+
604619
@property
605620
def _repr_html_(self):
606621
"""HTML representation of estimator.

sklearn/cluster/_bisect_k_means.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..utils.validation import check_is_fitted
1818
from ..utils.validation import _check_sample_weight
1919
from ..utils.validation import check_random_state
20-
from ..utils.validation import _is_arraylike_not_scalar
20+
from ..utils._param_validation import StrOptions
2121

2222

2323
class _BisectingTree:
@@ -204,6 +204,14 @@ class BisectingKMeans(_BaseKMeans):
204204
[ 1., 2.]])
205205
"""
206206

207+
_parameter_constraints = {
208+
**_BaseKMeans._parameter_constraints,
209+
"init": [StrOptions({"k-means++", "random"}), callable],
210+
"copy_x": [bool],
211+
"algorithm": [StrOptions({"lloyd", "elkan"})],
212+
"bisecting_strategy": [StrOptions({"biggest_inertia", "largest_cluster"})],
213+
}
214+
207215
D966 def __init__(
208216
self,
209217
n_clusters=8,
@@ -233,27 +241,6 @@ def __init__(
233241
self.algorithm = algorithm
234242
self.bisecting_strategy = bisecting_strategy
235243

236-
def _check_params(self, X):
237-
super()._check_params(X)
238-
239-
# algorithm
240-
if self.algorithm not in ("lloyd", "elkan"):
241-
raise ValueError(
242-
"Algorithm must be either 'lloyd' or 'elkan', "
243-
f"got {self.algorithm} instead."
244-
)
245-
246-
# bisecting_strategy
247-
if self.bisecting_strategy not in ["biggest_inertia", "largest_cluster"]:
248-
raise ValueError(
249-
"Bisect Strategy must be 'biggest_inertia' or 'largest_cluster'. "
250-
f"Got {self.bisecting_strategy} instead."
251-
)
252-
253-
# init
254-
if _is_arraylike_not_scalar(self.init):
255-
raise ValueError("BisectingKMeans does not support init as array.")
256-
257244
def _warn_mkl_vcomp(self, n_active_threads):
258245
"""Warn when vcomp and mkl are both present"""
259246
warnings.warn(
@@ -380,6 +367,8 @@ def fit(self, X, y=None, sample_weight=None):
380367
self
381368
Fitted estimator.
382369
"""
370+
self._validate_params()
371+
383372
X = self._validate F438 _data(
384373
X,
385374
accept_sparse="csr",
@@ -389,7 +378,8 @@ def fit(self, X, y=None, sample_weight=None):
389378
accept_large_sparse=False,
390379
)
391380

392-
self._check_params(X)
381+
self._check_params_vs_input(X)
382+
393383
self._random_state = check_random_state(self.random_state)
394384
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
395385
self._n_threads = _openmp_effective_n_threads()

sklearn/cluster/_kmeans.py

Lines changed: 57 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# License: BSD 3 clause
1313

1414
from abc import ABC, abstractmethod
15+
from numbers import Integral, Real
1516
import warnings
1617

1718
import numpy as np
@@ -34,6 +35,9 @@
3435
from ..utils import check_random_state
3536
from ..utils.validation import check_is_fitted, _check_sample_weight
3637
from ..utils.validation import _is_arraylike_not_scalar
38+
from ..utils._param_validation import Interval
39+
from ..utils._param_validation import StrOptions
40+
from ..utils._param_validation import validate_params
3741
from ..utils._openmp_helpers import _openmp_effective_n_threads
3842
from ..utils._readonly_array_wrapper import ReadonlyArrayWrapper
3943
from ..exceptions import ConvergenceWarning
@@ -55,6 +59,15 @@
5559
# Initialization heuristic
5660

5761

62+
@validate_params(
63+
{
64+
"X": ["array-like", "sparse matrix"],
65+
"n_clusters": [Interval(Integral, 1, None, closed="left")],
66+
"x_squared_norms": ["array-like", None],
67+
"random_state": ["random_state"],
68+
"n_local_trials": [Interval(Integral, 1, None, closed="left"), None],
69+
}
70+
)
5871
def kmeans_plusplus(
5972
X, n_clusters, *, x_squared_norms=None, random_state=None, n_local_trials=None
6073
):
@@ -114,7 +127,6 @@ def kmeans_plusplus(
114127
>>> indices
115128
array([4, 2])
116129
"""
117-
118130
# Check data
119131
check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
120132

@@ -135,12 +147,6 @@ def kmeans_plusplus(
135147
f"be equal to the length of n_samples {X.shape[0]}."
136148
)
137149

138-
if n_local_trials is not None and n_local_trials < 1:
139-
raise ValueError(
140-
f"n_local_trials is set to {n_local_trials} but should be an "
141-
"integer value greater than zero."
142-
)
143-
144150
random_state = check_random_state(random_state)
145151

146152
# Call private k-means++
@@ -794,6 +800,16 @@ class _BaseKMeans(
794800
):
795801
"""Base class for KMeans and MiniBatchKMeans"""
796802

803+
_parameter_constraints = {
804+
"n_clusters": [Interval(Integral, 1, None, closed="left")],
805+
"init": [StrOptions({"k-means++", "random"}), callable, "array-like"],
806+
"n_init": [Interval(Integral, 1, None, closed="left")],
807+
"max_iter": [Interval(Integral, 1, None, closed="left")],
808+
"tol": [Interval(Real, 0, None, closed="left")],
809+
"verbose": [Interval(Integral, 0, None, closed="left"), bool],
810+
"random_state": ["random_state"],
811+
}
812+
797813
def __init__(
798814
self,
799815
n_clusters,
@@ -813,16 +829,7 @@ def __init__(
813829
self.verbose = verbose
814830
self.random_state = random_state
815831

816-
def _check_params(self, X):
817-
# n_init
818-
if self.n_init <= 0:
819-
raise ValueError(f"n_init should be > 0, got {self.n_init} instead.")
820-
self._n_init = self.n_init
821-
822-
# max_iter
823-
if self.max_iter <= 0:
824-
raise ValueError(f"max_iter should be > 0, got {self.max_iter} instead.")
825-
832+
def _check_params_vs_input(self, X):
826833
# n_clusters
827834
if X.shape[0] < self.n_clusters:
828835
raise ValueError(
@@ -833,16 +840,7 @@ def _check_params(self, X):
833840
self._tol = _tolerance(X, self.tol)
834841

835842
# init
836-
if not (
837-
_is_arraylike_not_scalar(self.init)
838-
or callable(self.init)
839-
or (isinstance(self.init, str) and self.init in ["k-means++", "random"])
840-
):
841-
raise ValueError(
842-
"init should be either 'k-means++', 'random', an array-like or a "
843-
f"callable, got '{self.init}' instead."
844-
)
845-
843+
self._n_init = self.n_init
846844
if _is_arraylike_not_scalar(self.init) and self._n_init != 1:
847845
warnings.warn(
848846
"Explicit initial center position passed: performing only"
@@ -1275,6 +1273,14 @@ class KMeans(_BaseKMeans):
12751273
[ 1., 2.]])
12761274
"""
12771275

1276+
_parameter_constraints = {
1277+
**_BaseKMeans._parameter_constraints,
1278+
"copy_x": [bool],
1279+
"algorithm": [
1280+
StrOptions({"lloyd", "elkan", "auto", "full"}, deprecated={"auto", "full"})
1281+
],
1282+
}
1283+
12781284
def __init__(
12791285
self,
12801286
n_clusters=8,
@@ -1301,15 +1307,8 @@ def __init__(
13011307
self.copy_x = copy_x
13021308
self.algorithm = algorithm
13031309

1304-
def _check_params(self, X):
1305-
super()._check_params(X)
1306-
1307-
# algorithm
1308-
if self.algorithm not in ("lloyd", "elkan", "auto", "full"):
1309-
raise ValueError(
1310-
"Algorithm must be either 'lloyd' or 'elkan', "
1311-
f"got {self.algorithm} instead."
1312-
)
1310+
def _check_params_vs_input(self, X):
1311+
super()._check_params_vs_input(X)
13131312

13141313
self._algorithm = self.algorithm
13151314
if self._algorithm in ("auto", "full"):
@@ -1362,6 +1361,8 @@ def fit(self, X, y=None, sample_weight=None):
13621361
self : object
13631362
Fitted estimator.
13641363
"""
1364+
self._validate_params()
1365+
13651366
X = self._validate_data(
13661367
X,
13671368
accept_sparse="csr",
@@ -1371,7 +1372,8 @@ def fit(self, X, y=None, sample_weight=None):
13711372
accept_large_sparse=False,
13721373
)
13731374

1374-
self._check_params(X)
1375+
self._check_params_vs_input(X)
1376+
13751377
random_state = check_random_state(self.random_state)
13761378
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
13771379
self._n_threads = _openmp_effective_n_threads()
@@ -1755,6 +1757,15 @@ class MiniBatchKMeans(_BaseKMeans):
17551757
array([0, 1], dtype=int32)
17561758
"""
17571759

1760+
_parameter_constraints = {
1761+
**_BaseKMeans._parameter_constraints,
1762+
"batch_size": [Interval(Integral, 1, None, closed="left")],
1763+
"compute_labels": [bool],
1764+
"max_no_improvement": [Interval(Integral, 0, None, closed="left"), None],
1765+
"init_size": [Interval(Integral, 1, None, closed="left"), None],
1766+
"reassignment_ratio": [Interval(Real, 0, None, closed="left")],
1767+
}
1768+
17581769
def __init__(
17591770
self,
17601771
n_clusters=8,
@@ -1788,26 +1799,12 @@ def __init__(
17881799
self.init_size = init_size
17891800
self.reassignment_ratio = reassignment_ratio
17901801

1791-
def _check_params(self, X):
1792-
super()._check_params(X)
1793-
1794-
# max_no_improvement
1795-
if self.max_no_improvement is not None and self.max_no_improvement < 0:
1796-
raise ValueError(
1797-
"max_no_improvement should be >= 0, got "
1798-
f"{self.max_no_improvement} instead."
1799-
)
1802+
def _check_params_vs_input(self, X):
1803+
super()._check_params_vs_input(X)
18001804

1801-
# batch_size
1802-
if self.batch_size <= 0:
1803-
raise ValueError(
1804-
f"batch_size should be > 0, got {self.batch_size} instead."
1805-
)
18061805
self._batch_size = min(self.batch_size, X.shape[0])
18071806

18081807
# init_size
1809-
if self.init_size is not None and self.init_size <= 0:
1810-
raise ValueError(f"init_size should be > 0, got {self.init_size} instead.")
18111808
self._init_size = self.init_size
18121809
if self._init_size is None:
18131810
self._init_size = 3 * self._batch_size
@@ -1949,6 +1946,8 @@ def fit(self, X, y=None, sample_weight=None):
19491946
self : object
19501947
Fitted estimator.
19511948
"""
1949+
self._validate_params()
1950+
19521951
X = self._validate_data(
19531952
X,
19541953
accept_sparse="csr",
@@ -1957,7 +1956,7 @@ def fit(self, X, y=None, sample_weight=None):
19571956
accept_large_sparse=False,
19581957
)
19591958

1960-
self._check_params(X)
1959+
self._check_params_vs_input(X)
19611960
random_state = check_random_state(self.random_state)
19621961
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
19631962
self._n_threads = _openmp_effective_n_threads()
@@ -2106,6 +2105,9 @@ def partial_fit(self, X, y=None, sample_weight=None):
21062105
"""
21072106
has_centers = hasattr(self, "cluster_centers_")
21082107

2108+
if not has_centers:
2109+
self._validate_params()
2110+
21092111
X = self._validate_data(
21102112
X,
21112113
accept_sparse="csr",
@@ -2126,7 +2128,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
21262128

21272129
if not has_centers:
21282130
# this instance has not been fitted yet (fit or partial_fit)
2129-
self._check_params(X)
2131+
self._check_params_vs_input(X)
21302132
self._n_threads = _openmp_effective_n_threads()
21312133

21322134
# Validate init array

sklearn/cluster/tests/test_bisect_k_means.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,31 +85,6 @@ def test_one_cluster():
8585
assert_allclose(bisect_means.cluster_centers_, X.mean(axis=0).reshape(1, -1))
8686

8787

88-
@pytest.mark.parametrize(
89-
"param, match",
90-
[
91-
# Test bisecting_strategy param
92-
(
93-
{"bisecting_strategy": "None"},
94-
"Bisect Strategy must be 'biggest_inertia' or 'largest_cluster'",
95-
),
96-
# Test init array
97-
(
98-
{"init": np.ones((5, 2))},
99-
"BisectingKMeans does not support init as array.",
100-
),
101-
],
102-
)
103-
def test_wrong_params(param, match):
104-
"""Test Exceptions at check_params function."""
105-
rng = np.random.RandomState(0)
106-
X = rng.rand(5, 2)
107-
108-
with pytest.raises(ValueError, match=match):
109-
bisect_means = BisectingKMeans(n_clusters=3, **param)
110-
bisect_means.fit(X)
111-
112-
11388
@pytest.mark.parametrize("is_sparse", [True, False])
11489
def test_fit_predict(is_sparse):
11590
"""Check if labels from fit(X) method are same as from fit(X).predict(X)."""

0 commit comments

Comments
 (0)
0