8000 Array API support for k-nearest neighbors models with the brute force method · Issue #26586 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Array API support for k-nearest neighbors models with the brute force method #26586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ogrisel opened this issue Jun 15, 2023 · 4 comments
Open

Comments

@ogrisel
Copy link
Member
ogrisel commented Jun 15, 2023

This issue is a sibling of a similar issue for k-means: #26585 with similar purpose but likely different constraints.

In particular an efficient implementation of k-NN on the GPU would require:

@github-actions github-actions bot added the Needs Triage Issue requires triage label Jun 15, 2023
@ogrisel ogrisel added New Feature Array API and removed Needs Triage Issue requires triage labels Jun 15, 2023
@fcharras
Copy link
Contributor

Here is a relevant gist of what could be a pytorch drop-in replacement for the kneighbors method:

https://gist.github.com/fcharras/82772cf7651e087b3b91b99105a860dd

Self quoting myself in the k-means thread::

to my knowledge the best brute force implementations require materializing the pairwise distance matrix in memory and can't go farther than the IO bottleneck, so the speedup one can get is more limited, and the pytorch implementation should be decently close from the best you can get.

@ogrisel
Copy link
Member Author
ogrisel commented Jul 13, 2023

It would be interesting to compare with cuML and if cuML is much faster than this PyTorch GPU implementation of brute force kNN, then it might be interesting to see if we can get similar performance with Triton based implementation.

@betatim
Copy link
Member
betatim commented Jul 20, 2023

Forked your original gist and added some basic cuml comparison: https://gist.github.com/betatim/68219c95f539df51afad96cd9cd14a1c

On a machine with 8 Tesla V100 (32GB RAM), 80 Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz (from looking at /proc/cpuinfo) and 1TB RAM I get about 8s for the torch implementation and about 1s for the cuml option. With 5M samples in the data, so more than the original gist used.

On a second run I got 6s and 0.33s respectively. Seems to fluctuate a bit.

@ogrisel
Copy link
Member Author
ogrisel commented Aug 21, 2023

Have you tried to set CUDA_VISIBLE_DEVICES=0 to make sure that none of the two implementation leverages the fact that you have multiple GPU devices on the benchmark machine?

@glemaitre glemaitre moved this to Todo in Array API May 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Todo
Development

No branches or pull requests

3 participants
0