8000 MAINT Param validation: constraint for numeric missing values by jeremiedbb · Pull Request #26085 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Param validation: constraint for numeric missing values #26085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions sklearn/impute/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from scipy import sparse as sp

from ..base import BaseEstimator, TransformerMixin
from ..utils._param_validation import StrOptions, Hidden
from ..utils._param_validation import StrOptions, Hidden, MissingValues
from ..utils.fixes import _mode
from ..utils.sparsefuncs import _get_median
from ..utils.validation import check_is_fitted
Expand Down Expand Up @@ -78,7 +78,7 @@ class _BaseImputer(TransformerMixin, BaseEstimator):
"""

_parameter_constraints: dict = {
"missing_values": ["missing_values"],
"missing_values": [MissingValues()],
"add_indicator": ["boolean"],
"keep_empty_features": ["boolean"],
}
Expand Down Expand Up @@ -800,7 +800,7 @@ class MissingIndicator(TransformerMixin, BaseEstimator):
"""

_parameter_constraints: dict = {
"missing_values": [numbers.Real, numbers.Integral, str, None],
"missing_values": [MissingValues()],
"features": [StrOptions({"missing-only", "all"})],
"sparse": ["boolean", StrOptions({"auto"})],
"error_on_new": ["boolean"],
Expand Down
19 changes: 17 additions & 2 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@
from ..utils._mask import _get_mask
from ..utils.parallel import delayed, Parallel
from ..utils.fixes import sp_base_version, sp_version, parse_version
from ..utils._param_validation import validate_params, Interval, Real, Hidden
from ..utils._param_validation import (
validate_params,
Interval,
Real,
Hidden,
MissingValues,
)

from ._pairwise_distances_reduction import ArgKmin
from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan
Expand Down Expand Up @@ -380,6 +386,15 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
return distances if squared else np.sqrt(distances, out=distances)


@validate_params(
{
"X": ["array-like"],
"Y": ["array-like", None],
"squared": ["boolean"],
"missing_values": [MissingValues(numeric_only=True)],
"copy": ["boolean"],
}
)
def nan_euclidean_distances(
X, Y=None, *, squared=False, missing_values=np.nan, copy=True
):
Expand Down Expand Up @@ -420,7 +435,7 @@ def nan_euclidean_distances(
squared : bool, default=False
Return squared Euclidean distances.

missing_values : np.nan or int, default=np.nan
missing_values : np.nan, float or int, default=np.nan
Representation of missing value.

copy : bool, default=True
Expand Down
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _check_function_param_validation(
"sklearn.metrics.pairwise.haversine_distances",
"sklearn.metrics.pairwise.laplacian_kernel",
"sklearn.metrics.pairwise.linear_kernel",
"sklearn.metrics.pairwise.nan_euclidean_distances",
"sklearn.metrics.pairwise.paired_cosine_distances",
"sklearn.metrics.pairwise.paired_euclidean_distances",
"sklearn.metrics.pairwise.paired_manhattan_distances",
Expand Down
36 changes: 24 additions & 12 deletions sklearn/utils/_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
- the string "boolean"
- the string "verbose"
- the string "cv_object"
- the string "missing_values"
- a MissingValues object representing markers for missing values
- a HasMethods object, representing method(s) an object must have
- a Hidden object, representing a constraint not meant to be exposed to the user

Expand Down Expand Up @@ -125,14 +125,14 @@ def make_constraint(constraint):
return _NoneConstraint()
if isinstance(constraint, type):
return _InstancesOf(constraint)
if isinstance(constraint, (Interval, StrOptions, Options, HasMethods)):
if isinstance(
constraint, (Interval, StrOptions, Options, HasMethods, MissingValues)
):
return constraint
if isinstance(constraint, str) and constraint == "boolean":
return _Booleans()
if isinstance(constraint, str) and constraint == "verbose":
return _VerboseHelper()
if isinstance(constraint, str) and constraint == "missing_values":
return _MissingValues()
if isinstance(constraint, str) and constraint == "cv_object":
return _CVObjects()
if isinstance(constraint, Hidden):
Expand Down Expand Up @@ -609,31 +609,40 @@ def __str__(self):
)


class _MissingValues(_Constraint):
class MissingValues(_Constraint):
"""Helper constraint for the `missing_values` parameters.

Convenience for
[
Integral,
Interval(Real, None, None, closed="both"),
str,
None,
str, # when numeric_only is False
None, # when numeric_only is False
_NanConstraint(),
_PandasNAConstraint(),
]

Parameters
----------
numeric_only : bool, default=False
Whether to consider only numeric missing value markers.

"""

def __init__(self):
def __init__(self, numeric_only=False):
super().__init__()

self.numeric_only = numeric_only

self._constraints = [
_InstancesOf(Integral),
# we use an interval of Real to ignore np.nan that has its own constraint
Interval(Real, None, None, closed="both"),
_InstancesOf(str),
_NoneConstraint(),
_NanConstraint(),
_PandasNAConstraint(),
]
if not self.numeric_only:
self._constraints.extend([_InstancesOf(str), _NoneConstraint()])

def is_satisfied_by(self, val):
return any(c.is_satisfied_by(val) for c in self._constraints)
Expand Down Expand Up @@ -752,7 +761,7 @@ def generate_invalid_param_val(constraint):
if isinstance(constraint, StrOptions):
return f"not {' or '.join(constraint.options)}"

if isinstance(constraint, _MissingValues):
if isinstance(constraint, MissingValues):
return np.array([1, 2, 3])

if isinstance(constraint, _VerboseHelper):
Expand Down Expand Up @@ -841,9 +850,12 @@ def generate_valid_param(constraint):
if isinstance(constraint, _VerboseHelper):
return 1

if isinstance(constraint, _MissingValues):
if isinstance(constraint, MissingValues) and constraint.numeric_only:
return np.nan

if isinstance(constraint, MissingValues) and not constraint.numeric_only:
return "missing"

if isinstance(constraint, HasMethods):
return type(
"ValidHasMethods", (), {m: lambda self: None for m in constraint.methods}
Expand Down
22 changes: 12 additions & 10 deletions sklearn/utils/tests/test_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sklearn.utils._param_validation import _Callables
from sklearn.utils._param_validation import _CVObjects
from sklearn.utils._param_validation import _InstancesOf
from sklearn.utils._param_validation import _MissingValues
from sklearn.utils._param_validation import MissingValues
from sklearn.utils._param_validation import _PandasNAConstraint
from sklearn.utils._param_validation import _IterablesNotString
from sklearn.utils._param_validation import _NoneConstraint
Expand Down Expand Up @@ -202,7 +202,8 @@ def a(self):
Interval(Real, 0, None, closed="left"),
Interval(Real, None, None, closed="neither"),
StrOptions({"a", "b", "c"}),
_MissingValues(),
MissingValues(),
MissingValues(numeric_only=True),
_VerboseHelper(),
HasMethods("fit"),
_IterablesNotString(),
Expand Down Expand Up @@ -337,7 +338,8 @@ def test_generate_invalid_param_val_all_valid(constraint):
_SparseMatrices(),
_Booleans(),
_VerboseHelper(),
_MissingValues(),
MissingValues(),
MissingValues(numeric_only=True),
StrOptions({"a", "b", "c"}),
Options(Integral, {1, 2, 3}),
Interval(Integral, None, None, closed="neither"),
Expand Down Expand Up @@ -378,12 +380,12 @@ def test_generate_valid_param(constraint):
(Real, 0.5),
("boolean", False),
("verbose", 1),
("missing_values", -1),
("missing_values", -1.0),
("missing_values", None),
("missing_values", float("nan")),
("missing_values", np.nan),
("missing_values", "missing"),
(MissingValues(), -1),
(MissingValues(), -1.0),
(MissingValues(), None),
(MissingValues(), float("nan")),
(MissingValues(), np.nan),
(MissingValues(), "missing"),
(HasMethods("fit"), _Estimator(a=0)),
("cv_object", 5),
],
Expand All @@ -408,7 +410,7 @@ def test_is_satisfied_by(constraint_declaration, value):
(int, _InstancesOf),
("boolean", _Booleans),
("verbose", _VerboseHelper),
("missing_values", _MissingValues),
(MissingValues(numeric_only=True), MissingValues),
(HasMethods("fit"), HasMethods),
("cv_object", _CVObjects),
],
Expand Down
0