8000 [WIP] "other"/min_freq in OneHot and OrdinalEncoder by datajanko · Pull Request #12264 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] "other"/min_freq in OneHot and OrdinalEncoder #12264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions sklearn/preprocessing/_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import numpy as np
from scipy import sparse
from collections import Counter

from .. import get_config as _get_config
from ..base import BaseEstimator, TransformerMixin
Expand Down Expand Up @@ -73,13 +74,15 @@ def _fit(self, X, handle_unknown='error'):
if len(self._categories) != n_features:
raise ValueError("Shape mismatch: if n_values is an array,"
" it has to be of shape (n_features,).")

self.groups_ = []
self.categories_ = []

for i in range(n_features):
Xi = X[:, i]
if self._categories == 'auto':
Xi, group = _group_values(Xi.copy(), min_freq=self.min_freq)
cats = _encode(Xi)
self.groups_.append(group)
else:
cats = np.array(self._categories[i], dtype=X.dtype)
if self.handle_unknown == 'error':
Expand All @@ -99,6 +102,10 @@ def _transform(self, X, handle_unknown='error'):

for i in range(n_features):
Xi = X[:, i]
try:
Xi, _ = _group_values(Xi, group=self.groups_[i])
except IndexError:
pass
diff, valid_mask = _encode_check_unknown(Xi, self.categories_[i],
return_mask=True)

Expand Down Expand Up @@ -198,6 +205,9 @@ class OneHotEncoder(_BaseEncoder):
0.20 and will be removed in 0.22.
You can use the ``ColumnTransformer`` instead.

min_freq: float, default=0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space before colon, please

group low frequent categories together
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be more specific, please. This should describe what the parameter is.


Attributes
----------
categories_ : list of arrays
Expand Down Expand Up @@ -272,13 +282,14 @@ class OneHotEncoder(_BaseEncoder):

def __init__(self, n_values=None, categorical_features=None,
categories=None, sparse=True, dtype=np.float64,
handle_unknown='error'):
min_freq=0, handle_unknown='error'):
self.categories = categories
self.sparse = sparse
self.dtype = dtype
self.handle_unknown = handle_unknown
self.n_values = n_values
self.categorical_features = categorical_features
self.min_freq = min_freq

# Deprecated attributes

Expand Down Expand Up @@ -759,9 +770,10 @@ class OrdinalEncoder(_BaseEncoder):
between 0 and n_classes-1.
"""

def __init__(self, categories='auto', dtype=np.float64):
def __init__(self, categories='auto', dtype=np.float64, min_freq=0):
self.categories = categories
self.dtype = dtype
self.min_freq = min_freq

def fit(self, X, y=None):
"""Fit the OrdinalEncoder to X.
Expand Down Expand Up @@ -835,3 +847,46 @@ def inverse_transform(self, X):
X_tr[:, i] = self.categories_[i][labels]

return X_tr


def _group_values_python(values, min_freq=0, group=None):
if min_freq and group:
raise ValueError
if min_freq:
freqs = {key: counts/len(values)
for key, counts in Counter(values).items()}
low_freq_keys = (key for key, freq in freqs.items() if freq < min_freq)
# sorting ensures first element in group is always the same
group = np.array(sorted(set(low_freq_keys)), dtype=values.dtype)
if group is not None:
try:
values[np.isin(values, group)] = group[0]
except IndexError:
pass
return values, group
else:
return values, group


def _group_values_numpy(values, min_freq=0, group=None):
if min_freq and group:
raise ValueError
if min_freq:
uniques, counts = np.unique(values, return_counts=True)
mask = (counts/len(values) < min_freq)
group = uniques[mask]
if group is not None:
try:
values[np.isin(values, group)] = group[0]
except IndexError:
pass
return values, group
else:
return values, None


def _group_values(values, min_freq=0, group=None):
if values.dtype == object:
return _group_values_python(values, min_freq, group)
else:
return _group_values_numpy(values, min_freq, group)
73 changes: 73 additions & 0 deletions sklearn/preprocessing/tests/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing._encoders import _group_values


