3
3
# License: BSD 3 clause
4
4
5
5
import numbers
6
+ from numbers import Integral , Real
6
7
import warnings
7
8
8
9
import numpy as np
13
14
from ..utils .deprecation import deprecated
14
15
from ..utils .validation import check_is_fitted
15
16
from ..utils .validation import _check_feature_names_in
17
+ from ..utils ._param_validation import Interval
18
+ from ..utils ._param_validation import StrOptions
16
19
from ..utils ._mask import _get_mask
17
20
18
21
from ..utils ._encode import _encode , _check_unknown , _unique , _get_counts
@@ -430,6 +433,20 @@ class OneHotEncoder(_BaseEncoder):
430
433
[1., 0., 0.]])
431
434
"""
432
435
436
+ _parameter_constraints = {
437
+ "categories" : [StrOptions ({"auto" }), list ],
438
+ "drop" : [StrOptions ({"first" , "if_binary" }), "array-like" , None ],
439
+ "dtype" : "no_validation" , # validation delegated to numpy
440
+ "handle_unknown" : [StrOptions ({"error" , "ignore" , "infrequent_if_exist" })],
441
+ "max_categories" : [Interval (Integral , 1 , None , closed = "left" ), None ],
442
+ "min_frequency" : [
443
+ Interval (Integral , 1 , None , closed = "left" ),
444
+ Interval (Real , 0 , 1 , closed = "neither" ),
445
+ None ,
446
+ ],
447
+ "sparse" : ["boolean" ],
448
+ }
449
+
433
450
def __init__ (
434
451
self ,
435
452
* ,
@@ -459,33 +476,11 @@ def infrequent_categories_(self):
459
476
for category , indices in zip (self .categories_ , infrequent_indices )
460
477
]
461
478
462
- def _validate_keywords (self ):
463
-
464
- if self .handle_unknown not in {"error" , "ignore" , "infrequent_if_exist" }:
465
- msg = (
466
- "handle_unknown should be one of 'error', 'ignore', "
467
- f"'infrequent_if_exist' got { self .handle_unknown } ."
468
- )
469
- raise ValueError (msg )
470
-
471
- if self .max_categories is not None and self .max_categories < 1 :
472
- raise ValueError ("max_categories must be greater than 1" )
473
-
474
- if isinstance (self .min_frequency , numbers .Integral ):
475
- if not self .min_frequency >= 1 :
476
- raise ValueError (
477
- "min_frequency must be an integer at least "
478
- "1 or a float in (0.0, 1.0); got the "
479
- f"integer { self .min_frequency } "
480
- )
481
- elif isinstance (self .min_frequency , numbers .Real ):
482
- if not (0.0 < self .min_frequency < 1.0 ):
483
- raise ValueError (
484
- "min_frequency must be an integer at least "
485
- "1 or a float in (0.0, 1.0); got the "
486
- f"float { self .min_frequency } "
487
- )
488
-
479
+ def _check_infrequent_enabled (self ):
480
+ """
481
+ This functions checks whether _infrequent_enabled is True or False.
482
+ This has to be called after parameter validation in the fit function.
483
+ """
489
484
self ._infrequent_enabled = (
490
485
self .max_categories is not None and self .max_categories >= 1
491
486
) or self .min_frequency is not None
@@ -547,23 +542,11 @@ def _compute_drop_idx(self):
547
542
],
548
543
dtype = object ,
549
544
)
550
- else :
551
- msg = (
552
- "Wrong input for parameter `drop`. Expected "
553
- "'first', 'if_binary', None or array of objects, got {}"
554
- )
555
- raise ValueError (msg .format (type (self .drop )))
556
545
557
546
else :
558
- try :
559
- drop_array = np .asarray (self .drop , dtype = object )
560
- droplen = len (drop_array )
561
- except (ValueError , TypeError ):
562
- msg = (
563
- "Wrong input for parameter `drop`. Expected "
564
- "'first', 'if_binary', None or array of objects, got {}"
565
- )
566
- raise ValueError (msg .format (type (drop_array )))
547
+ drop_array = np .asarray (self .drop , dtype = object )
548
+ droplen = len (drop_array )
549
+
567
550
if droplen != len (self .categories_ ):
568
551
msg = (
569
552
"`drop` should have length equal to the number "
@@ -814,7 +797,9 @@ def fit(self, X, y=None):
814
797
self
815
798
Fitted encoder.
816
799
"""
817
- self ._validate_keywords ()
800
+ self ._validate_params ()
801
+ self ._check_infrequent_enabled ()
802
+
818
803
fit_results = self ._fit (
819
804
X ,
820
805
handle_unknown = self .handle_unknown ,
@@ -829,31 +814,6 @@ def fit(self, X, y=None):
829
814
self ._n_features_outs = self ._compute_n_features_outs ()
830
815
return self
831
816
832
- def fit_transform (self , X , y = None ):
833
- """
834
- Fit OneHotEncoder to X, then transform X.
835
-
836
- Equivalent to fit(X).transform(X) but more convenient.
837
-
838
- Parameters
839
- ----------
840
- X : array-like of shape (n_samples, n_features)
841
- The data to encode.
842
-
843
- y : None
844
- Ignored. This parameter exists only for compatibility with
845
- :class:`~sklearn.pipeline.Pipeline`.
846
-
847
- Returns
848
- -------
849
- X_out : {ndarray, sparse matrix} of shape \
850
- (n_samples, n_encoded_features)
851
- Transformed input. If `sparse=True`, a sparse matrix will be
852
- returned.
853
- """
854
- self ._validate_keywords ()
855
- return super ().fit_transform (X , y )
856
-
857
817
def transform (self , X ):
858
818
"""
859
819
Transform X using one-hot encoding.
@@ -1228,6 +1188,14 @@ class OrdinalEncoder(_OneToOneFeatureMixin, _BaseEncoder):
1228
1188
[ 0., -1.]])
1229
1189
"""
1230
1190
1191
+ _parameter_constraints = {
1192
+ "categories" : [StrOptions ({"auto" }), list ],
1193
+ "dtype" : "no_validation" , # validation delegated to numpy
1194
+ "encoded_missing_value" : [Integral , type (np .nan )],
1195
+ "handle_unknown" : [StrOptions ({"error" , "use_encoded_value" })],
1196
+ "unknown_value" : [Integral , type (np .nan ), None ],
1197
+ }
1198
+
1231
1199
def __init__ (
1232
1200
self ,
1233
1201
* ,
@@ -1261,12 +1229,7 @@ def fit(self, X, y=None):
1261
1229
self : object
1262
1230
Fitted encoder.
1263
1231
"""
1264
- handle_unknown_strategies = ("error" , "use_encoded_value" )
1265
- if self .handle_unknown not in handle_unknown_strategies :
1266
- raise ValueError (
1267
- "handle_unknown should be either 'error' or "
1268
- f"'use_encoded_value', got { self .handle_unknown } ."
1269
- )
1232
+ self ._validate_params ()
1270
1233
1271
1234
if self .handle_unknown == "use_encoded_value" :
1272
1235
if is_scalar_nan (self .unknown_value ):
0 commit comments