8000 HistGradientBoostingClassifier slow in prediction mode · Issue #16429 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

HistGradientBoostingClassifier slow in prediction mode #16429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
SebastianBr opened this issue Feb 11, 2020 · 15 comments
Open

HistGradientBoostingClassifier slow in prediction mode #16429

SebastianBr opened this issue Feb 11, 2020 · 15 comments

Comments

@SebastianBr
Copy link

While HistGradientBoostingClassifier is 100 faster than GradientBoostingClassifier when fitting the model, I found it to be very slow in case of predicting the class probabilities, in my case about 100 times slower :-(

For example:
GradientBoostingClassifier: 3.2 min for training for 1 million examples. 32 ms for 1000 predictions.
HistGradientBoostingClassifier: 7s for training. 1s for 1000 predictions

@ogrisel
Copy link
Member
ogrisel commented Feb 12, 2020

Indeed the prediction latency was optimized for GradientBoostingClassifier and we haven't really done similar work for the newer HistGradientBoostingClassifier.

Just to clarify, do you predict on a batch of 1000 samples in a single numpy array or do you call 1000 times the predict method with 1 sample at a time?

@ogrisel
Copy link
Member
ogrisel commented Feb 12, 2020

We need to do a profiling session with py-spy with the --native flag to spot the performance bottleneck:

https://github.com/benfred/py-spy

@ogrisel
Copy link
Member
ogrisel commented Feb 12, 2020

ping @jeremiedbb

@NicolasHug
Copy link
Member

Thanks for the report @SebastianBr . Do you observe a similar difference when predicting for 1 million samples instead?

The prediction of the HistGB is multi-threaded, but the regular GB isn't. With 1000 samples, the thread spawning might be too much of an overhead. If you have time, I'm interested to see what it does if you set OMP_NUM_THREADS to 1.

@ogrisel , just a side note, we parallelize the _raw_predict (i.e. decision_function), but not the decision_function_to_proba method which is sequential. There's room for improvement here, though I doubt this is the issue

Indeed the prediction latency was optimized for GradientBoostingClassifier

Curious, what are you referring to?

@SebastianBr
Copy link
Author

Just to clarify, do you predict on a batch of 1000 samples in a single numpy array or do you call 1000 times the predict method with 1 sample at a time?
I predict on a batch.

I tested it again on larger datasets. Starting from 100k GB and HGB are indeed about equally fast. It seems there is some overhead that is probably not relevant in most cases. But since I often have smaller batches of 10 to 1000 examples and need in predictions in real-time HGB isn't in good choice for me currently.

Do you see any chance for an improvement?

@ogrisel
Copy link
Member
ogrisel commented Feb 12, 2020

@ogrisel , just a side note, we parallelize the _raw_predict (i.e. decision_function), but not the decision_function_to_proba method which is sequential. There's room for improvement here, though I doubt this is the issue

I agree.

Indeed the prediction latency was optimized for GradientBoostingClassifier
Curious, what are you referring to?

I remember that @pprett spent a lot of time profiling the predict method of GradientBoostingClassifier to make sure that prediction latency on small batches would be as low a possible.

@SebastianBr
Copy link
Author

Oh I see I haven't tried OMP_NUM_THREADS. It isn't a parameter, so where can I set it?

@ogrisel
Copy link
Member
ogrisel commented Feb 12, 2020

OMP_NUM_THREADS is an environment variable. Try to set it to 1 to run in sequential mode on small batches.

@NicolasHug
Copy link
Member

Oh I see I haven't tried OMP_NUM_THREADS. It isn't a parameter, so where can I set it?

It's an environment variable, you can do OMP_NUM_THREAD=1 python the_script.py

Do you see any chance for an improvement?

Yes. The first obvious thing is to check whether this is indeed a thread overhead issue, and if so, only parallelize the code if the number of samples is high enough. The second one is to pack the tree structure so that it optimizes cache hits and stuff.

@SebastianBr
Copy link
Author

Thanks for the info. I benchmarked the prediction with OMP_NUM_THREAD=1 and it was actually slower. For example, the smallest batch I tested of size 5 was 25 ms with one thread and 16 ms with multithreading.

So it seems the overhead lies in the method itself.

@NicolasHug
Copy link
Member

interesting, thanks

I'll run a few benchmarks on my side. From what I remember, we're consistently faster than LightGBM and XGBoost on prediction (but I haven't tested against the regular GB estimators)

@stonebig
Copy link
stonebig commented Jan 1, 2024

and a study published on arxiv in may 2023 https://arxiv.org/pdf/2305.17094.pdf

apparently xgboost can use GPU, but it doesn't mean it's better than sklearn.

rapidsai/cuml#5374 (comment)

@ogrisel
Copy link
Member
ogrisel commented Feb 28, 2025

Let's keep this particular issue focused on the topic of the prediction speed and open different issues to track various aspects of training speed instead. Those have different performance bottlenecks that can be investigated and fixed quite independently of one another.

For instance, for training there are already the following open issues:

@ogrisel
Copy link
Member
ogrisel commented Feb 28, 2025

To start working on improving the prediction throughput and or latency of HGBDT we first need to collect detailed benchmarks and profiling results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants
0