def toarray(a):
Expand Down Expand Up @@ -515,6 +516,13 @@ def test_one_hot_encoder_raise_missing(X, handle_unknown):
with pytest.raises(ValueError, match="Input contains NaN"):
ohe.transform(X)

def test_one_hot_encoder_min_freq_fit_not_inplace():
arr = np.array([[1, 1], [2, 3], [1, 2], [3, 2], [1, 4]])
expected = arr.copy()
enc = OneHotEncoder(min_freq=0.4, categories='auto')
enc.fit(arr)
assert_array_equal(arr, expected)


@pytest.mark.parametrize("X", [
[['abc', 2, 55], ['def', 1, 55]],
Expand Down Expand Up @@ -613,3 +621,68 @@ def test_one_hot_encoder_warning():
def test_categorical_encoder_stub():
from sklearn.preprocessing import CategoricalEncoder
assert_raises(RuntimeError, CategoricalEncoder, encoding='ordinal')


def test_ordinal_encoder_min_freq_fit_not_inplace():
arr = np.array([[1, 1], [2, 3], [1, 2], [3, 2], [1, 4]])
expected = arr.copy()
enc = OrdinalEncoder(min_freq=0.4)
enc.fit(arr)
assert_array_equal(arr, expected)


@pytest.mark.parametrize(
"values, min_freq, exp_values, exp_group",
[(np.array([1, 2, 3, 3, 3], dtype='int64'),
0.3,
np.array([1, 1, 3, 3, 3], dtype='int64'),
np.array([1, 2])),
(np.array([1, 2, 3, 3, 3], dtype='int64'),
0.7,
np.array([1, 1, 1, 1, 1], dtype='int64'),
np.array([1, 2, 3])),
(np.array([1, 2, 3, 3, 3], dtype='int64'),
0.2,
np.array([1, 2, 3, 3, 3], dtype='int64'),
np.array([], dtype='int64')),
(np.array([1, 2, 3, 3, 3], dtype='int64'),
0,
np.array([1, 2, 3, 3, 3], dtype='int64'),
None),
(np.array(['a', 'b', 'c', 'c', 'c'], dtype=object),
0.3,
np.array(['a', 'a', 'c', 'c', 'c'], dtype=object),
np.array(['a', 'b'])),
(np.array(['a', 'b', 'c', 'c', 'c'], dtype=object),
0.7,
np.array(['a', 'a', 'a', 'a', 'a'], dtype=object),
np.array(['a', 'b', 'c'])),
(np.array(['a', 'b', 'c', 'c', 'c'], dtype=object),
0.2,
np.array(['a', 'b', 'c', 'c', 'c'], dtype=object),
np.array([], dtype=object)),
(np.array(['a', 'b', 'c', 'c', 'c'], dtype=object),
0,
np.array(['a', 'b', 'c', 'c', 'c'], dtype=object),
None),
(np.array(['a', 'b', 'c', 'c', 'c'], dtype=str),
0.3,
np.array(['a', 'a', 'c', 'c', 'c'], dtype=str),
np.array(['a', 'b'])),
(np.array(['a', 'b', 'c', 'c', 'c'], dtype=str),
0.7,
np.array(['a', 'a', 'a', 'a', 'a'], dtype=str),
np.array(['a', 'b', 'c'])),
(np.array(['a', 'b', 'c', 'c', 'c'], dtype=str),
0.2,
np.array(['a', 'b', 'c', 'c', 'c'], dtype=str),
np.array([], dtype=str)),
(np.array(['a', 'b', 'c', 'c', 'c'], dtype=str),
0,
np.array(['a', 'b', 'c', 'c', 'c'], dtype=str),
None)],
ids=(['int64']*4 + ['object']*4 + ['str']*4))
def test_group_values_freq(values, min_freq, exp_values, exp_group):
values, group = _group_values(values, min_freq)
assert_array_equal(values, exp_values)
assert_array_equal(group, exp_group)
0