-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
ENH: array types: add JAX support #20085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
(The reason I decided to comment over on the DLPack issue is that I recall a conversation about how portability could be increased if we replace occurrences of |
Thanks for working on this Lucas. JAX support will be very nice. And a third library with CPU support (after NumPy and PyTorch) will also be good for testing how generic our array API standard support actually is. Okay, related to the read-only question, it looks like this is the problem you were seeing:
The problem is that Cython doesn't accept read-only arrays when the signature is a regular memoryview. There's a long discussion about this topic in scikit-learn/scikit-learn#10624. Now that we have Cython 3 though, the fix is simple: diff --git a/scipy/cluster/_hierarchy.pyx b/scipy/cluster/_hierarchy.pyx
index 814051df2..c59b3de6a 100644
--- a/scipy/cluster/_hierarchy.pyx
+++ b/scipy/cluster/_hierarchy.pyx
@@ -1012,7 +1012,7 @@ def nn_chain(double[:] dists, int n, int method):
return Z_arr
-def mst_single_linkage(double[:] dists, int n):
+def mst_single_linkage(const double[:] dists, int n):
"""Perform hierarchy clustering using MST algorithm for single linkage.
Parameters This makes the tests pass (at least for this issue, I tried with the |
thanks! I've removed the copies and added some |
00e27e8
to
77aaebd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit
, because it requires converting array objects to host-side buffers, and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?
Do you mean for testing purposes, or for library code? For the latter: we should never do device transfers like GPU->host memory under the hood. The array API standard design was careful to not include that. It wasn't even possible at all until very recently, when a way was added to do it with DLPack (for testing purposes). If you mean "convert to
JIT compilers were explicitly considered, and nothing in the standard should be JIT-unfriendly, except for the few clearly marked as data-dependent output shapes and the few dunder methods that are also problematic for lazy arrays. |
If this is what you meant, x-ref the 'Dispatching Mechanism' section of gh-18286 |
I mean for actual user-level code: most of the work here will be more-or-less useless for JAX users because array conversions via dlpack cannot be done under JIT without some sort of callback mechanism. |
Okay, I had a look at https://jax.readthedocs.io/en/latest/tutorials/external-callbacks.html and understand what you mean now. It looks fairly straightforward to support (disclaimer: I haven't tried it yet). It'd be taking this current code pattern: # inside some Python-level scipy function with array API standard support:
x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result) # back to original array type and replacing it with something like (untested): def call_compiled_code_helper(x, xp): # needs *args, *kwargs too
if is_jax(x):
result_shape_dtypes = ... # TODO: figure out how to construct the needed PyTree here
result = jax.pure_callback(call_some_compiled_code, result_shape_dtypes, x)
else:
x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result) Use of a utility function like It's interesting that |
Yeah, something like that is what I had in mind, though |
It is (depending on your defintion of "issue") because there's no magic bullet that will do something like take some native function implemented in C/Fortran/Cython inside SciPy and make that run on GPU. The basic state of things is:
In a generic library like SciPy it's almost impossible to support custom kernels on device. Our choices for arrays that don't live on host memory are:
|
I gave adding Dask another shot just now, but unfortunately things are missing from |
I'd suggest keeping this PR focused on JAX and getting that merged first. That makes it easier to see (also in the future) what had to be done only for JAX. And if we're going to experiment a bit with |
I think it would be worth adding something that works for now, even if it's not great. It would avoid all the test skips and make it more obvious what capabilities we need. Once something better comes along, it will be easy to replace. It's probably better than using |
Yeah maybe - I don't want to go too fast though, and add a bunch of code we may regret. Looks like the new version (I edited my comment and pushed a new commit) works though, and is still very fast with JAX.
Let's make sure not to do things like that. Using |
@rgommers FYI I managed to implement the scalar boolean scatter in JAX, and it will be available in the next release. Turns out we had all the necessary logic there already – I just needed to put it together! jax-ml/jax#21305 |
Great! Thanks @jakevdp. Looks like a small patch that I can try out pretty easily on top of JAX 0.4.28 - will give it a go later this week. (note to self, since comments are hard to find in this PR: the relevant comment here is #20085 (comment)) |
Resurrecting the conversation about #20085 (comment) based on gh-20935. @rgommers can we add that function that mutates an array at boolean indices where possible and copies when necessary (e.g. JAX)? If we regret it later, we can just change the definition of the function. The only downside I see would be the overhead of an extra function call for non-JAX arrays. The potential upside is JAX support in many functions. If the experiment fails completely or a new array-API standard functionality is made available, we can revert or change wherever the new function is used. Adding/removing all these skips has some cost, too, and I would prefer to with or at least note any other JAX incompatibilities while we're still working on converting a function rather than having to come back later. @lucascolley @jakevdp anything to add / change? |
Late here but responding to #20085 (comment)
I suspect what you have in mind is the special case when the size of |
@jakevdp would you be willing to open a PR so we have something concrete to discuss? I'd be happy to contribute to the branch by using the function in stats functions and showing that it allows us to remove test skips without performance costs to non-JAX arrays. |
Thanks for the answer @jakevdp!
No, not really. To me there's a fundamental difference between the semantics of a function, and implementation details of it. For I would not call the requirement that
It'd be moving the "if boolean indices, then use
Note also that the error message for the boolean case is misleading, since it recommends a
It would be useful to have an improved error message here, and mention the boolean indexing case in https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html. I'm actually curious what your preference here is. It seems like the |
Agreed that the output shapes are the same no matter what. But still, there are dynamic value-based semantics at play here in general, because if Alternatively, you could relax the requirement that |
I like your ideas for improved error messages, but I'm not sure how to discover in |
Another compilcation is that |
I'm not sure what exactly you're referring to (what repository would this PR be against, and what would this PR do?). |
In #20085 (comment) above, Ralf experimented with a function that mutates an array at boolean indices where possible and copies when necessary (depending on the backend). In #20085 (comment), I asked again for it to be added to SciPy so we can use it in places that are currently blocking JAX support. In #20085 (comment), you seemed to have some reservations about the particular approach, so in #20085 (comment) I asked if you would be willing to help with it by opening a PR that implements We would use it only in places where mutation is beneficial for performance but copies can be tolerated. (When copies are desired, we can use e.g. This would be against SciPy; specificially, it would be an addition to |
I think there were two conversations going on at the same time, so things got muddled. I don't have any reservations with the approach of a helper that dispatches array updates differently in different contexts. That seems like a great idea, and I'd been under the impression that scipy was already doing that! If you'd like me to put together a prototype, I can. Where would be the appropriate place to define it? |
Great! I think it would go with our other helpers in https://github.com/scipy/scipy/blob/main/scipy/_lib/_array_api.py. Thank you! It will hopefully be simple enough that I can test it just by using it (i.e. to enable JAX support in a function that currently skips JAX because of the immutability). If you think it needs other tests, they can go in https://github.com/scipy/scipy/blob/main/scipy/_lib/tests/test_array_api.py. |
Update: I'm having trouble building scipy locally due to some gfortran issues; I'm sure it's something I could resolve with a few hours work but I have more pressing priorities for blocks of time that large at the moment. So a PR is not on the table in the next few days. Just to get something concrete down, what I have in mind is something like this: if is_jax(x):
x = xp.where(xp.isnan(x), 0, x)
else:
x[xp.isnan(x)] = 0 It would certainly not cover all instances of See also the related discussion I started at data-apis/array-api#845. |
OK, sounds good. We can include it in a PR that enables |
It's getting dangerously close to the start of the university term for me so I think I'll start saying no to writing new PRs :) happy to review though! |
Reference issue
Towards gh-18867
What does this implement/fix?
First steps on JAX support. To-do:
Additional information
Can do the same for
dask.array
once the problems are fixed over at data-apis/array-api-compat#89.