-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG] Fixes tree and forest classification for non-numeric multi-target #11458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Any update on this PR? |
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS) | ||
@pytest.mark.parametrize('oob_score', (True, False)) | ||
def test_multi_target(name, oob_score): | ||
check_multi_target(name, oob_score) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason for not having the body of the check_multi_target
function directly here?
sklearn/tree/tests/test_tree.py
Outdated
|
||
@pytest.mark.parametrize('name', CLF_TREES) | ||
def test_multi_target(name): | ||
check_multi_target(name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here.
sklearn/tree/tree.py
Outdated
|
||
return predictions | ||
return np.array(predictions).T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would try to figure the dtype of the array, right? how much is it slower than the status quo?
You also need to rebase/merge master, you've got conflicts. Other than that, I'm really not sure if this is a good idea. How many other estimators do we have that support string outputs? I suppose the recommended way is to convert the values before feeding them to estimators. I may be wrong. |
We support string targets where 1d (i.e. single target). I'm not entirely against supporting strong labels in multi output, but it should be by making sure that all estimators with multi output multiclass support, and any metrics, support this case. Let alone the case of mixed numeric and string data. At the moment I can't see that we test multi output multiclass in common tests at all. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does your implementation do with a mix of string and numeric targets?
e215890
to
b26e943
Compare
5c5ecff
to
fcd597a
Compare
I've refactored the tests and the code a bit. Also PR is rebased to master.
Returns:
Doing the same with a regressor would fail as targets need to be numerical. Training target array is upcast to one dtype and we will get an array with the same dtype back from A way to support a real mix of dtypes would be with structured array, but I don't know if we really want to do that? |
# Make multi-target. | ||
ys = np.hstack([y, y]) | ||
|
||
# Try to fix and predict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix -> fit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed :)
y = np.array(['foo' if v else 'bar' for v in y]).reshape((y.shape[0], 1)) | ||
|
||
# Make multi-target. | ||
ys = np.hstack([y, y]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try with a string and a numerical column just to be on the safe side in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
This looks good to me. Since Joel mentioned it, could you please kindly try adding the same test on common tests ( |
Thaks for the quick responese @adrinjalali! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would support doing common tests in a separate PR, ideally remembering to remove these tests as redundant.
Thank @adrinjalali, @mitar and @jnothman. |
…scikit-learn#11458) * Fixes tree and forest classification for non-numeric multi-target. Fixes scikit-learn#11451. * Renaming test functions, adding dtype to predictions array in tree.py. * Fixing flake8 issue. * Adding ignore warning to test_forest.py. * Switching to iris data for tests.
…i-target (scikit-learn#11458)" This reverts commit f95ffe6.
…i-target (scikit-learn#11458)" This reverts commit f95ffe6.
…scikit-learn#11458) * Fixes tree and forest classification for non-numeric multi-target. Fixes scikit-learn#11451. * Renaming test functions, adding dtype to predictions array in tree.py. * Fixing flake8 issue. * Adding ignore warning to test_forest.py. * Switching to iris data for tests.
Fixes #11451.
This fixes the issue that trees and forests cannot classify (but they can fit) non-numeric targets, when there are multiple targets.