8000 MAINT validate parameter in sklearn.preprocessing._encoders by Diadochokinetic · Pull Request #23579 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT validate parameter in sklearn.preprocessing._encoders #23579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
adc5db9
Add minimal example for parameter 'categories' in OneHotEncoder() and…
Diadochokinetic Jun 10, 2022
7520d3e
Add all parameter constraints and run tests successfully
Diadochokinetic Jun 10, 2022
200ade8
formatted code with black
Diadochokinetic Jun 10, 2022
835cccd
Remove old paramater checks
Diadochokinetic Jun 10, 2022
c20ccdc
remove OneHotEncode and OridnalEncoder form PARAM_VALIDATION_ESTIMATO…
Diadochokinetic Jun 10, 2022
e240f8c
remove simple tests and move self._infrequent_enabled to __init__
Diadochokinetic Jun 10, 2022
a3f9633
change paramter dtype to type
Diadochokinetic Jun 10, 2022
8225384
try different approach for dtype
Diadochokinetic Jun 12, 2022
48db2c4
Add type and np.dtype to parameter_constraints of dtype
Diadochokinetic Jun 12, 2022
e7579ec
Merge branch 'ohe_validate_params' of https://github.com/Diadochokine…
Diadochokinetic Jun 12, 2022
732c009
Remove KernelCenterer from PARAM_VALIDATION_TO_IGNORE
Diadochokinetic Jun 12, 2022
19afd9b
Add StrOptions first and if_binary to parameter drop
Diadochokinetic Jun 12, 2022
2f60d6e
Remove simple tests for error messages of parameter constraints
Diadochokinetic Jun 12, 2022
cd14e09
Add StrOptions for paramater dtype of OridinalEncoder, it allows to p…
Diadochokinetic Jun 12, 2022
cc93146
Format with black
Diadochokinetic Jun 12, 2022
0d9e2f4
Undo changes to _InstancesOf
Diadochokinetic Jun 12, 2022
073bc0f
Undo changes in _InstacnesOf docstring
Diadochokinetic Jun 12, 2022
5409276
Format with black
Diadochokinetic Jun 12, 2022
243434f
Change parameter_constraints of dtpye to object, the most abstract form
Diadochokinetic Jun 12, 2022
604954f
disable error message tests for object as parameter constraint
Diadochokinetic Jun 12, 2022
ced1030
forgot to run flake8...
Diadochokinetic Jun 12, 2022
030efea
Merge remote-tracking branch 'upstream/main' into ohe_validate_params
Diadochokinetic Jun 17, 2022
b75f62c
undo changes in estimator_checks
Diadochokinetic Jun 17, 2022
e3f230b
Fix inline comments
Diadochokinetic Jun 17, 2022
de91812
Replace constraint [bool] with [boolean].
Diadochokinetic Jun 17, 2022
455140c
no validation for dtype
jeremiedbb Jun 24, 2022
e61c5ce
Merge remote-tracking branch 'upstream/main' into pr/Diadochokinetic/…
jeremiedbb Jun 24, 2022
00c8eae
update
jeremiedbb Jun 24, 2022
04e65ab
address review comment
jeremiedbb Jun 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 37 additions & 74 deletions sklearn/preprocessing/_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# License: BSD 3 clause

import numbers
from numbers import Integral, Real
import warnings

import numpy as np
Expand All @@ -13,6 +14,8 @@
from ..utils.deprecation import deprecated
from ..utils.validation import check_is_fitted
from ..utils.validation import _check_feature_names_in
from ..utils._param_validation import Interval
from ..utils._param_validation import StrOptions
from ..utils._mask import _get_mask

