10000 CLN Refactors encoding logic · thomasjpfan/scikit-learn@805daf9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 805daf9

Browse files
committed
CLN Refactors encoding logic
1 parent c71a1c2 commit 805daf9

File tree

4 files changed

+77
-79
lines changed

4 files changed

+77
-79
lines changed

sklearn/metrics/_ranking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from ..utils.validation import _deprecate_positional_args
3535
from ..exceptions import UndefinedMetricWarning
3636
from ..preprocessing import label_binarize
37-
from ..preprocessing._label import _encode
37+
from ..preprocessing._label import _encode, _unique
3838

3939
from ._base import _average_binary_score, _average_multiclass_ovo_score
4040

@@ -460,7 +460,7 @@ def _multiclass_roc_auc_score(y_true, y_score, labels,
460460

461461
if labels is not None:
462462
labels = column_or_1d(labels)
463-
classes = _encode(labels)
463+
classes = _unique(labels)
464464
if len(classes) != len(labels):
465465
raise ValueError("Parameter 'labels' must be unique")
466466
if not np.array_equal(classes, labels):
@@ -474,7 +474,7 @@ def _multiclass_roc_auc_score(y_true, y_score, labels,
474474
raise ValueError(
475475
"'y_true' contains labels not in parameter 'labels'")
476476
else:
477-
classes = _encode(y_true)
477+
classes = _unique(y_true)
478478
if len(classes) != y_score.shape[1]:
479479
raise ValueError(
480480
"Number of classes in y_true not equal to the number of "

sklearn/preprocessing/_encoders.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..utils.validation import check_is_fitted
1111
from ..utils.validation import _deprecate_positional_args
1212

13-
from ._label import _encode, _encode_check_unknown
13+
from ._label import _encode, _encode_check_unknown, _unique
1414

1515

1616
__all__ = [
@@ -83,7 +83,7 @@ def _fit(self, X, handle_unknown='error'):
8383
for i in range(n_features):
8484
Xi = X_list[i]
8585
if self.categories == 'auto':
86-
cats = _encode(Xi)
86+
cats = _unique(Xi)
8787
else:
8888
cats = np.array(self.categories[i], dtype=Xi.dtype)
8989
if Xi.dtype != object:
@@ -138,9 +138,8 @@ def _transform(self, X, handle_unknown='error'):
138138
Xi[~valid_mask] = self.categories_[i][0]
139139
# We use check_unknown=False, since _encode_check_unknown was
140140
# already called above.
141-
_, encoded = _encode(Xi, self.categories_[i], encode=True,
142-
check_unknown=False)
143-
X_int[:, i] = encoded
141+
X_int[:, i] = _encode(Xi, uniques=self.categories_[i],
142+
check_unknown=False)
144143

145144
return X_int, X_mask
146145

sklearn/preprocessing/_label.py

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -34,46 +34,8 @@
3434
]
3535

3636

37-
def _encode_numpy(values, uniques=None, encode=False, check_unknown=True):
38-
# only used in _encode below, see docstring there for details
39-
if uniques is None:
40-
if encode:
41-
uniques, encoded = np.unique(values, return_inverse=True)
42-
return uniques, encoded
43-
else:
44-
# unique sorts
45-
return np.unique(values)
46-
if encode:
47-
if check_unknown:
48-
diff = _encode_check_unknown(values, uniques)
49-
if diff:
50-
raise ValueError("y contains previously unseen labels: %s"
51-
% str(diff))
52-
encoded = np.searchsorted(uniques, values)
53-
return uniques, encoded
54-
else:
55-
return uniques
56-
57-
58-
def _encode_python(values, uniques=None, encode=False):
59-
# only used in _encode below, see docstring there for details
60-
if uniques is None:
61-
uniques = sorted(set(val B41A ues))
62-
uniques = np.array(uniques, dtype=values.dtype)
63-
if encode:
64-
table = {val: i for i, val in enumerate(uniques)}
65-
try:
66-
encoded = np.array([table[v] for v in values])
67-
except KeyError as e:
68-
raise ValueError("y contains previously unseen labels: %s"
69-
% str(e))
70-
return uniques, encoded
71-
else:
72-
return uniques
73-
74-
75-
def _encode(values, uniques=None, encode=False, check_unknown=True):
76-
"""Helper function to factorize (find uniques) and encode values.
37+
def _encode(values, *, uniques, check_unknown=True):
38+
"""Helper function encode values.
7739
7840
Uses pure python method for object dtype, and numpy method for
7941
all other dtypes.
@@ -86,12 +48,10 @@ def _encode(values, uniques=None, encode=False, check_unknown=True):
8648
----------
8749
values : array
8850
Values to factorize or encode.
89-
uniques : array, optional
90-
If passed, uniques are not determined from passed values (this
51+
uniques : array
52+
Uniques are not determined from passed values (this
9153
can be because the user specified categories, or because they
9254
already have been determined in fit).
93-
encode : bool, default False
94-
If True, also encode the values into integer codes based on `uniques`.
9555
check_unknown : bool, default True
9656
If True, check for values in ``values`` that are not in ``unique``
9757
and raise an error. This is ignored for object dtype, and treated as
@@ -101,25 +61,67 @@ def _encode(values, uniques=None, encode=False, check_unknown=True):
10161
10262
Returns
10363
-------
104-
uniques
105-
If ``encode=False``. The unique values are sorted if the `uniques`
106-
parameter was None (and thus inferred from the data).
107-
(uniques, encoded)
108-
If ``encode=True``.
109-
64+
encoded : ndarray
65+
Encoded values
11066
"""
11167
if values.dtype == object:
68+
table = {val: i for i, val in enumerate(uniques)}
11269
try:
113-
res = _encode_python(values, uniques, encode)
114-
except TypeError:
115-
types = sorted(t.__qualname__
116-
for t in set(type(v) for v in values))
117-
raise TypeError("Encoders require their input to be uniformly "
118-
f"strings or numbers. Got {types}")
119-
return res
70+
return np.array([table[v] for v in values])
71+
except KeyError as e:
72+
raise ValueError(f"y contains previously unseen labels: {str(e)}")
12073
else:
121-
return _encode_numpy(values, uniques, encode,
122-
check_unknown=check_unknown)
74+
if check_unknown:
75+
diff = _encode_check_unknown(values, uniques)
76+
if diff:
77+
raise ValueError(f"y contains previously unseen labels: "
78+
f"{str(diff)}")
79+
return np.searchsorted(uniques, values)
80+
81+
82+
def _unique_python(values, *, return_inverse):
83+
# Only used in _u 10000 niques below, see docstring there for details
84+
try:
85+
uniques = sorted(set(values))
86+
uniques = np.array(uniques, dtype=values.dtype)
87+
except TypeError:
88+
types = sorted(t.__qualname__
89+
for t in set(type(v) for v in values))
90+
raise TypeError("Encoders require their input to be uniformly "
91+
f"strings or numbers. Got {types}")
92+
93+
ret = (uniques, )
94+
95+
if return_inverse:
96+
table = {val: i for i, val in enumerate(uniques)}
97+
inverse = np.array([table[v] for v in values])
98+
ret += (inverse, )
99+
100+
if len(ret) == 1:
101+
ret = ret[0]
102+
103+
return ret
104+
105+
106+
def _unique(values, *, return_inverse=False):
107+
"""Helper function to find uniques with support for python objects.
108+
109+
Uses pure python method for object dtype, and numpy method for
110+
all other dtypes.
111+
112+
Parameters
113+
----------
114+
unique : ndarray
115+
The sorted uniique values
116+
117+
unique_inverse : ndarray
118+
The indicies to reconstruct the original array from the unique array.
119+
Only provided if `return_inverse` is True.
120+
"""
121+
if values.dtype == object:
122+
return _unique_python(values, return_inverse=return_inverse)
123+
# numerical
124+
return np.unique(values, return_inverse=return_inverse)
123125

124126

125127
def _encode_check_unknown(values, uniques, return_mask=False):
@@ -237,7 +239,7 @@ def fit(self, y):
237239
self : returns an instance of self.
238240
"""
239241
y = column_or_1d(y, warn=True)
240-
self.classes_ = _encode(y)
242+
self.classes_ = _unique(y)
241243
return self
242244

243245
def fit_transform(self, y):
@@ -253,7 +255,7 @@ def fit_transform(self, y):
253255
y : array-like of shape [n_samples]
254256
"""
255257
y = column_or_1d(y, warn=True)
256-
self.classes_, y = _encode(y, encode=True)
258+
self.classes_, y = _unique(y, return_inverse=True)
257259
return y
258260

259261
def transform(self, y):
@@ -274,8 +276,7 @@ def transform(self, y):
274276
if _num_samples(y) == 0:
275277
return np.array([])
276278

277-
_, y = _encode(y, uniques=self.classes_, encode=True)
278-
return y
279+
return _encode(y, uniques=self.classes_)
279280

280281
def inverse_transform(self, y):
281282
"""Transform labels back to original encoding.

sklearn/preprocessing/tests/test_label.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.preprocessing._label import _inverse_binarize_thresholding
2525
from sklearn.preprocessing._label import _inverse_binarize_multiclass
2626
from sklearn.preprocessing._label import _encode
27+
from sklearn.preprocessing._label import _unique
2728

2829
from sklearn import datasets
2930

@@ -626,12 +627,9 @@ def test_inverse_binarize_multiclass():
626627
np.array(['a', 'b', 'c']))],
627628
ids=['int64', 'object', 'str'])
628629
def test_encode_util(values, expected):
629-
uniques = _encode(values)
630+
uniques = _unique(values)
630631
assert_array_equal(uniques, expected)
631-
uniques, encoded = _encode(values, encode=True)
632-
assert_array_equal(uniques, expected)
633-
assert_array_equal(encoded, np.array([1, 0, 2, 0, 2]))
634-
_, encoded = _encode(values, uniques, encode=True)
632+
encoded = _encode(values, uniques=uniques)
635633
assert_array_equal(encoded, np.array([1, 0, 2, 0, 2]))
636634

637635

@@ -643,14 +641,14 @@ def test_encode_check_unknown():
643641
# Default is True, raise error
644642
with pytest.raises(ValueError,
645643
match='y contains previously unseen labels'):
646-
_encode(values, uniques, encode=True, check_unknown=True)
644+
_encode(values, uniques=uniques, check_unknown=True)
647645

648646
# dont raise error if False
649-
_encode(values, uniques, encode=True, check_unknown=False)
647+
_encode(values, uniques=uniques, check_unknown=False)
650648

651649
# parameter is ignored for object dtype
652650
uniques = np.array(['a', 'b', 'c'], dtype=object)
653651
values = np.array(['a', 'b', 'c', 'd'], dtype=object)
654652
with pytest.raises(ValueError,
655653
match='y contains previously unseen labels'):
656-
_encode(values, uniques, encode=True, check_unknown=False)
654+
_encode(values, uniques=uniques, check_unknown=False)

0 commit comments

Comments
 (0)
0