8000 [WIP] Fix gradient boosting overflow by chenhe95 · Pull Request #7959 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] Fix gradient boosting overflow #7959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

chenhe95
Copy link
Contributor
@chenhe95 chenhe95 commented Dec 1, 2016

Reference Issue

Fix #7717

What does this implement/fix? Explain your changes.

Before, the code was using == to compare float values and dividing by "zero (~10e-309)" which caused an overflow.

        if denominator == 0:
            tree.value[leaf, 0, 0] = 0.0
        else:
            tree.value[leaf, 0, 0] = numerator / denominator`

Now I made it so that it's

        if np.isclose(denominator, 0.0):
            tree.value[leaf, 0, 0] = 0.0
        else:
            tree.value[leaf, 0, 0] = numerator / denominator

chenhe95 and others added 2 commits December 1, 2016 10:33
…scikit-learn#7838)

* initial commit for return_std

* initial commit for return_std

* adding tests, examples, ARD predict_std

* adding tests, examples, ARD predict_std

* a smidge more documentation

* a smidge more documentation

* Missed a few PEP8 issues

* Changing predict_std to return_std scikit-learn#1

* Changing predict_std to return_std scikit-learn#2

* Changing predict_std to return_std scikit-learn#3

* Changing predict_std to return_std final

* adding better plots via polynomial regression

* trying to fix flake error

* fix to ARD plotting issue

* fixing some flakes

* Two blank lines part 1

* Two blank lines part 2

* More newlines!

* Even more newlines

* adding info to the doc string for the two plot files

* Rephrasing "polynomial" for Bayesian Ridge Regression

* Updating "polynomia" for ARD

* Adding more formal references

* Another asked-for improvement to doc string.

* Fixing flake8 errors

* Cleaning up the tests a smidge.

* A few more flakes

* requested fixes from Andy

* Mini bug fix

* Final pep8 fix

* pep8 fix round 2

* Fix beta_ to alpha_ in the comments
@lesteve
Copy link
Member
lesteve commented Dec 1, 2016

np.isclose was added in numpy 1.7 and we test with numpy 1.6.2 in our build matrix, this is why Travis is failing. Not sure what the best way is to tackle this, maybe compare abs(denominator) to some small value.

@lesteve
Copy link
Member
lesteve commented Dec 1, 2016

There are other places using checking if denominator == 0. in the same file, do they need to be changed as well? Maybe check across the whole scikit-learn code while you are at it.

It would be great to add tests for this "close to 0" behaviour, not sure how easy this is.

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 1, 2016

Thank you for the feedback. I'll take a look at the other cases where there is a float comparison on ==
Would anyone happen to know if scikit learn has a built in function for np.isclose()?
Maybe I could add one in utils or something? It should be a simple boolean comparison of abs(x - y) < epsilon
From a code design / readability perspective, I do not like using abs(x - y) < epsilon for every instance of ==

@chenhe95 chenhe95 changed the title [MRG] Fix gradient boosting overflow [WIP] Fix gradient boosting overflow Dec 1, 2016
@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 1, 2016

I was thinking of copy pasting the isclose() from numpy
And including the numpy license

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

I am not sure if that is acceptable or a good idea or not, but I feel that a built-in, universal, standardized way for comparing float64 scalars and matrices is very good to have.

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 2, 2016

I have been reading up on it and it seems that isclose() used to be there but was removed
#4864, #7353
Because it seems that it was causing a bug in ROC/AUC where very very small differences below the threshold mattered, and instead of adding warnings in the documentation, it was decided that removing the function would be better. I disagree with this decision from a software engineering modularity / separation of concerns point of view, but if the community does not want support for this function, then I can't really complain.

I suppose abs(x - y) < epsilon for every case of float comparison will have to do.

Also speaking of float comparisons, I found this issue which also seems to have various problems with assert_equals on floats. #4400

@lesteve
Copy link
Member
lesteve commented Dec 2, 2016

I have been reading up on it and it seems that isclose() used to be there but was removed
#4864, #7353

Interesting I did not know that. I'll take a closer look.

abs(x - y) < epsilon seems good enough for now. There would also be an option of a backport of math.isclose in Python 2.

@amueller
Copy link
Member
amueller commented Dec 2, 2016

@chenhe95 The function was removed because it was no longer needed. It was a bug in the ROC curve and that was the only place where the function was used. As the use got removed, so was the function. Feel free to add the backport back in if there is a new need.

@amueller
Copy link
Member
amueller commented Dec 2, 2016

wow apparently I don't know a lot about floating point numbers. I was surprised that you can represent x but not 1/x but I guess the reason is the exponent bias:
https://en.wikipedia.org/wiki/Double-precision_floating-point_format#Exponent_encoding

Shouldn't we use as epsilon here something close to 1/np.finfo("double").max?
Maybe that's a bit extreme but the code should run fine with 1e-300.

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 2, 2016

It's kind of tricky
If used in the context of numerator / denominator where denominator is 1e-309 and numerator is 100, it can still overflow, so I felt like we could use something that was relatively safe (1e-150?) or include checking the exponent of the numerator as well.

@amueller
Copy link
Member
amueller commented Dec 2, 2016

yeah 1e-150 is fine :)

@lesteve
Copy link
Member
lesteve commented Dec 2, 2016

yeah 1e-150 is fine :)

Should this not depend on the data used? Maybe it's not always float64

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 2, 2016

@lesteve I think it is fine because:
Assume that denominator is float32 and is 1 / np.finfo("float32").max
Assume we have atol = np.float64(1e-150)
Then clearly np.isclose(denominator, 0., rtol=0., atol=atol) will return False
Then it will do numerator / denominator. However, no overflow error will occur because numerator / denominator will be boxed to a float64.

edit: If the inputs are matrices, then it will do element-wise np.isclose() which is another reason why I really like the function.

Here is also a link to the numpy code I am going to port in
https://github.com/numpy/numpy/blob/v1.11.0/numpy/core/numeric.py#L2375-L2474

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 2, 2016

It seems that the merge failed in updating my branch to master.
The PR is continued in #7970
Apologies for any inconveniences

@chenhe95 chenhe95 closed this Dec 2, 2016
@chenhe95 chenhe95 deleted the GradientBoostingOverflow branch December 2, 2016 22:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

possible overflow bug in gradient boosting
6 participants
0