8000 MAINT validate parameter in OneHotEncoder and OrdinalEncoder (#23579) · ogrisel/scikit-learn@edc858b · GitHub
[go: up one dir, main page]

Skip to content

Commit edc858b

Browse files
Diadochokineticjeremiedbb
authored andcommitted
MAINT validate parameter in OneHotEncoder and OrdinalEncoder (scikit-learn#23579)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 590c4cf commit edc858b

File tree

3 files changed

+37
-167
lines changed

3 files changed

+37
-167
lines changed

sklearn/preprocessing/_encoders.py

Lines changed: 37 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# License: BSD 3 clause
44

55
import numbers
6+
from numbers import Integral, Real
67
import warnings
78

89
import numpy as np
@@ -13,6 +14,8 @@
1314
from ..utils.deprecation import deprecated
1415
from ..utils.validation import check_is_fitted
1516
from ..utils.validation import _check_feature_names_in
17+
from ..utils._param_validation import Interval
18+
from ..utils._param_validation import StrOptions
1619
from ..utils._mask import _get_mask
1720

1821
from ..utils._encode import _encode, _check_unknown, _unique, _get_counts
@@ -430,6 +433,20 @@ class OneHotEncoder(_BaseEncoder):
430433
[1., 0., 0.]])
431434
"""
432435

436+
_parameter_constraints = {
437+
"categories": [StrOptions({"auto"}), list],
438+
"drop": [StrOptions({"first", "if_binary"}), "array-like", None],
439+
"dtype": "no_validation", # validation delegated to numpy
440+
"handle_unknown": [StrOptions({"error", "ignore", "infrequent_if_exist"})],
441+
"max_categories": [Interval(Integral, 1, None, closed="left"), None],
442+
"min_frequency": [
443+
Interval(Integral, 1, None, closed="left"),
444+
Interval(Real, 0, 1, closed="neither"),
445+
None,
446+
],
447+
"sparse": ["boolean"],
448+
}
449+
433450
def __init__(
434451
self,
435452
*,
@@ -459,33 +476,11 @@ def infrequent_categories_(self):
459476
for category, indices in zip(self.categories_, infrequent_indices)
460477
]
461478

462-
def _validate_keywords(self):
463-
464-
if self.handle_unknown not in {"error", "ignore", "infrequent_if_exist"}:
465-
msg = (
466-
"handle_unknown should be one of 'error', 'ignore', "
467-
f"'infrequent_if_exist' got {self.handle_unknown}."
468-
)
469-
raise ValueError(msg)
470-
471-
if self.max_categories is not None and self.max_categories < 1:
472-
raise ValueError("max_categories must be greater than 1")
473-
474-
if isinstance(self.min_frequency, numbers.Integral):
475-
if not self.min_frequency >= 1:
476-
raise ValueError(
477-
"min_frequency must be an integer at least "
478-
"1 or a float in (0.0, 1.0); got the "
479-
f"integer {self.min_frequency}"
480-
)
481-
elif isinstance(self.min_frequency, numbers.Real):
482-
if not (0.0 < self.min_frequency < 1.0):
483-
raise ValueError(
484-
"min_frequency must be an integer at least "
485-
"1 or a float in (0.0, 1.0); got the "
486-
f"float {self.min_frequency}"
487-
)
488-
479+
def _check_infrequent_enabled(self):
480+
"""
481+
This functions checks whether _infrequent_enabled is True or False.
482+
This has to be called after parameter validation in the fit function.
483+
"""
489484
self._infrequent_enabled = (
490485
self.max_categories is not None and self.max_categories >= 1
491486
) or self.min_frequency is not None
@@ -547,23 +542,11 @@ def _compute_drop_idx(self):
547542
],
548543
dtype=object,
549544
)
550-
else:
551-
msg = (
552-
"Wrong input for parameter `drop`. Expected "
553-
"'first', 'if_binary', None or array of objects, got {}"
554-
)
555-
raise ValueError(msg.format(type(self.drop)))
556545

557546
else:
558-
try:
559-
drop_array = np.asarray(self.drop, dtype=object)
560-
droplen = len(drop_array)
561-
except (ValueError, TypeError):
562-
msg = (
563-
"Wrong input for parameter `drop`. Expected "
564-
"'first', 'if_binary', None or array of objects, got {}"
565-
)
566-
raise ValueError(msg.format(type(drop_array)))
547+
drop_array = np.asarray(self.drop, dtype=object)
548+
droplen = len(drop_array)
549+
567550
if droplen != len(self.categories_):
568551
msg = (
569552
"`drop` should have length equal to the number "
@@ -814,7 +797,9 @@ def fit(self, X, y=None):
814797
self
815798
Fitted encoder.
816799
"""
817-
self._validate_keywords()
800+
self._validate_params()
801+
self._check_infrequent_enabled()
802+
818803
fit_results = self._fit(
819804
X,
820805
handle_unknown=self.handle_unknown,
@@ -829,31 +814,6 @@ def fit(self, X, y=None):
829814
self._n_features_outs = self._compute_n_features_outs()
830815
return self
831816

832-
def fit_transform(self, X, y=None):
833-
"""
834-
Fit OneHotEncoder to X, then transform X.
835-
836-
Equivalent to fit(X).transform(X) but more convenient.
837-
838-
Parameters
839-
----------
840-
X : array-like of shape (n_samples, n_features)
841-
The data to encode.
842-
843-
y : None
844-
Ignored. This parameter exists only for compatibility with
845-
:class:`~sklearn.pipeline.Pipeline`.
846-
847-
Returns
848-
-------
849-
X_out : {ndarray, sparse matrix} of shape \
850-
(n_samples, n_encoded_features)
851-
Transformed input. If `sparse=True`, a sparse matrix will be
852-
returned.
853-
"""
854-
self._validate_keywords()
855-
return super().fit_transform(X, y)
856-
857817
def transform(self, X):
858818
"""
859819
Transform X using one-hot encoding.
@@ -1228,6 +1188,14 @@ class OrdinalEncoder(_OneToOneFeatureMixin, _BaseEncoder):
12281188
[ 0., -1.]])
12291189
"""
12301190

1191+
_parameter_constraints = {
1192+
"categories": [StrOptions({"auto"}), list],
1193+
"dtype": "no_validation", # validation delegated to numpy
1194+
"encoded_missing_value": [Integral, type(np.nan)],
1195+
"handle_unknown": [StrOptions({"error", "use_encoded_value"})],
1196+
"unknown_value": [Integral, type(np.nan), None],
1197+
}
1198+
12311199
def __init__(
12321200
self,
12331201
*,
@@ -1261,12 +1229,7 @@ def fit(self, X, y=None):
12611229
self : object
12621230
Fitted encoder.
12631231
"""
1264-
handle_unknown_strategies = ("error", "use_encoded_value")
1265-
if self.handle_unknown not in handle_unknown_strategies:
1266-
raise ValueError(
1267-
"handle_unknown should be either 'error' or "
1268-
f"'use_encoded_value', got {self.handle_unknown}."
1269-
)
1232+
self._validate_params()
12701233

12711234
if self.handle_unknown == "use_encoded_value":
12721235
if is_scalar_nan(self.unknown_value):

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 0 additions & 91 deletions
< 10000 td data-grid-cell-id="diff-e9fbba57002acf1400bee7d9a5f8b6981d4c348411c2a253cdc4ffdb3858b059-1433-1358-1" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative diff-line-number-neutral left-side">1358
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,6 @@ def test_one_hot_encoder_handle_unknown(handle_unknown):
6060
# ensure transformed data was not modified in place
6161
assert_allclose(X2, X2_passed)
6262

63-
# Raise error if handle_unknown is neither ignore or error.
64-
oh = OneHotEncoder(handle_unknown="42")
65-
with pytest.raises(ValueError, match="handle_unknown should be one of"):
66-
oh.fit(X)
67-
6863

6964
def test_one_hot_encoder_ F987 not_fitted():
7065
X = np.array([["a"], ["b"]])
@@ -716,50 +711,6 @@ def test_ordinal_encoder_handle_unknowns_numeric(dtype):
716711
assert_array_equal(X_trans_inv, inv_exp)
717712

718713

719-
@pytest.mark.parametrize(
720-
"params, err_type, err_msg",
721-
[
722-
(
723-
{"handle_unknown": "use_encoded_value"},
724-
TypeError,
725-
"unknown_value should be an integer or np.nan when handle_unknown "
726-
"is 'use_encoded_value', got None.",
727-
),
728-
(
729-
{"unknown_value": -2},
730-
TypeError,
731-
"unknown_value should only be set when handle_unknown is "
732-
"'use_encoded_value', got -2.",
733-
),
734-
(
735-
{"handle_unknown": "use_encoded_value", "unknown_value": "bla"},
736-
TypeError,
737-
"unknown_value should be an integer or np.nan when handle_unknown "
738-
"is 'use_encoded_value', got bla.",
739-
),
740-
(
741-
{"handle_unknown": "use_encoded_value", "unknown_value": 1},
742-
ValueError,
743-
"The used value for unknown_value (1) is one of the values "
744-
"already used for encoding the seen categories.",
745-
),
746-
(
747-
{"handle_unknown": "ignore"},
748-
ValueError,
749-
"handle_unknown should be either 'error' or 'use_encoded_value', "
750-
"got ignore.",
751-
),
752-
],
753-
)
754-
def test_ordinal_encoder_handle_unknowns_raise(params, err_type, err_msg):
755-
# Check error message when validating input parameters
756-
X = np.array([["a", "x"], ["b", "y"]], dtype=object)
757-
758-
encoder = OrdinalEncoder(**params)
759-
with pytest.raises(err_type, match=err_msg):
760-
encoder.fit(X)
761-
762-
763714
def test_ordinal_encoder_handle_unknowns_nan():
764715
# Make sure unknown_value=np.nan properly works
765716

@@ -886,32 +837,6 @@ def test_one_hot_encoder_drop_manual(missing_value):
886837
assert_array_equal(X_array, X_inv_trans)
887838

888839

889-
@pytest.mark.parametrize(
890-
"X_fit, params, err_msg",
891-
[
892-
(
893-
[["Male"], ["Female"]],
894-
{"drop": "second"},
895-
"Wrong input for parameter `drop`",
896-
),
897-
(
898-
[["abc", 2, 55], ["def", 1, 55], ["def", 3, 59]],
899-
{"drop": np.asarray("b", dtype=object)},
900-
"Wrong input for parameter `drop`",
901-
),
902-
(
903-
[["abc", 2, 55], ["def", 1, 55], ["def", 3, 59]],
904-
{"drop": ["ghi", 3, 59]},
905-
"The following categories were supposed",
906-
),
907-
],
908-
)
909-
def test_one_hot_encoder_invalid_params(X_fit, params, err_msg):
910-
enc = OneHotEncoder(**params)
911-
with pytest.raises(ValueError, match=err_msg):
912-
enc.fit(X_fit)
913-
914-
915840
@pytest.mark.parametrize("drop", [["abc", 3], ["abc", 3, 41, "a"]])
916841
def test_invalid_drop_length(drop):
917842
enc = OneHotEncoder(drop=drop)
@@ -1433,22 +1358,6 @@ def test_ohe_infrequent_user_cats_unknown_training_errors(kwargs):
1433
assert_allclose(X_trans, [[1], [1]])
14341359

14351360

1436-
@pytest.mark.parametrize(
1437-
"kwargs, error_msg",
1438-
[
1439-
({"max_categories": -2}, "max_categories must be greater than 1"),
1440-
({"min_frequency": -1}, "min_frequency must be an integer at least"),
1441-
({"min_frequency": 1.1}, "min_frequency must be an integer at least"),
1442-
],
1443-
)
1444-
def test_ohe_infrequent_invalid_parameters_error(kwargs, error_msg):
1445-
X_train = np.array([["a"] * 5 + ["b"] * 20 + ["c"] * 10 + ["d"] * 2]).T
1446-
1447-
ohe = OneHotEncoder(handle_unknown="infrequent_if_exist", **kwargs)
1448-
with pytest.raises(ValueError, match=error_msg):
1449-
ohe.fit(X_train)
1450-
1451-
14521361
# TODO: Remove in 1.2 when get_feature_names is removed
14531362
def test_one_hot_encoder_get_feature_names_deprecated():
14541363
X = np.array([["cat", "dog"]], dtype=object).T

sklearn/tests/test_common.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -551,10 +551,8 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
551551
"OAS",
552552
"OPTICS",
553553
"OneClassSVM",
554-
"OneHotEncoder",
555554
"OneVsOneClassifier",
556555
"OneVsRestClassifier",
557-
"OrdinalEncoder",
558556
"OrthogonalMatchingPursuit",
559557
"OrthogonalMatchingPursuitCV",
560558
"OutputCodeClassifier",

0 commit comments

Comments
 (0)
0