8000 [MRG + 1] Fix gradient boosting overflow and various other float comparison on == by chenhe95 · Pull Request #7970 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG + 1] Fix gradient boosting overflow and various other float comparison on == #7970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 7, 2017

Conversation

chenhe95
Copy link
Contributor
@chenhe95 chenhe95 commented Dec 2, 2016

Reference Issue

Fix #7717

What does this implement/fix? Explain your changes.

Before, the code was using == to compare float values and dividing by "zero (~10e-309)" which caused an overflow.

        if denominator == 0:
            tree.value[leaf, 0, 0] = 0.0
        else:
            tree.value[leaf, 0, 0] = numerator / denominator`

Now I made it so that it's

        if np.isclose(denominator, 0.0):
            tree.value[leaf, 0, 0] = 0.0
        else:
            tree.value[leaf, 0, 0] = numerator / denominator

There are several other instances of this happening, which may cause an error and I want to also address those later on.

In addition, this brings back the numpy.isclose() method which is a standardized way of computing if two float scalars or matrices of arbitrary size are almost close to a threshold.

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 2, 2016

I am wondering how float comparisons were handled before, when there was no isclose() function.
Was it just something like (x - y) < 1e-15 on the spot?

@amueller
Copy link
Member
amueller commented Dec 2, 2016

@chenhe95 yeah usually you want "close to zero" so you can always do norm < eps

@@ -511,7 +512,7 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
numerator = np.sum(sample_weight * residual)
denominator = np.sum(sample_weight * (y - residual) * (1 - y + residual))

if denominator == 0.0:
if isclose(denominator, 0., rtol=0., atol=np.float64(1e-150)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

silly question but is that a scalar? I feel the code denominator < np.float64(1e-150) easier to understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turned out a lot messier than I had anticipated.
I originally planned to just use isclose(denominator, 0.) which seemed pretty clean, but the default tolerance was only at 1e-8.
I suppose it is good to just do abs(denominator) < 1e-150 here.

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 3, 2016

I also went and did some pep8 housekeeping on the file gradient_boosting.py
The only 2 pep8 violations which were not fixed were

sklearn/ensemble/gradient_boosting.py:1035:20: E712 comparison to True should be 'if cond is True:' or 'if cond:'
sklearn/ensemble/gradient_boosting.py:1446:80: E501 line too long (87 > 79 characters)

Which I wasn't quite sure what to do

(from docstring)
    estimators_ : ndarray of DecisionTreeRegressor, shape = [n_estimators, ``loss_.K``]
        The collection of fitted sub-estimators. ``loss_.K`` is 1 for binary
        classification, otherwise n_classes.
        if presort == True:
            if issparse(X):
                raise ValueError(
                    "Presorting is not supported for sparse matrices.")
            else:
                X_idx_sorted = np.asfortranarray(np.argsort(X, axis=0),
                                                 dtype=np.int32)

And here flake8 suggested doing just if presort: but what if presort was a non-empty list and we only wanted the if statement to pass if presort is literally True?

Let me know what you guys think

Copy link
Member
@raghavrv raghavrv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from reverting all PEP8 changes unrelated to the PR, this LGTM... Thx!

tree.value[leaf, 0, 0] = _weighted_percentile(diff, sample_weight, percentile=50)
diff = (y.take(terminal_region, axis=0) -
pred.take(terminal_region, axis=0))
tree.value[leaf, 0, 0] = \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you avoid backslash and rather do

_weighted_percentile(diff
                     sample_weight...)

@@ -375,7 +380,8 @@ def negative_gradient(self, y, pred, sample_weight=None, **kargs):
if sample_weight is None:
gamma = stats.scoreatpercentile(np.abs(diff), self.alpha * 100)
else:
gamma = _weighted_percentile(np.abs(diff), sample_weight, self.alpha * 100)
gamma = _weighted_percentile(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you cleaning up flake8 issues for code that is not modified in this PR... It creates merged conflicts with other PRs. In general we try to enforce flake8 only for the code that is being modified in the PR... +1 for reverting these changes...

@@ -634,7 +646,8 @@ def _update_terminal_region(self, tree, terminal_regions, leaf, X, y,
numerator = np.sum(y_ * sample_weight * np.exp(-y_ * pred))
denominator = np.sum(sample_weight * np.exp(-y_ * pred))

if denominator == 0.0:
# prevents overflow and division by zero
if abs(denominator) < 1e-150:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of 1e-150 you could use np.finfo(np.float32).eps... There is precedent for it here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. I am not sure. @raghavrv do you have a strong preference for np.finfo(np.float32).eps or do you think 1e-150 is still fine? I personally prefer 1e-150 because the case where this algorithm was failing at was when denominator was around 1e-309, so I felt that 1e-150 was appropriate since it's about half of 300.

>>> np.finfo(np.float).eps
2.2204460492503131e-16
>>> np.finfo(np.double).eps
2.2204460492503131e-16

It's just that I am kind of worried that those values are too large compared to 1e-150 and I'm not sure if it will cause any rounding errors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about .tiny then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> np.finfo(np.float32).tiny
1.1754944e-38

I suppose this seems okay!

Copy link
Member
@raghavrv raghavrv Dec 8, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was suggesting np.finfo(np.double).tiny, which is close to e-300 for 32 64 bit and much less for 64 bit systems...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> np.finfo(np.double).tiny
2.2250738585072014e-308

Copy link
Member
@raghavrv raghavrv Dec 8, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thing is if your system is 32 bit, denominator - which is the result of np.sum - can only be as low as (np.finfo(np.float<arch>).tiny) IIUC...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually wasn't sure about this, because to the power of -308 seemed a bit too small as well, and it can easily overflow for not very large numerators

>>> 4/np.finfo(np.double).tiny
__main__:1: RuntimeWarning: overflow encountered in double_scalars
inf

Which was originally why I had 1e-150

@@ -970,7 +985,8 @@ def fit(self, X, y, sample_weight=None, monitor=None):
self._clear_state()

# Check input
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'], dtype=DTYPE)
X, y = check_X_y(
X, y, accept_sparse=['csr', 'csc', 'coo'], dtype=DTYPE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for reverting all pep8 changes unrelated to the PR... :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ambivalent

@jnothman
Copy link
Member
jnothman commented Dec 7, 2016

Is this meant to be MRG, not WIP?

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 7, 2016

@jnothman I think I am going to revert the flake8 fixes and then set the title to MRG.

@chenhe95 chenhe95 changed the title [WIP] Fix gradient boosting overflow and various other float comparison on == [MRG] Fix gradient boosting overflow and various other float comparison on == Dec 8, 2016
@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 8, 2016

Thanks for the feedback everyone! I have reverted the flake8 things. Let me know how it looks!

Copy link
Member
@raghavrv raghavrv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@raghavrv
Copy link
Member
raghavrv commented Dec 8, 2016

Could you make it np.finfo(np.double).tiny instead?

@raghavrv
Copy link
Member
raghavrv commented Dec 8, 2016

Sorry I missed your comment here...

@raghavrv
Copy link
Member
raghavrv commented Dec 8, 2016

You make a good point about overflow, maybe you could just use 1e-150 then. Ping @jnothman or @amueller for advice!

@amueller
Copy link
Member
amueller commented Dec 9, 2016

1e-150 is fine. np.finfo(np.double).tiny is too small.

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 9, 2016

Okay, it's reverted back to 1e-150

@chenhe95
Copy link
Contributor Author

AppVeyor is claiming that the log is empty and failing.

@jnothman
Copy link
Member

LGTM, thanks

@jnothman
Copy link
Member

Could you please add a bug fix entry to whats_new.rst? Thanks

@amueller amueller changed the title [MRG] Fix gradient boosting overflow and various other float comparison on == [MRG + 1] Fix gradient boosting overflow and various other float comparison on == Dec 14, 2016
@amueller
Copy link
Member

Hm can we add tests that no warning is raised? Or is that too tricky? Otherwise lgtm.

@chenhe95
Copy link
Contributor Author

@jnothman I have added to whats_new.rst here's how it looks

@amueller I don't really understand what testing that no warning is raised means, but if it's not too complicated I can certainly add it

@amueller
Copy link
Member

Oh, there's actually a ValueError in the issue. You should add a test to ensure this value error doesn't happen any more after your fix.

@chenhe95
Copy link
Contributor Author
chenhe95 commented Dec 14, 2016

I am unsure if the ValueError is easily reproducible since the original reporter of the error said

I am unable to reproduce the behaviour since it happens in a heavily parallelized randomized search

But I am fairly confident that this will fix the ValueError because the float will not be compared to 0.0 using == anymore

@amueller
Copy link
Member

The point about adding a test is that we don't accidentally introduce the same bug down the road.

@chenhe95
Copy link
Contributor Author

Okay, I'll see what I can do to come up with some test cases.

@jnothman
Copy link
Member
jnothman commented Dec 27, 2016

Any luck on this, @chenhe95?

@chenhe95
Copy link
Contributor Author

Hmm.. not really. The last few days of finals have been rough and I have been working on my other CountFeaturizer pull request.
Tomorrow I will have some time and I will see if I can come up with any good test cases, but I think that I have a lot of documentation to go through.

@jnothman
Copy link
Member

@amueller, this is the sort of thing that I suspect we can only reasonably test by separating out a smaller private helper as a unit and testing that. I am inclined to merge the patch even if we can't build a test with ease.

F42D

@raghavrv
Copy link
Member
raghavrv commented Jan 5, 2017

(Travis is failing because of a flake8 issue....) Also can we merge as is? @jnothman @amueller

@jnothman
Copy link
Member
jnothman commented Jan 6, 2017

Waiting for @amueller to voice his opinion on to what extent a test is necessary.

@amueller
Copy link
Member
amueller commented Mar 7, 2017

LGTM.

@amueller amueller merged commit 919b4a8 into scikit-learn:master Mar 7, 2017
@raghavrv
Copy link
Member
raghavrv commented Mar 7, 2017

Thanks @chenhe95!!

@Przemo10 Przemo10 mentioned this pull request Mar 17, 2017
herilalaina pushed a commit to herilalaina/scikit-learn that referenced this pull request Mar 26, 2017
…arison on == (scikit-learn#7970)

* reintroduced isclose() and flake8 fixes to fixes.py

* changed == 0.0 to isclose(...)

* example changes

* changed back to abs() < epsilon

* flake8 convention on file

* reverted flake8 fixes

* reverted flake8 fixes (2)

* np.finfo(np.float32).tiny instead of hard coded epsilon 1e-150

* reverted to 1e-150

* whats new modified
massich pushed a commit to massich/scikit-learn that referenced this pull request Apr 26, 2017
…arison on == (scikit-learn#7970)

* reintroduced isclose() and flake8 fixes to fixes.py

* changed == 0.0 to isclose(...)

* example changes

* changed back to abs() < epsilon

* flake8 convention on file

* reverted flake8 fixes

* reverted flake8 fixes (2)

* np.finfo(np.float32).tiny instead of hard coded epsilon 1e-150

* reverted to 1e-150

* whats new modified
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
…arison on == (scikit-learn#7970)

* reintroduced isclose() and flake8 fixes to fixes.py

* changed == 0.0 to isclose(...)

* example changes

* changed back to abs() < epsilon

* flake8 convention on file

* reverted flake8 fixes

* reverted flake8 fixes (2)

* np.finfo(np.float32).tiny instead of hard coded epsilon 1e-150

* reverted to 1e-150

* whats new modified
NelleV pushed a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…arison on == (scikit-learn#7970)

* reintroduced isclose() and flake8 fixes to fixes.py

* changed == 0.0 to isclose(...)

* example changes

* changed back to abs() < epsilon

* flake8 convention on file

* reverted flake8 fixes

* reverted flake8 fixes (2)

* np.finfo(np.float32).tiny instead of hard coded epsilon 1e-150

* reverted to 1e-150

* whats new modified
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…arison on == (scikit-learn#7970)

* reintroduced isclose() and flake8 fixes to fixes.py

* changed == 0.0 to isclose(...)

* example changes

* changed back to abs() < epsilon

* flake8 convention on file

* reverted flake8 fixes

* reverted flake8 fixes (2)

* np.finfo(np.float32).tiny instead of hard coded epsilon 1e-150

* reverted to 1e-150

* whats new modified
maskani-moh pushed a commit to maskani-moh/scikit-learn that referenced this pull request Nov 15, 2017
…arison on == (scikit-learn#7970)

* reintroduced isclose() and flake8 fixes to fixes.py

* changed == 0.0 to isclose(...)

* example changes

* changed back to abs() < epsilon

* flake8 convention on file

* reverted flake8 fixes

* reverted flake8 fixes (2)

* np.finfo(np.float32).tiny instead of hard coded epsilon 1e-150

* reverted to 1e-150

* whats new modified
jwjohnson314 pushed a commit to jwjohnson314/scikit-learn that referenced this pull request Dec 18, 2017
…arison on == (scikit-learn#7970)

* reintroduced isclose() and flake8 fixes to fixes.py

* changed == 0.0 to isclose(...)

* example changes

* changed back to abs() < epsilon

* flake8 convention on file

* reverted flake8 fixes

* reverted flake8 fixes (2)

* np.finfo(np.float32).tiny instead of hard coded epsilon 1e-150

* reverted to 1e-150

* whats new modified
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

possible overflow bug in gradient boosting
4 participants
0