-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Array API support for k-nearest neighbors models with the brute force method #26586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Here is a relevant gist of what could be a pytorch drop-in replacement for the https://gist.github.com/fcharras/82772cf7651e087b3b91b99105a860dd Self quoting myself in the k-means thread::
|
It would be interesting to compare with cuML and if cuML is much faster than this PyTorch GPU implementation of brute force kNN, then it might be interesting to see if we can get similar performance with Triton based implementation. |
Forked your original gist and added some basic cuml comparison: https://gist.github.com/betatim/68219c95f539df51afad96cd9cd14a1c On a machine with 8 Tesla V100 (32GB RAM), 80 Intel(R) Xeon(R) CPU E5-2698 v4 @ 2.20GHz (from looking at On a second run I got 6s and 0.33s respectively. Seems to fluctuate a bit. |
Have you tried to set |
This issue is a sibling of a similar issue for k-means: #26585 with similar purpose but likely different constraints.
In particular an efficient implementation of k-NN on the GPU would require:
torch.cdist
torch.topk
being discussed at:The text was updated successfully, but these errors were encountered: