8000 FIX Allow OrdinalEncoder's encoded_missing_value set to the cardinali… · scikit-learn/scikit-learn@ae4a1b1 · GitHub
[go: up one dir, main page]

Skip to content

Commit ae4a1b1

Browse files
authored
FIX Allow OrdinalEncoder's encoded_missing_value set to the cardinality (#25704)
1 parent b4afbee commit ae4a1b1

File tree

3 files changed

+36
-11
lines changed

3 files changed

+36
-11
lines changed

doc/whats_new/v1.2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ Changelog
7272
when the global configuration sets `transform_output="pandas"`.
7373
:pr:`25500` by :user:`Guillaume Lemaitre <glemaitre>`.
7474

75+
:mod:`sklearn.preprocessing`
76+
............................
77+
78+
- |Fix| :class:`preprocessing.OrdinalEncoder` now correctly supports
79+
`encoded_missing_value` or `unknown_value` set to a categories' cardinality
80+
when there is missing values in the training data. :pr:`25704` by `Thomas Fan`_.
81+
7582
:mod:`sklearn.utils`
7683
....................
7784

sklearn/preprocessing/_encoders.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,24 +1300,30 @@ def fit(self, X, y=None):
13001300
# `_fit` will only raise an error when `self.handle_unknown="error"`
13011301
self._fit(X, handle_unknown=self.handle_unknown, force_all_finite="allow-nan")
13021302

1303-
if self.handle_unknown == "use_encoded_value":
1304-
for feature_cats in self.categories_:
1305-
if 0 <= self.unknown_value < len(feature_cats):
1306-
raise ValueError(
1307-
"The used value for unknown_value "
1308-
f"{self.unknown_value} is one of the "
1309-
"values already used for encoding the "
1310-
"seen categories."
1311-
)
1303+
cardinalities = [len(categories) for categories in self.categories_]
13121304

13131305
# stores the missing indices per category
13141306
self._missing_indices = {}
13151307
for cat_idx, categories_for_idx in enumerate(self.categories_):
13161308
for i, cat in enumerate(categories_for_idx):
13171309
if is_scalar_nan(cat):
13181310
self._missing_indices[cat_idx] = i
1311+
1312+
# missing values are not considered part of the cardinality
1313+
# when considering unknown categories or encoded_missing_value
1314+
cardinalities[cat_idx] -= 1
13191315
continue
13201316

1317+
if self.handle_unknown == "use_encoded_value":
1318+
for cardinality in cardinalities:
1319+
if 0 <= self.unknown_value < cardinality:
1320+
raise ValueError(
1321+
"The used value for unknown_value "
1322+
f"{self.unknown_value} is one of the "
1323+
"values already used for encoding the "
1324+
"seen categories."
1325+
)
1326+
13211327
if self._missing_indices:
13221328
if np.dtype(self.dtype).kind != "f" and is_scalar_nan(
13231329
self.encoded_missing_value
@@ -1336,9 +1342,9 @@ def fit(self, X, y=None):
13361342
# known category
13371343
invalid_features = [
13381344
cat_idx
1339-
for cat_idx, categories_for_idx in enumerate(self.categories_)
1345+
for cat_idx, cardinality in enumerate(cardinalities)
13401346
if cat_idx in self._missing_indices
1341-
and 0 <= self.encoded_missing_value < len(categories_for_idx)
1347+
and 0 <= self.encoded_missing_value < cardinality
13421348
]
13431349

13441350
if invalid_features:

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,3 +2003,15 @@ def test_predefined_categories_dtype():
20032003
for n, cat in enumerate(enc.categories_):
20042004
assert cat.dtype == object
20052005
assert_array_equal(categories[n], cat)
2006+
2007+
2008+
def test_ordinal_encoder_missing_unknown_encoding_max():
2009+
"""Check missing value or unknown encoding can equal the cardinality."""
2010+
X = np.array([["dog"], ["cat"], [np.nan]], dtype=object)
2011+
X_trans = OrdinalEncoder(encoded_missing_value=2).fit_transform(X)
2012+
assert_allclose(X_trans, [[1], [0], [2]])
2013+
2014+
enc = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=2).fit(X)
2015+
X_test = np.array([["snake"]])
2016+
X_trans = enc.transform(X_test)
2017+
assert_allclose(X_trans, [[2]])

0 commit comments

Comments
 (0)
0