from ..utils._encode import _encode, _check_unknown, _unique, _get_counts
Expand Down Expand Up @@ -430,6 +433,20 @@ class OneHotEncoder(_BaseEncoder):
[1., 0., 0.]])
"""

_parameter_constraints = {
"categories": [StrOptions({"auto"}), list],
"drop": [StrOptions({"first", "if_binary"}), "array-like", None],
"dtype": "no_validation", # validation delegated to numpy
"handle_unknown": [StrOptions({"error", "ignore", "infrequent_if_exist"})],
"max_categories": [Interval(Integral, 1, None, closed="left"), None],
"min_frequency": [
Interval(Integral, 1, None, closed="left"),
Interval(Real, 0, 1, closed="neither"),
None,
],
"sparse": ["boolean"],
}

def __init__(
self,
*,
Expand Down Expand Up @@ -459,33 +476,11 @@ def infrequent_categories_(self):
for category, indices in zip(self.categories_, infrequent_indices)
]

def _validate_keywords(self):

if self.handle_unknown not in {"error", "ignore", "infrequent_if_exist"}:
msg = (
"handle_unknown should be one of 'error', 'ignore', "
f"'infrequent_if_exist' got {self.handle_unknown}."
)
raise ValueError(msg)

if self.max_categories is not None and self.max_categories < 1:
raise ValueError("max_categories must be greater than 1")

if isinstance(self.min_frequency, numbers.Integral):
if not self.min_frequency >= 1:
raise ValueError(
"min_frequency must be an integer at least "
"1 or a float in (0.0, 1.0); got the "
f"integer {self.min_frequency}"
)
elif isinstance(self.min_frequency, numbers.Real):
if not (0.0 < self.min_frequency < 1.0):
raise ValueError(
"min_frequency must be an integer at least "
"1 or a float in (0.0, 1.0); got the "
f"float {self.min_frequency}"
)

def _check_infrequent_enabled(self):
"""
This functions checks whether _infrequent_enabled is True or False.
This has to be called after parameter validation in the fit function.
"""
self._infrequent_enabled = (
self.max_categories is not None and self.max_categories >= 1
) or self.min_frequency is not None
Expand Down Expand Up @@ -547,23 +542,11 @@ def _compute_drop_idx(self):
],
dtype=object,
)
else:
msg = (
"Wrong input for parameter `drop`. Expected "
"'first', 'if_binary', None or array of objects, got {}"
)
raise ValueError(msg.format(type(self.drop)))

else:
try:
drop_array = np.asarray(self.drop, dtype=object)
droplen = len(drop_array)
except (ValueError, TypeError):
msg = (
"Wrong input for parameter `drop`. Expected "
"'first', 'if_binary', None or array of objects, got {}"
)
raise ValueError(msg.format(type(drop_array)))
drop_array = np.asarray(self.drop, dtype=object)
droplen = len(drop_array)

if droplen != len(self.categories_):
msg = (
"`drop` should have length equal to the number "
Expand Down Expand Up @@ -814,7 +797,9 @@ def fit(self, X, y=None):
self
Fitted encoder.
"""
self._validate_keywords()
self._validate_params()
self._check_infrequent_enabled()

