-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG + 1] Issue#8062: JoblibException thrown when passing "fit_params={'sample_… #8068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…weights': weights}" to RandomizedSearchCV with RandomForestClassifier
please add a test |
I added a test to test_forest.py:check_class_weights(), I considered creating a separate function for it, but seemed excessive, at the same time that small test seemed out of place in there, so I placed it at the end, with a longer comment. The test itself just recreates the conditions mentioned in the bug report, and tries to trigger the exception raised when Python 2.x lists are used that have no copy() method. Doing a type check on sample_weights seemed unidiomatic. And doing deeper testing to make sure even Python 3.x lists weren't used seemed pointless, since when those are used the feature_importances_ satisfied the almost_equal assertion. The issue I mentioned earlier with the full make rebuild just happens on Python 2.x with no code changes, the rebuild on Python 3.x works fine. |
# method. On success the list should be automatically converted. | ||
clf = ForestClassifier() | ||
sample_weight = [1.] * len(iris.data) | ||
clf.fit(iris.data, iris.target, sample_weight=sample_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to add a test that fails on master to show that your PR is actually the right fix. AFAICT this test does not fail on master.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm this seems a Python 2 only problem, my bad, this test is fine.
The AppVeyor failure is unrelated (we have an issue for it already IIRC), so LGTM. |
@@ -947,6 +947,13 @@ def check_class_weights(name): | |||
clf2.fit(iris.data, iris.target, sample_weight) | |||
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_) | |||
|
|||
# When sample_weights is an unsupported array type, it checks if it raises | |||
# an exception. With Python 2.x lists it complains there is no copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is a bit odd as not it doesn't fail. Maybe say "used to fail on python2 for lists" or something. This checks if it passes, not if it raises.
apart from somewhat cryptic comment LGTM |
LTGM. Thx. Merging. |
…={'sample_… (scikit-learn#8068) * Issue#8062: JoblibException thrown when passing "fit_params={'sample_weights': weights}" to RandomizedSearchCV with RandomForestClassifier * Added test for issues scikit-learn#8068 and scikit-learn#8064. * Clean up with pyflakes. * Changed cryptic comment.
…={'sample_… (scikit-learn#8068) * Issue#8062: JoblibException thrown when passing "fit_params={'sample_weights': weights}" to RandomizedSearchCV with RandomForestClassifier * Added test for issues scikit-learn#8068 and scikit-learn#8064. * Clean up with pyflakes. * Changed cryptic comment.
…={'sample_… (scikit-learn#8068) * Issue#8062: JoblibException thrown when passing "fit_params={'sample_weights': weights}" to RandomizedSearchCV with RandomForestClassifier * Added test for issues scikit-learn#8068 and scikit-learn#8064. * Clean up with pyflakes. * Changed cryptic comment.
…={'sample_… (scikit-learn#8068) * Issue#8062: JoblibException thrown when passing "fit_params={'sample_weights': weights}" to RandomizedSearchCV with RandomForestClassifier * Added test for issues scikit-learn#8068 and scikit-learn#8064. * Clean up with pyflakes. * Changed cryptic comment.
…={'sample_… (scikit-learn#8068) * Issue#8062: JoblibException thrown when passing "fit_params={'sample_weights': weights}" to RandomizedSearchCV with RandomForestClassifier * Added test for issues scikit-learn#8068 and scikit-learn#8064. * Clean up with pyflakes. * Changed cryptic comment.
…weights': weights}" to RandomizedSearchCV with RandomForestClassifier
Reference Issue
Fixes #8062
What does this implement/fix? Explain your changes.
When the sample_weight parameter of the fit() method of a RandomForestClassifier is a python 2.x list, it raises an exception, because on a later call to _parallel_build_trees(), the copy() method is used, which it exists in python 3.x lists, panda series and nparrays, it doesn't exist however on python 2.x lists. To be clear, a python 3.x list here would also be inappropriate because even though operations like multiplication on it would succeed, the results would be different from the ones obtained from an nparray. It does explain however how the exception only happened with python 2.x lists. Another detail from the bug report was that the exception didn't happen when using a ExtraTreesClassifier. This was so because the bootstrap parameter on the ExtraTreeClassifier is False by default, which later in the _parallel_build_trees() call will divert the logic flow away from the copy() method call. If we set the bootstrap parameter on the ExtraTreesClassigier constructor, the behavior becomes the same, and the same exception is raised.
As suggested by @jnothman, I used the check_array utility to transform the original sample_weight array into a nparray, on the BaseForest class. With a little pre-check to see if the sample_weight parameter is None, since check_array raises an exception otherwise. And an additional parameter to make sure the array created isn't 2d.
I added no tests to check for this issue.A test was added.
I tried the full make rebuild with no changes to the source, and it failed several tests. So I ran the full make again with my changes and just corrected what was different. So I'm not sure if all relevant tests were done.
pyflakes and pep8 showed no irregularities.
coverage for the module changed was at 65%, I'm not sure how relevant this is for such a small change, but there you have it.