8000 Pprett/gradient boosting by glouppe · Pull Request #6 · pprett/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Pprett/gradient boosting #6

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 20, 2012

Conversation

glouppe
Copy link
@glouppe glouppe commented Mar 19, 2012

This is my first bunch of commits regarding your PR.

I really like how you managed to remove the "terminal" mechanisms from the Tree code :)

My changes are the following:

  • Moved _compute_feature_importances into Tree
  • Moved _build_tree into Tree
  • Use DTYPE instead of float64
  • Cosmits and pep8

Most of those do not actually concern the boosting module. I still have to review the gradient_boosting.py file into more depth. (Later today or tomorrow).

@pprett
Copy link
Owner
pprett commented Mar 19, 2012

@glouppe some of the tests fail due to numerical issues (an aftermath of changing dtype). I fixed those but I notice a performance regression for the following benchmark::

import numpy as np
from sklearn import datasets
from sklearn.ensemble import gradient_boosting

X, y = datasets.make_hastie_10_2(n_samples=12000, random_state=1)
X = X.astype(np.float32)

gbrt = gradient_boosting.GradientBoostingClassifier(n_estimators=250,
                                                min_samples_split=5,
                                                max_depth=1,
                                                learn_rate=1.0,
                                                random_state=0)
%timeit gbrt.fit(X, y)

it goes from::

1 loops, best of 3: 1.32 s per loop

to::

1 loops, best of 3: 1.97 s per loop

@pprett
Copy link
Owner
pprett commented Mar 19, 2012

hmm... I think I hunted it down::

379       250       768998   3076.0     42.2              residual = loss.negative_gradient(y, y_pred, k=k)

This is 4 times the usual timing due to y and y_pred having different dtype.

The error of the (best) split.
For leaves `init_error == `best_error`.

init_error : np.ndarray of float64
init_error : np.ndarray of DTYPE
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why should init_error or best_error have type DTYPE which is the dtype of the data array? Either use np.float32 or np.float64. I tend to use np.float64 whenever possible (i.e. when memory consumption is not an issue).

@pprett
Copy link
Owner
pprett commented Mar 19, 2012

wow... seems like 32bit floating point arithmetic in numpy is substantially slower than 64bit arithmetic::

%timeit bd.negative_gradient(y, y_pred)
1000 loops, best of 3: 546 us per loop

vs 32bit::

%timeit bd.negative_gradient(y_float32, y_pred_float32)
100 loops, best of 3: 3.01 ms per loop

it seems that np.exp is the one to blame.

@glouppe
Copy link
Author
glouppe commented Mar 19, 2012

Wow that's huge. I was not aware of this. Actually, my machine is 32 bits that's the reason why I like to have the possibility to not use float64. I will have a deeper look at it tomorrow. I'll revert my changes if I come to no good solution.

@pprett
Copy link
Owner
pprett commented Mar 19, 2012

it might be slower on 64bit machines but a 6-fold increase is too
large - numpy has a npy_expf function that operates on float32 but
I don't know whether it is exposed to the numpy API... i keep you
posted

2012/3/19 Gilles Louppe
reply@reply.github.com:

Wow that's huge. I was not aware of this. Actually, my machine is 32 bits that's the reason why I like to have the possibility to not use float64. I will have a deeper look at it tomorrow. I'll revert my changes if I come to no good solution.


Reply to this email directly or view it on GitHub:
#6 (comment)

Peter Prettenhofer

@pprett
Copy link
Owner
pprett commented Mar 19, 2012

Gilles, I just checked the other (regression) models in sklearn, it seems that only tree and ensemble use 32bit floating point for the target values. SVM and Lasso/ElasticNet/SGDRegressor explicitly convert to 64bit. I'd rather use 64bit for tree and ensemble too - this has the advantage that results are more stable (I remember we use np.mean(y) somewhere in our code which might pose a underflow problem) - AFAIK we choose 32bit because of memory consumption which is only an issue for X but not y.

@glouppe
Copy link
Author
glouppe commented Mar 19, 2012

Okay, I agree. I'll revert my changes tomorrow.

On 19 March 2012 22:32, Peter Prettenhofer
reply@reply.github.com
wrote:

Gilles, I just checked the other (regression) models in sklearn, it seems that only tree and ensemble use 32bit floating point for the target values. SVM and Lasso/ElasticNet/SGDRegressor explicitly convert to 64bit. I'd rather use 64bit for tree and ensemble too - this has the advantage that results are more stable (I remember we use np.mean(y) somewhere in our code which might pose a underflow problem) - AFAIK we choose 32bit because of memory consumption which is only an issue for X but not y.


Reply to this email directly or view it on GitHub:
#6 (comment)

This reverts commit 3509e16.

Conflicts:

	sklearn/ensemble/gradient_boosting.py
	sklearn/tree/tree.py
@glouppe
Copy link
Author
glouppe commented Mar 20, 2012

I just pushed a reverse commit.

@pprett pprett merged commit cc2bab9 into pprett:gradient_boosting Mar 20, 2012
@pprett
Copy link
Owner
pprett commented Mar 20, 2012

@glouppe thanks - I updated whats_new.rst and merged

pprett pushed a commit that referenced this pull request Jul 25, 2013
nitpick fixes, pep8 and fix math equations
pprett pushed a commit that referenced this pull request Mar 18, 2014
Revised text classification chapter
pprett pushed a commit that referenced this pull request Mar 18, 2014
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants
0