diff --git a/sklearn/preprocessing/label.py b/sklearn/preprocessing/label.py index f2f7d9afad347..35bdcdaf7aba5 100644 --- a/sklearn/preprocessing/label.py +++ b/sklearn/preprocessing/label.py @@ -112,6 +112,9 @@ def fit(self, y): y = column_or_1d(y, warn=True) _check_numpy_unicode_bug(y) self.classes_ = np.unique(y) + self.classes_lookup = defaultdict(int) + for i, v in enumerate(self.classes_): + self.classes_lookup[v] = i return self def fit_transform(self, y): @@ -128,8 +131,14 @@ def fit_transform(self, y): """ y = column_or_1d(y, warn=True) _check_numpy_unicode_bug(y) - self.classes_, y = np.unique(y, return_inverse=True) - return y + + self.classes_ = np.unique(y) + + self.classes_lookup = defaultdict(int) + for i, v in enumerate(self.classes_): + self.classes_lookup[v] = i + transformer = np.vectorize(lambda x: self.classes_lookup.get(x)) + return transformer(y) def transform(self, y): """Transform labels to normalized encoding. @@ -145,13 +154,16 @@ def transform(self, y): """ check_is_fitted(self, 'classes_') y = column_or_1d(y, warn=True) - classes = np.unique(y) _check_numpy_unicode_bug(classes) + if len(np.intersect1d(classes, self.classes_)) < len(classes): diff = np.setdiff1d(classes, self.classes_) + # for item in diff: + # self.expand_classes([item]) raise ValueError("y contains new labels: %s" % str(diff)) - return np.searchsorted(self.classes_, y) + transformer = np.vectorize(lambda x: self.classes_lookup.get(x)) + return transformer(y) def inverse_transform(self, y): """Transform labels back to original encoding. @@ -170,8 +182,31 @@ def inverse_transform(self, y): diff = np.setdiff1d(y, np.arange(len(self.classes_))) if diff: raise ValueError("y contains new labels: %s" % str(diff)) - y = np.asarray(y) - return self.classes_[y] + + transformer = np.vectorize(lambda x: ( + k for k, v in self.classes_lookup.items() if v == x).next()) + return transformer(y) + + def expand_classes(self, y): + """Index new labels, and return their new normalized encoding + + Parameters + ---------- + y : numpy array of shape [n_samples] + Target values. + + Returns + ------- + y : numpy array of shape [n_samples] + + """ + + for item in y: + if item not in self.classes_lookup: + self.classes_ = np.append(self.classes_, [item]) + self.classes_lookup[item] = len(self.classes_) - 1 + + return self.classes_lookup[item] class LabelBinarizer(BaseEstimator, TransformerMixin): @@ -696,6 +731,7 @@ class MultiLabelBinarizer(BaseEstimator, TransformerMixin): sklearn.preprocessing.OneHotEncoder : encode categorical integer features using a one-hot aka one-of-K scheme. """ + def __init__(self, classes=None, sparse_output=False): self.classes = classes self.sparse_output = sparse_output @@ -845,5 +881,4 @@ def inverse_transform(self, yt): if len(unexpected) > 0: raise ValueError('Expected only 0s and 1s in label indicator. ' 'Also got {0}'.format(unexpected)) - return [tuple(self.classes_.compress(indicators)) for indicators - in yt] + return [tuple(self.classes_.compress(indicators)) for indicators in yt] \ No newline at end of file diff --git a/sklearn/preprocessing/tests/test_label.py b/sklearn/preprocessing/tests/test_label.py index f48ad29bd29b5..1b3514883a7ec 100644 --- a/sklearn/preprocessing/tests/test_label.py +++ b/sklearn/preprocessing/tests/test_label.py @@ -167,7 +167,6 @@ def test_label_binarizer_errors(): assert_raises(ValueError, label_binarize, np.array([[1, 3], [2, 1]]), [1, 2, 3]) - def test_label_encoder(): # Test LabelEncoder's transform and inverse_transform methods le = LabelEncoder() @@ -184,6 +183,7 @@ def test_label_encoder(): assert_raise_message(ValueError, msg, le.transform, "apple") + def test_label_encoder_fit_transform(): # Test fit_transform le = LabelEncoder() @@ -195,6 +195,24 @@ def test_label_encoder_fit_transform(): assert_array_equal(ret, [1, 1, 2, 0]) +def test_label_encoder_expand_classes(): + # Test expand_classes + le = LabelEncoder() + ret = le.fit_transform([1, 1, 4, 5, -1, 0]) + assert_array_equal(ret, [2, 2, 3, 4, 0, 1]) + + le = LabelEncoder() + ret = le.fit_transform(["paris", "paris", "tokyo", "berlin"]) + + assert_array_equal(ret, [1, 1, 2, 0]) + # case where we expand classes + c_ny = le.expand_classes(["new york"]) + c_sy = le.expand_classes(["sydney"]) + + ret = le.inverse_transform([1, 2, c_ny, c_sy]) + assert_array_equal(ret, ["paris", "tokyo", "new york", "sydney"] ) + + def test_label_encoder_errors(): # Check that invalid arguments yield ValueError le = LabelEncoder()