-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH support multilabel targets in LabelEncoder #1987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
96dbd77
8669605
82073e4
e162bc1
2eac97f
8085fb0
5ddcf5c
507e848
07eebcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
from sklearn.utils.testing import assert_raises | ||
from sklearn.utils.testing import assert_true | ||
from sklearn.utils.testing import assert_false | ||
from sklearn.utils.testing import assert_sequences_equal | ||
|
||
from sklearn.utils.sparsefuncs import mean_variance_axis0 | ||
from sklearn.preprocessing import Binarizer | ||
|
@@ -510,7 +511,7 @@ def test_label_binarizer_multilabel(): | |
[1, 1, 0]]) | ||
got = lb.fit_transform(inp) | ||
assert_array_equal(indicator_mat, got) | ||
assert_equal(lb.inverse_transform(got), inp) | ||
assert_sequences_equal(lb.inverse_transform(got), inp) | ||
|
||
# test input as label indicator matrix | ||
lb.fit(indicator_mat) | ||
|
@@ -527,8 +528,7 @@ def test_label_binarizer_multilabel(): | |
[1, 1]]) | ||
got = lb.fit_transform(inp) | ||
assert_array_equal(expected, got) | ||
assert_equal([set(x) for x in lb.inverse_transform(got)], | ||
[set(x) for x in inp]) | ||
assert_sequences_equal(lb.inverse_transform(got), inp) | ||
|
||
|
||
def test_label_binarizer_errors(): | ||
|
@@ -612,17 +612,47 @@ def test_label_encoder(): | |
assert_raises(ValueError, le.transform, [0, 6]) | ||
|
||
|
||
def test_label_encoder_multilabel(): | ||
"""Test LabelEncoder's transform and inverse_transform methods with | ||
multilabel data""" | ||
le = LabelEncoder() | ||
le.fit([[1], [1, 4], [5, -1, 0]]) | ||
assert_array_equal(le.classes_, [-1, 0, 1, 4, 5]) | ||
assert_sequences_equal(le.transform([[0, 1, 4], [4, 5, -1], [-1]]), | ||
[[1, 2, 3], [3, 4, 0], [0]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pep8 |
||
assert_sequences_equal(le.inverse_transform([[1, 2, 3], [3, 4, 0], [0]]), | ||
[[0, 1, 4], [4, 5, -1], [-1]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pep8 |
||
assert_raises(ValueError, le.transform, [[0, 6]]) | ||
# not handling label encoder matrices presently | ||
assert_raises(ValueError, le.transform, np.array([[0, 1], [1, 0]])) | ||
|
||
|
||
def test_label_encoder_fit_transform(): | ||
"""Test fit_transform""" | ||
le = LabelEncoder() | ||
ret = le.fit_transform([1, 1, 4, 5, -1, 0]) | ||
assert_array_equal(ret, [2, 2, 3, 4, 0, 1]) | ||
assert_array_equal(le.classes_, [-1, 0, 1, 4, 5]) | ||
|
||
le = LabelEncoder() | ||
ret = le.fit_transform(["paris", "paris", "tokyo", "amsterdam"]) | ||
assert_array_equal(ret, [1, 1, 2, 0]) | ||
|
||
|
||
def test_label_encoder_fit_transform_multilabel(): | ||
"""Test fit_transform for multilabel input""" | ||
le = LabelEncoder() | ||
ret = le.fit_transform([[1], [1, 4, 5], [-1, 0]]) | ||
assert_sequences_equal(ret, [[2], [2, 3, 4], [0, 1]]) | ||
assert_array_equal(le.classes_, [-1, 0, 1, 4, 5]) | ||
|
||
le = LabelEncoder() | ||
ret = le.fit_transform([["paris"], ["paris", "tokyo", "amsterdam"]]) | ||
assert_sequences_equal(ret, [[1], [1, 2, 0]]) | ||
# not handling label encoder matrices presently | ||
assert_raises(ValueError, le.transform, np.array([[0, 1], [1, 0]])) | ||
|
||
|
||
def test_label_encoder_string_labels(): | ||
"""Test LabelEncoder's transform and inverse_transform methods with | ||
non-numeric labels""" | ||
|
@@ -636,6 +666,19 @@ def test_label_encoder_string_labels(): | |
assert_raises(ValueError, le.transform, ["london"]) | ||
|
||
|
||
def test_label_encoder_strings_multilabel(): | ||
"""Test LabelEncoder's transform and inverse_transform methods with | ||
non-numeric multilabel data""" | ||
le = LabelEncoder() | ||
le.fit([["paris"], ["paris", "tokyo", "amsterdam"]]) | ||
assert_array_equal(le.classes_, ["amsterdam", "paris", "tokyo"]) | ||
assert_sequences_equal(le.transform([["tokyo"], ["tokyo", "paris"]]), | ||
[[2], [2, 1]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pep8 |
||
assert_sequences_equal(le.inverse_transform([[2], [2, 1]]), | ||
[["tokyo"], ["tokyo", "paris"]]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pep8 |
||
assert_raises(ValueError, le.transform, ["london"]) | ||
|
||
|
||
def test_label_encoder_errors(): | ||
"""Check that invalid arguments yield ValueError""" | ||
le = LabelEncoder() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,5 +130,7 @@ def is_multilabel(y): | |
""" | ||
# the explicit check for ndarray is for forward compatibility; future | ||
# versions of Numpy might want to register ndarray as a Sequence | ||
return (not isinstance(y[0], np.ndarray) and isinstance(y[0], Sequence) and | ||
not isinstance(y[0], string_types) or is_label_indicator_matrix(y)) | ||
if getattr(y, 'ndim', 1) != 1: | ||
return is_label_indicator_matrix(y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have a test case in mind for this modification? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The point is to allow the outer sequence to be an array. But it can't be an array of size 1. Did I not test it? Ah well. |
||
return ((isinstance(y[0], Sequence) and not isinstance(y[0], string_types)) | ||
or isinstance(y[0], np.ndarray)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
import inspect | ||
import pkgutil | ||
|
||
import numpy as np | ||
import scipy as sp | ||
from functools import wraps | ||
try: | ||
|
@@ -97,6 +98,25 @@ def assert_raise_message(exception, message, function, *args, **kwargs): | |
assert_in(message, error_message) | ||
|
||
|
||
def assert_sequences_equal(first, second, err_msg=''): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the advantage of this new function compare to There was a problem hiding this comment. F987 Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we do without and use only arrays with dtype=object? That way we make sure in the tests that we are manipulating arrays, and not lists. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i have discussed arrays of dtype object below. |
||
"""Asserts equality of two sequences of sequences | ||
|
||
This compares multilabel targets irrespective of the sequence types. | ||
It is necessary because sequence types vary, `assert_array_equal` may | ||
misinterpret some formats as 2-dimensional. | ||
""" | ||
# TODO: first assert args are valid sequences of sequences | ||
if err_msg: | ||
err_msg = '\n' + err_msg | ||
assert_equal(len(first), len(second), | ||
'Sequence of sequence lengths do not match.' | ||
'{}'.format(err_msg)) | ||
for i, (first_el, second_el) in enumerate(zip(first, second)): | ||
assert_array_equal(np.unique(first_el), np.unique(second_el), | ||
'In sequence of sequence element {}' | ||
'{}'.format(i, err_msg)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that you mean
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're quite right. Fixed. |
||
|
||
|
||
def fake_mldata(columns_dict, dataname, matfile, ordering=None): | ||
"""Create a fake mldata data set. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer (syntacticaly) a list comprehension:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think that it would be better if we returned an array, with dtype=object, instead of a list. A lot of code expects to find arrays, for instance testing the shape of the object, or applying fancy indexing. Besides, having the return type of a method vary can be a good recipe for later bugs.