8000 [WIP] Fixes #8136: Added support for new labels by tzano · Pull Request #8164 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] Fixes #8136: Added support for new labels #8164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
51 changes: 43 additions & 8 deletions sklearn/preprocessing/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def fit(self, y):
y = column_or_1d(y, warn=True)
_check_numpy_unicode_bug(y)
self.classes_ = np.unique(y)
self.classes_lookup = defaultdict(int)
for i, v in enumerate(self.classes_):
self.classes_lookup[v] = i
return self

def fit_transform(self, y):
Expand All @@ -128,8 +131,14 @@ def fit_transform(self, y):
"""
y = column_or_1d(y, warn=True)
_check_numpy_unicode_bug(y)
self.classes_, y = np.unique(y, return_inverse=True)
return y

self.classes_ = np.unique(y)

self.classes_lookup = defaultdict(int)
for i, v in enumerate(self.classes_):
self.classes_lookup[v] = i
transformer = np.vectorize(lambda x: self.classes_lookup.get(x))
return transformer(y)

def transform(self, y):
"""Transform labels to normalized encoding.
Expand All @@ -145,13 +154,16 @@ def transform(self, y):
"""
check_is_fitted(self, 'classes_')
y = column_or_1d(y, warn=True)

classes = np.unique(y)
_check_numpy_unicode_bug(classes)

if len(np.intersect1d(classes, self.classes_)) < len(classes):
diff = np.setdiff1d(classes, self.classes_)
# for item in diff:
# self.expand_classes([item])
raise ValueError("y contains new labels: %s" % str(diff))
return np.searchsorted(self.classes_, y)
transformer = np.vectorize(lambda x: self.classes_lookup.get(x))
return transformer(y)

def inverse_transform(self, y):
"""Transform labels back to original encoding.
Expand All @@ -170,8 +182,31 @@ def inverse_transform(self, y):
diff = np.setdiff1d(y, np.arange(len(self.classes_)))
if diff:
raise ValueError("y contains new labels: %s" % str(diff))
y = np.asarray(y)
return self.classes_[y]

transformer = np.vectorize(lambda x: (
k for k, v in self.classes_lookup.items() if v == x).next())
return transformer(y)

def expand_classes(self, y):
"""Index new labels, and return their new normalized encoding

Parameters
----------
y : numpy array of shape [n_samples]
Target values.

Returns
-------
y : numpy array of shape [n_samples]

"""

for item in y:
if item not in self.classes_lookup:
self.classes_ = np.append(self.classes_, [item])
self.classes_lookup[item] = len(self.classes_) - 1

return self.classes_lookup[item]


class LabelBinarizer(BaseEstimator, TransformerMixin):
Expand Down Expand Up @@ -696,6 +731,7 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin):
sklearn.preprocessing.OneHotEncoder : encode categorical integer features
using a one-hot aka one-of-K scheme.
"""

def __init__(self, classes=None, sparse_output=False):
self.classes = classes
self.sparse_output = sparse_output
Expand Down Expand Up @@ -845,5 +881,4 @@ def inverse_transform(self, yt):
if len(unexpected) > 0:
raise ValueError('Expected only 0s and 1s in label indicator. '
'Also got {0}'.format(unexpected))
return [tuple(self.classes_.compress(indicators)) for indicators
in yt]
return [tuple(self.classes_.compress(indicators)) for indicators in yt]
20 changes: 19 additions & 1 deletion sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def test_label_binarizer_errors():
assert_raises(ValueError, label_binarize, np.array([[1, 3], [2, 1]]),
[1, 2, 3])


def test_label_encoder():
# Test LabelEncoder's transform and inverse_transform methods
le = LabelEncoder()
Expand All @@ -184,6 +183,7 @@ def test_label_encoder():
assert_raise_message(ValueError, msg, le.transform, "apple")



def test_label_encoder_fit_transform():
# Test fit_transform
le = LabelEncoder()
Expand All @@ -195,6 +195,24 @@ def test_label_encoder_fit_transform():
assert_array_equal(ret, [1, 1, 2, 0])


def test_label_encoder_expand_classes():
# Test expand_classes
le = LabelEncoder()
ret = le.fit_transform([1, 1, 4, 5, -1, 0])
assert_array_equal(ret, [2, 2, 3, 4, 0, 1])

le = LabelEncoder()
ret = le.fit_transform(["paris", "paris", "tokyo", "berlin"])

assert_array_equal(ret, [1, 1, 2, 0])
# case where we expand classes
c_ny = le.expand_classes(["new york"])
c_sy = le.expand_classes(["sydney"])

ret = le.inverse_transform([1, 2, c_ny, c_sy])
assert_array_equal(ret, ["paris", "tokyo", "new york", "sydney"] )


def test_label_encoder_errors():
# Check that invalid arguments yield ValueError
le = LabelEncoder()
Expand Down
0