8000 fix + test specifying of categories · scikit-learn/scikit-learn@fda6d27 · GitHub
[go: up one dir, main page]

Skip to content

Commit fda6d27

Browse files
fix + test specifying of categories
1 parent bea23a5 commit fda6d27

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

sklearn/preprocessing/data.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2462,11 +2462,12 @@ class CategoricalEncoder(BaseEstimator, TransformerMixin):
24622462
24632463
Parameters
24642464
----------
2465-
categories : 'auto', 2D array of ints or strings or both.
2465+
categories : 'auto' or a list of lists/arrays of values.
24662466
Values per feature.
24672467
24682468
- 'auto' : Determine classes automatically from the training data.
2469-
- array: ``classes[i]`` holds the classes expected in the ith column.
2469+
- list : ``categories[i]`` holds the categories expected in the ith
2470+
column.
24702471
24712472
dtype : number type, default=np.float
24722473
Desired dtype of output.
@@ -2544,7 +2545,18 @@ def _fit(self, X):
25442545
if self.categories == 'auto':
25452546
le.fit(X[:, i])
25462547
else:
2547-
le.classes_ = np.array(self.categories[i])
2548+
if not np.all(np.in1d(X[:, i], self.categories[i])):
2549+
if self.handle_unknown == 'error':
2550+
diff = np.setdiff1d(X[:, i], self.categories[i])
2551+
msg = 'Unknown feature(s) %s in column %d' % (diff, i)
2552+
raise ValueError(msg)
2553+
le.classes_ = np.array(np.sort(self.categories[i]))
2554+
2555+
@staticmethod
2556+
def _check_unknown_categories(values, categories):
2557+
"""Returns False if not all categories in the values are known"""
2558+
valid_mask = np.in1d(values, categories)
2559+
return np.all(valid_mask)
25482560

25492561
def transform(self, X, y=None):
25502562
"""Encode the selected categorical features using the one-hot scheme.

sklearn/preprocessing/tests/test_data.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,6 +2003,36 @@ def test_categorical_encoder_errors():
20032003
assert_allclose(Xtr.toarray(), [[0, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1]])
20042004

20052005

2006+
def test_categorical_encoder_specified_categories():
< 8000 /td>2007+
X = np.array([['a', 'b']], dtype=object).T
2008+
2009+
enc = CategoricalEncoder(categories=[['a', 'b', 'c']])
2010+
exp = np.array([[1., 0., 0.],
2011+
[0., 1., 0.]])
2012+
assert_array_equal(enc.fit_transform(X).toarray(), exp)
2013+
2014+
# don't follow order of passed categories, but sort them
2015+
enc = CategoricalEncoder(categories=[['c', 'b', 'a']])
2016+
exp = np.array([[1., 0., 0.],
2017+
[0., 1., 0.]])
2018+
assert_array_equal(enc.fit_transform(X).toarray(), exp)
2019+
2020+
# multiple columns
2021+
X = np.array([['a', 'b'], ['A', 'C']], dtype=object).T
2022+
enc = CategoricalEncoder(categories=[['a', 'b', 'c'], ['A', 'B', 'C']])
2023+
exp = np.array([[1., 0., 0., 1., 0., 0.],
2024+
[0., 1., 0., 0., 0., 1.]])
2025+
assert_array_equal(enc.fit_transform(X).toarray(), exp)
2026+
2027+
# when specifying categories manually, unknown categories should already
2028+
# raise when fitting
2029+
X = np.array([['a', 'b', 'c']]).T
2030+
enc = CategoricalEncoder(categories=[['a', 'b']])
2031+
assert_raises(ValueError, enc.fit, X)
2032+
enc = CategoricalEncoder(categories=[['a', 'b']], handle_unknown='ignore')
2033+
enc.fit(X)
2034+
2035+
20062036
def test_fit_cold_start():
20072037
X = iris.data
20082038
X_2d = X[:, :2]

0 commit comments

Comments
 (0)
0