8000 MAINT Param validation: constraint for numeric missing values (#26085) · thomasjpfan/scikit-learn@ba46b65 · GitHub
[go: up one dir, main page]

Skip to content

Commit ba46b65

Browse files
authored
MAINT Param validation: constraint for numeric missing values (scikit-learn#26085)
1 parent db83e23 commit ba46b65

File tree

5 files changed

+57
-27
lines changed

5 files changed

+57
-27
lines changed

sklearn/impute/_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from scipy import sparse as sp
1212

1313
from ..base import BaseEstimator, TransformerMixin
14-
from ..utils._param_validation import StrOptions, Hidden
14+
from ..utils._param_validation import StrOptions, Hidden, MissingValues
1515
from ..utils.fixes import _mode
1616
from ..utils.sparsefuncs import _get_median
1717
from ..utils.validation import check_is_fitted
@@ -78,7 +78,7 @@ class _BaseImputer(TransformerMixin, BaseEstimator):
7878
"""
7979

8080
_parameter_constraints: dict = {
81-
"missing_values": ["missing_values"],
81+
"missing_values": [MissingValues()],
8282
"add_indicator": ["boolean"],
8383
"keep_empty_features": ["boolean"],
8484
}
@@ -800,7 +800,7 @@ class MissingIndicator(TransformerMixin, BaseEstimator):
800800
"""
801801

802802
_parameter_constraints: dict = {
803-
"missing_values": [numbers.Real, numbers.Integral, str, None],
803+
"missing_values": [MissingValues()],
804804
"features": [StrOptions({"missing-only", "all"})],
805805
"sparse": ["boolean", StrOptions({"auto"})],
806806
"error_on_new": ["boolean"],

sklearn/metrics/pairwise.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
from ..utils._mask import _get_mask
3030
from ..utils.parallel import delayed, Parallel
3131
from ..utils.fixes import sp_base_version, sp_version, parse_version
32-
from ..utils._param_validation import validate_params, Interval, Real, Hidden
32+
from ..utils._param_validation import (
33+
validate_params,
34+
Interval,
35+
Real,
36+
Hidden,
37+
MissingValues,
38+
)
3339

3440
from ._p 8000 airwise_distances_reduction import ArgKmin
3541
from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan
@@ -380,6 +386,15 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
380386
return distances if squared else np.sqrt(distances, out=distances)
381387

382388

389+
@validate_params(
390+
{
391+
"X": ["array-like"],
392+
"Y": ["array-like", None],
393+
"squared": ["boolean"],
394+
"missing_values": [MissingValues(numeric_only=True)],
395+
"copy": ["boolean"],
396+
}
397+
)
383398
def nan_euclidean_distances(
384399
X, Y=None, *, squared=False, missing_values=np.nan, copy=True
385400
):
@@ -420,7 +435,7 @@ def nan_euclidean_distances(
420435
squared : bool, default=False
421436
Return squared Euclidean distances.
422437
423-
missing_values : np.nan or int, default=np.nan
438+
missing_values : np.nan, float or int, default=np.nan
424439
Representation of missing value.
425440
426441
copy : bool, default=True

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def _check_function_param_validation(
212212
"sklearn.metrics.pairwise.haversine_distances",
213213
"sklearn.metrics.pairwise.laplacian_kernel",
214214
"sklearn.metrics.pairwise.linear_kernel",
215+
"sklearn.metrics.pairwise.nan_euclidean_distances",
215216
"sklearn.metrics.pairwise.paired_cosine_distances",
216217
"sklearn.metrics.pairwise.paired_euclidean_distances",
217218
"sklearn.metrics.pairwise.paired_manhattan_distances",

sklearn/utils/_param_validation.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
4848
- the string "boolean"
4949
- the string "verbose"
5050
- the string "cv_object"
51-
- the string "missing_values"
51+
- a MissingValues object representing markers for missing values
5252
- a HasMethods object, representing method(s) an object must have
5353
- a Hidden object, representing a constraint not meant to be exposed to the user
5454
@@ -125,14 +125,14 @@ def make_constraint(constraint):
125125
return _NoneConstraint()
126126
if isinstance(constraint, type):
127127
return _InstancesOf(constraint)
128-
if isinstance(constraint, (Interval, StrOptions, Options, HasMethods)):
128+
if isinstance(
129+
constraint, (Interval, StrOptions, Options, HasMethods, MissingValues)
130+
):
129131
return constraint
130132
if isinstance(constraint, str) and constraint == "boolean":
131133
return _Booleans()
132134
if isinstance(constraint, str) and constraint == "verbose":
133135
return _VerboseHelper()
134-
if isinstance(constraint, str) and constraint == "missing_values":
135-
return _MissingValues()
136136
if isinstance(constraint, str) and constraint == "cv_object":
137137
return _CVObjects()
138138
if isinstance(constraint, Hidden):
@@ -609,31 +609,40 @@ def __str__(self):
609609
)
610610

611611

612-
class _MissingValues(_Constraint):
612+
class MissingValues(_Constraint):
613613
"""Helper constraint for the `missing_values` parameters.
614614
615615
Convenience for
616616
[
617617
Integral,
618618
Interval(Real, None, None, closed="both"),
619-
str,
620-
None,
619+
str, # when numeric_only is False
620+
None, # when numeric_only is False
621621
_NanConstraint(),
622622
_PandasNAConstraint(),
623623
]
624+
625+
Parameters
626+
----------
627+
numeric_only : bool, default=False
628+
Whether to consider only numeric missing value markers.
629+
624630
"""
625631

626-
def __init__(self):
632+
def __init__(self, numeric_only=False):
627633
super().__init__()
634+
635+
self.numeric_only = numeric_only
636+
628637
self._constraints = [
629638
_InstancesOf(Integral),
630639
# we use an interval of Real to ignore np.nan that has its own constraint
631640
Interval(Real, None, None, closed="both"),
632-
_InstancesOf(str),
633-
_NoneConstraint(),
634641
_NanConstraint(),
635642
_PandasNAConstraint(),
636643
]
644+
if not self.numeric_only:
645+
self._constraints.extend([_InstancesOf(str), _NoneConstraint()])
637646

638647
def is_satisfied_by(self, val):
639648
return any(c.is_satisfied_by(val) for c in self._constraints)
@@ -752,7 +761,7 @@ def generate_invalid_param_val(constraint):
752761
if isinstance(constraint, StrOptions):
753762
return f"not {' or '.join(constraint.options)}"
754763

755-
if isinstance(constraint, _MissingValues):
764+
if isinstance(constraint, MissingValues):
756765
return np.array([1, 2, 3])
757766

758767
if isinstance(constraint, _VerboseHelper):
@@ -841,9 +850,12 @@ def generate_valid_param(constraint):
841850
if isinstance(constraint, _VerboseHelper):
842851
return 1
843852

844-
if isinstance(constraint, _MissingValues):
853+
if isinstance(constraint, MissingValues) and constraint.numeric_only:
845854
return np.nan
846855

856+
if isinstance(constraint, MissingValues) and not constraint.numeric_only:
857+
return "missing"
858+
847859
if isinstance(constraint, HasMethods):
848860
return type(
849861
"ValidHasMethods", (), {m: lambda self: None for m in constraint.methods}

sklearn/utils/tests/test_param_validation.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.utils._param_validation import _Callables
1717
from sklearn.utils._param_validation import _CVObjects
1818
from sklearn.utils._param_validation import _InstancesOf
19-
from sklearn.utils._param_validation import _MissingValues
19+
from sklearn.utils._param_validation import MissingValues
2020
from sklearn.utils._param_validation import _PandasNAConstraint
2121
from sklearn.utils._param_validation import _IterablesNotString
2222
from sklearn.utils._param_validation import _NoneConstraint
@@ -202,7 +202,8 @@ def a(self):
202202
Interval(Real, 0, None, closed="left"),
203203
Interval(Real, None, None, closed="neither"),
204204
StrOptions({"a", "b", "c"}),
205-
_MissingValues(),
205+
MissingValues(),
206+
MissingValues(numeric_only=True),
206207
_VerboseHelper(),
207208
HasMethods("fit"),
208209
_IterablesNotString(),
@@ -337,7 +338,8 @@ def test_generate_invalid_param_val_all_valid(constraint):
337338
_SparseMatrices(),
338339
_Booleans(),
339340
_VerboseHelper(),
340-
_MissingValues(),
341+
MissingValues(),
342+
MissingValues(numeric_only=True),
341343
StrOptions({"a", "b", "c"}),
342344
Options(Integral, {1, 2, 3}),
343345
Interval(Integral, None, None, closed="neither"),
@@ -378,12 +380,12 @@ def test_generate_valid_param(constraint):
378380
(Real, 0.5),
379381
("boolean", False),
380382
("verbose", 1),
381-
("missing_values", -1),
382-
("missing_values", -1.0),
383-
("missing_values", None),
384-
("missing_values", float("nan")),
385-
("missing_values", np.nan),
386-
("missing_values", "missing"),
383+
(MissingValues(), -1),
384+
(MissingValues(), -1.0),
385+
(MissingValues(), None),
386+
(MissingValues(), float("nan")),
387+
(MissingValues(), np.nan),
388+
(MissingValues(), "missing"),
387389
(HasMethods("fit"), _Estimator(a=0)),
388390
("cv_object", 5),
389391
],
@@ -408,7 +410,7 @@ def test_is_satisfied_by(constraint_declaration, value):
408410
(int, _InstancesOf),
409411
("boolean", _Booleans),
410412
("verbose", _VerboseHelper),
411-
("missing_values", _MissingValues),
413+
(MissingValues(numeric_only=True), MissingValues),
412414
(HasMethods("fit"), HasMethods),
413415
("cv_object", _CVObjects),
414416
],

0 commit comments

Comments
 (0)
0