34
34
]
35
35
36
36
37
- def _encode_numpy (values , uniques = None , encode = False , check_unknown = True ):
38
- # only used in _encode below, see docstring there for details
39
- if uniques is None :
40
- if encode :
41
- uniques , encoded = np .unique (values , return_inverse = True )
42
- return uniques , encoded
43
- else :
44
- # unique sorts
45
- return np .unique (values )
46
- if encode :
47
- if check_unknown :
48
- diff = _encode_check_unknown (values , uniques )
49
- if diff :
50
- raise ValueError ("y contains previously unseen labels: %s"
51
- % str (diff ))
52
- encoded = np .searchsorted (uniques , values )
53
- return uniques , encoded
54
- else :
55
- return uniques
56
-
57
-
58
- def _encode_python (values , uniques = None , encode = False ):
59
- # only used in _encode below, see docstring there for details
60
- if uniques is None :
61
- uniques = sorted (set (val
B41A
ues ))
62
- uniques = np .array (uniques , dtype = values .dtype )
63
- if encode :
64
- table = {val : i for i , val in enumerate (uniques )}
65
- try :
66
- encoded = np .array ([table [v ] for v in values ])
67
- except KeyError as e :
68
- raise ValueError ("y contains previously unseen labels: %s"
69
- % str (e ))
70
- return uniques , encoded
71
- else :
72
- return uniques
73
-
74
-
75
- def _encode (values , uniques = None , encode = False , check_unknown = True ):
76
- """Helper function to factorize (find uniques) and encode values.
37
+ def _encode (values , * , uniques , check_unknown = True ):
38
+ """Helper function encode values.
77
39
78
40
Uses pure python method for object dtype, and numpy method for
79
41
all other dtypes.
@@ -86,12 +48,10 @@ def _encode(values, uniques=None, encode=False, check_unknown=True):
86
48
----------
87
49
values : array
88
50
Values to factorize or encode.
89
- uniques : array, optional
90
- If passed, uniques are not determined from passed values (this
51
+ uniques : array
52
+ Uniques are not determined from passed values (this
91
53
can be because the user specified categories, or because they
92
54
already have been determined in fit).
93
- encode : bool, default False
94
- If True, also encode the values into integer codes based on `uniques`.
95
55
check_unknown : bool, default True
96
56
If True, check for values in ``values`` that are not in ``unique``
97
57
and raise an error. This is ignored for object dtype, and treated as
@@ -101,25 +61,67 @@ def _encode(values, uniques=None, encode=False, check_unknown=True):
101
61
102
62
Returns
103
63
-------
104
- uniques
105
- If ``encode=False``. The unique values are sorted if the `uniques`
106
- parameter was None (and thus inferred from the data).
107
- (uniques, encoded)
108
- If ``encode=True``.
109
-
64
+ encoded : ndarray
65
+ Encoded values
110
66
"""
111
67
if values .dtype == object :
68
+ table = {val : i for i , val in enumerate (uniques )}
112
69
try :
113
- res = _encode_python (values , uniques , encode )
114
- except TypeError :
115
- types = sorted (t .__qualname__
116
- for t in set (type (v ) for v in values ))
117
- raise TypeError ("Encoders require their input to be uniformly "
118
- f"strings or numbers. Got { types } " )
119
- return res
70
+ return np .array ([table [v ] for v in values ])
71
+ except KeyError as e :
72
+ raise ValueError (f"y contains previously unseen labels: { str (e )} " )
120
73
else :
121
- return _encode_numpy (values , uniques , encode ,
122
- check_unknown = check_unknown )
74
+ if check_unknown :
75
+ diff = _encode_check_unknown (values , uniques )
76
+ if diff :
77
+ raise ValueError (f"y contains previously unseen labels: "
78
+ f"{ str (diff )} " )
79
+ return np .searchsorted (uniques , values )
80
+
81
+
82
+ def _unique_python (values , * , return_inverse ):
83
+ # Only used in _u
10000
niques below, see docstring there for details
84
+ try :
85
+ uniques = sorted (set (values ))
86
+ uniques = np .array (uniques , dtype = values .dtype )
87
+ except TypeError :
88
+ types = sorted (t .__qualname__
89
+ for t in set (type (v ) for v in values ))
90
+ raise TypeError ("Encoders require their input to be uniformly "
91
+ f"strings or numbers. Got { types } " )
92
+
93
+ ret = (uniques , )
94
+
95
+ if return_inverse :
96
+ table = {val : i for i , val in enumerate (uniques )}
97
+ inverse = np .array ([table [v ] for v in values ])
98
+ ret += (inverse , )
99
+
100
+ if len (ret ) == 1 :
101
+ ret = ret [0 ]
102
+
103
+ return ret
104
+
105
+
106
+ def _unique (values , * , return_inverse = False ):
107
+ """Helper function to find uniques with support for python objects.
108
+
109
+ Uses pure python method for object dtype, and numpy method for
110
+ all other dtypes.
111
+
112
+ Parameters
113
+ ----------
114
+ unique : ndarray
115
+ The sorted uniique values
116
+
117
+ unique_inverse : ndarray
118
+ The indicies to reconstruct the original array from the unique array.
119
+ Only provided if `return_inverse` is True.
120
+ """
121
+ if values .dtype == object :
122
+ return _unique_python (values , return_inverse = return_inverse )
123
+ # numerical
124
+ return np .unique (values , return_inverse = return_inverse )
123
125
124
126
125
127
def _encode_check_unknown (values , uniques , return_mask = False ):
@@ -237,7 +239,7 @@ def fit(self, y):
237
239
self : returns an instance of self.
238
240
"""
239
241
y = column_or_1d (y , warn = True )
240
- self .classes_ = _encode (y )
242
+ self .classes_ = _unique (y )
241
243
return self
242
244
243
245
def fit_transform (self , y ):
@@ -253,7 +255,7 @@ def fit_transform(self, y):
253
255
y : array-like of shape [n_samples]
254
256
"""
255
257
y = column_or_1d (y , warn = True )
256
- self .classes_ , y = _encode (y , encode = True )
258
+ self .classes_ , y = _unique (y , return_inverse = True )
257
259
return y
258
260
259
261
def transform (self , y ):
@@ -274,8 +276,7 @@ def transform(self, y):
274
276
if _num_samples (y ) == 0 :
275
277
return np .array ([])
276
278
277
- _ , y = _encode (y , uniques = self .classes_ , encode = True )
278
- return y
279
+ return _encode (y , uniques = self .classes_ )
279
280
280
281
def inverse_transform (self , y ):
281
282
"""Transform labels back to original encoding.
0 commit comments