fit_results = self._fit(
X,
handle_unknown=self.handle_unknown,
Expand All @@ -829,31 +814,6 @@ def fit(self, X, y=None):
self._n_features_outs = self._compute_n_features_outs()
return self

def fit_transform(self, X, y=None):
"""
Fit OneHotEncoder to X, then transform X.

Equivalent to fit(X).transform(X) but more convenient.

Parameters
----------
X : array-like of shape (n_samples, n_features)
The data to encode.

y : None
Ignored. This parameter exists only for compatibility with
:class:`~sklearn.pipeline.Pipeline`.

Returns
-------
X_out : {ndarray, sparse matrix} of shape \
(n_samples, n_encoded_features)
Transformed input. If `sparse=True`, a sparse matrix will be
returned.
"""
self._validate_keywords()
return super().fit_transform(X, y)

def transform(self, X):
"""
Transform X using one-hot encoding.
Expand Down Expand Up @@ -1228,6 +1188,14 @@ class OrdinalEncoder(_OneToOneFeatureMixin, _BaseEncoder):
[ 0., -1.]])
"""

_parameter_constraints = {
"categories": [StrOptions({"auto"}), list],
"dtype": "no_validation", # validation delegated to numpy
"encoded_missing_value": [Integral, type(np.nan)],
"handle_unknown": [StrOptions({"error", "use_encoded_value"})],
"unknown_value": [Integral, type(np.nan), None],
}

def __init__(
self,
*,
Expand Down Expand Up @@ -1261,12 +1229,7 @@ def fit(self, X, y=None):
self : object
Fitted encoder.
"""
handle_unknown_strategies = ("error", "use_encoded_value")
if self.handle_unknown not in handle_unknown_strategies:
raise ValueError(
"handle_unknown should be either 'error' or "
f"'use_encoded_value', got {self.handle_unknown}."
)
self._validate_params()

if self.handle_unknown == "use_encoded_value":
if is_scalar_nan(self.unknown_value):
Expand Down
91 changes: 0 additions & 91 deletions sklearn/preprocessing/tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def test_one_hot_encoder_handle_unknown(handle_unknown):
# ensure transformed data was not modified in place
assert_allclose(X2, X2_passed)

# Raise error if handle_unknown is neither ignore or error.
oh = OneHotEncoder(handle_unknown="42")
with pytest.raises(ValueError, match="handle_unknown should be one of"):
oh.fit(X)


def test_one_hot_encoder_not_fitted():
X = np.array([["a"], ["b"]])
Expand Down Expand Up @@ -716,50 +711,6 @@ def test_ordinal_encoder_handle_unknowns_numeric(dtype):
assert_array_equal(X_trans_inv, inv_exp)


@pytest.mark.parametrize(
"params, err_type, err_msg",
[
(
{"handle_unknown": "use_encoded_value"},
TypeError,
"unknown_value should be an integer or np.nan when handle_unknown "
"is 'use_encoded_value', got None.",
),
(
{"unknown_value": -2},
TypeError,
"unknown_value should only be set when handle_unknown is "
"'use_encoded_value', got -2.",
),
(
{"handle_unknown": "use_encoded_value", "unknown_value": "bla"},
TypeError,
"unknown_value should be an integer or np.nan when handle_unknown "
"is 'use_encoded_value', got bla.",
),
(
{"handle_unknown": "use_encoded_value", "unknown_value": 1},
ValueError,
"The used value for unknown_value (1) is one of the values "
"already used for encoding the seen categories.",
),
(
{"handle_unknown": "ignore"},
ValueError,
"handle_unknown should be either 'error' or 'use_encoded_value', "
"got ignore.",
),
],
)
def test_ordinal_encoder_handle_unknowns_raise(params, err_type, err_msg):
# Check error message when validating input parameters
X = np.array([["a", "x"], ["b", "y"]], dtype=object)

encoder = OrdinalEncoder(**params)
with pytest.raises(err_type, match=err_msg):
encoder.fit(X)


def test_ordinal_encoder_handle_unknowns_nan():
# Make sure unknown_value=np.nan properly works

Expand Down Expand Up @@ -886,32 +837,6 @@ def test_one_hot_encoder_drop_manual(missing_value):
assert_array_equal(X_array, X_inv_trans)


@pytest.mark.parametrize(
"X_fit, params, err_msg",
[
(
[["Male"], ["Female"]],
{"drop": "second"},
"Wrong input for parameter `drop`",
),
(
[["abc", 2, 55], ["def", 1, 55], ["def", 3, 59]],
{"drop": np.asarray("b", dtype=object)},
"Wrong input for parameter `drop`",
),
(
[["abc", 2, 55], ["def", 1, 55], ["def", 3, 59]],
{"drop": ["ghi", 3, 59]},
"The following categories were supposed",
),
],
)
def test_one_hot_encoder_invalid_params(X_fit, params, err_msg):
enc = OneHotEncoder(**params)
with pytest.raises(ValueError, match=err_msg):
enc.fit(X_fit)


@pytest.mark.parametrize("drop", [["abc", 3], ["abc", 3, 41, "a"]])
def test_invalid_drop_length(drop):
enc = OneHotEncoder(drop=drop)
Expand Down Expand Up @@ -1433,22 +1358,6 @@ def test_ohe_infrequent_user_cats_unknown_training_errors(kwargs):
assert_allclose(X_trans, [[1], [1]])


@pytest.mark.parametrize(
"kwargs, error_msg",
[
({"max_categories": -2}, "max_categories must be greater than 1"),
({"min_frequency": -1}, "min_frequency must be an integer at least"),
({"min_frequency": 1.1}, "min_frequency must be an integer at least"),
],
)
def test_ohe_infrequent_invalid_parameters_error(kwargs, error_msg):
X_train = np.array([["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 2]).T

ohe = OneHotEncoder(handle_unknown="infrequent_if_exist", **kwargs)
with pytest.raises(ValueError, match=error_msg):
ohe.fit(X_train)


# TODO: Remove in 1.2 when get_feature_names is removed
def test_one_hot_encoder_get_feature_names_deprecated():
X = np.array([["cat", "dog"]], dtype=object).T
Expand Down
2 changes: 0 additions & 2 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,8 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
"OAS",
"OPTICS",
"OneClassSVM",
"OneHotEncoder",
"OneVsOneClassifier",
"OneVsRestClassifier",
"OrdinalEncoder",
"OrthogonalMatchingPursuit",
"OrthogonalMatchingPursuitCV",
"OutputCodeClassifier",
Expand Down
0