8000 FIX Ensure dtype of categories is `object` for strings in `OneHotEnco… · scikit-learn/scikit-learn@ecb9a70 · GitHub
[go: up one dir, main page]

Skip to content

Commit ecb9a70

Browse files
betatimglemaitrethomasjpfan
authored
FIX Ensure dtype of categories is object for strings in OneHotEncoder (#25174)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent c78a422 commit ecb9a70

File tree

3 files changed

+65
-3
lines changed

3 files changed

+65
-3
lines changed

doc/whats_new/v1.3.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ parameters, may produce different models from the previous version. This often
1919
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
2020
random sampling procedures.
2121

22+
- |Fix| The `categories_` attribute of :class:`preprocessing.OneHotEncoder` now
23+
always contains an array of `object`s when using predefined categories that
24+
are strings. Predefined categories encoded as bytes will no longer work
25+
with `X` encoded as strings. :pr:`25174` by :user:`Tim Head <betatim>`.
26+
2227
Changes impacting all modules
2328
-----------------------------
2429

@@ -51,7 +56,6 @@ Changelog
5156

5257
:mod:`sklearn.preprocessing`
5358
............................
54-
5559
- |Enhancement| Added support for `sample_weight` in
5660
:class:`preprocessing.KBinsDiscretizer`. This allows specifying the parameter
5761
`sample_weight` for each sample to be used while fitting. The option is only

sklearn/preprocessing/_encoders.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,27 @@ def _fit(
9797
else:
9898
cats = result
9999
else:
100-
cats = np.array(self.categories[i], dtype=Xi.dtype)
100+
if np.issubdtype(Xi.dtype, np.str_):
101+
# Always convert string categories to objects to avoid
102+
# unexpected string truncation for longer category labels
103+
# passed in the constructor.
104+
Xi_dtype = object
105+
else:
106+
Xi_dtype = Xi.dtype
107+
108+
cats = np.array(self.categories[i], dtype=Xi_dtype)
109+
if (
110+
cats.dtype == object
111+
and isinstance(cats[0], bytes)
112+
and Xi.dtype.kind != "S"
113+
):
114+
msg = (
115+
f"In column {i}, the predefined categories have type 'bytes'"
116+
" which is incompatible with values of type"
117+
f" '{type(Xi[0]).__name__}'."
118+
)
119+
raise ValueError(msg)
120+
101121
if Xi.dtype.kind not in "OUS":
102122
sorted_cats = np.sort(cats)
103123
error_msg = (

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1346,7 +1346,7 @@ def test_one_hot_encoder_sparse_deprecated():
13461346

13471347
# deliberately omit 'OS' as an invalid combo
13481348
@pytest.mark.parametrize(
1349-
"input_dtype, category_dtype", ["OO", "OU", "UO", "UU", "US", "SO", "SU", "SS"]
1349+
"input_dtype, category_dtype", ["OO", "OU", "UO", "UU", "SO", "SU", "SS"]
13501350
)
13511351
@pytest.mark.parametrize("array_type", ["list", "array", "dataframe"])
13521352
def test_encoders_string_categories(input_dtype, category_dtype, array_type):
@@ -1376,6 +1376,27 @@ def test_encoders_string_categories(input_dtype, category_dtype, array_type):
13761376
assert_array_equal(X_trans, expected)
13771377

13781378

1379+
def test_mixed_string_bytes_categoricals():
1380+
"""Check that this mixture of predefined categories and X raises an error.
1381+
1382+
Categories defined as bytes can not easily be compared to data that is
1383+
a string.
1384+
"""
1385+
# data as unicode
1386+
X = np.array([["b"], ["a"]], dtype="U")
1387+
# predefined categories as bytes
1388+
categories = [np.array(["b", "a"], dtype="S")]
1389+
ohe = OneHotEncoder(categories=categories, sparse_output=False)
1390+
1391+
msg = re.escape(
1392+
"In column 0, the predefined categories have type 'bytes' which is incompatible"
1393+
" with values of type 'str_'."
1394+
)
1395+
1396+
with pytest.raises(ValueError, match=msg):
1397+
ohe.fit(X)
1398+
1399+
13791400
@pytest.mark.parametrize("missing_value", [np.nan, None])
13801401
def test_ohe_missing_values_get_feature_names(missing_value):
13811402
# encoder with missing values with object dtypes
@@ -1939,3 +1960,20 @@ def test_ordinal_set_output():
19391960

19401961
assert_allclose(X_pandas.to_numpy(), X_default)
19411962
assert_array_equal(ord_pandas.get_feature_names_out(), X_pandas.columns)
1963+
1964+
1965+
def test_predefined_categories_dtype():
1966+
"""Check that the categories_ dtype is `object` for string categories
1967+
1968+
Regression test for gh-25171.
1969+
"""
1970+
categories = [["as", "mmas", "eas", "ras", "acs"], ["1", "2"]]
1971+
1972+
enc = OneHotEncoder(categories=categories)
1973+
1974+
enc.fit([["as", "1"]])
1975+
1976+
assert len(categories) == len(enc.categories_)
1977+
for n, cat in enumerate(enc.categories_):
1978+
assert cat.dtype == object
1979+
assert_array_equal(categories[n], cat)

0 commit comments

Comments
 (0)
0