8000 ENH: array types: add JAX support by lucascolley · Pull Request #20085 · scipy/scipy · GitHub
[go: up one dir, main page]

Skip to content

ENH: array types: add JAX support #20085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 73 commits into from
May 18, 2024
Merged

ENH: array types: add JAX support #20085

merged 73 commits into from
May 18, 2024

Conversation

lucascolley
Copy link
Member
@lucascolley lucascolley commented Feb 13, 2024

Reference issue

Towards gh-18867

What does this implement/fix?

First steps on JAX support. To-do:

Additional information

Can do the same for dask.array once the problems are fixed over at data-apis/array-api-compat#89.

@github-actions github-actions bot added scipy.cluster scipy._lib Meson Items related to the introduction of Meson as the new build system for SciPy array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement labels Feb 13, 2024
@lucascolley lucascolley removed the Meson Items related to the introduction of Meson as the new build system for SciPy label Feb 13, 2024
@lucascolley
Copy link
Member Author
lucascolley commented Feb 14, 2024

(The reason I decided to comment over on the DLPack issue is that I recall a conversation about how portability could be increased if we replace occurrences of np.asarray with {array_api_compat.numpy, np>=2.0}.from_dlpack}. Clearly, portability past libraries which are coercible by np.asarray is very low prio at the minute, but something to consider long-term. Also, DLPack being the idiomatic way to do library-interchange, rather than relying on the array-creation function asarray)

@rgommers
Copy link
Member

Thanks for working on this Lucas. JAX support will be very nice. And a third library with CPU support (after NumPy and PyTorch) will also be good for testing how generic our array API standard support actually is.

Okay, related to the read-only question, it looks like this is the problem you were seeing:

scipy/cluster/hierarchy.py:1038: in linkage
    result = _hierarchy.mst_single_linkage(y, n)
        method     = 'single'
        method_code = 0
        metric     = 'euclidean'
        n          = 6
        optimal_ordering = False
        xp         = <module 'jax.experimental.array_api' from '/home/rgommers/mambaforge/envs/scipy-dev-jax/lib/python3.11/site-packages/jax/experimental/array_api/__init__.py'>
        y          = array([1.48660687, 2.23606798, 1.41421356, 1.41421356, 1.41421356,
       2.28254244, 0.1       , 1.48660687, 1.48660687, 2.23606798,
       1.        , 1.        , 1.41421356, 1.41421356, 0.        ])
_hierarchy.pyx:1015: in scipy.cluster._hierarchy.mst_single_linkage
    ???
<stringsource>:663: in View.MemoryView.memoryview_cwrapper
    ???
<stringsource>:353: in View.MemoryView.memoryview.__cinit__
    ???
E   ValueError: buffer source array is read-only

The problem is that Cython doesn't accept read-only arrays when the signature is a regular memoryview. There's a long discussion about this topic in scikit-learn/scikit-learn#10624. Now that we have Cython 3 though, the fix is simple:

diff --git a/scipy/cluster/_hierarchy.pyx b/scipy/cluster/_hierarchy.pyx
index 814051df2..c59b3de6a 100644
--- a/scipy/cluster/_hierarchy.pyx
+++ b/scipy/cluster/_hierarchy.pyx
@@ -1012,7 +1012,7 @@ def nn_chain(double[:] dists, int n, int method):
     return Z_arr
 
 
-def mst_single_linkage(double[:] dists, int n):
+def mst_single_linkage(const double[:] dists, int n):
     """Perform hierarchy clustering using MST algorithm for single linkage.
 
     Parameters

This makes the tests pass (at least for this issue, I tried with the dendrogram 8000 tests only). The dists input to mst_single_linkage isn't modified in-place, so once we tell Cython that by adding const, things are happy.

@lucascolley lucascolley added the Cython Issues with the internal Cython code base label Feb 14, 2024
@lucascolley
Copy link
Member Author

thanks! I've removed the copies and added some consts to the Cython file to get the tests to pass. Still some failures for in-place assignments with indexing but we can circle back to those once we get integration with the test skip infra.

@lucascolley lucascolley force-pushed the jax branch 2 times, most recently from 00e27e8 to 77aaebd Compare February 17, 2024 16:58
Copy link
Member
@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit, because it requires converting array objects to host-side buffers, and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?

@rgommers
Copy link
Member

One question I have here, which is probably a question more broadly for the array API: as written, much of the JAX support added here will not work under jax.jit, because it requires converting array objects to host-side buffers,

Do you mean for testing purposes, or for library code? For the latter: we should never do device transfers like GPU->host memory under the hood. The array API standard design was careful to not include that. It wasn't even possible at all until very recently, when a way was added to do it with DLPack (for testing purposes).

If you mean "convert to numpy.ndarray before going into Cython/C/C++/Fortran code inside SciPy, then yes that is happening. That's kinda not an array API standard issue, because it's leaving Python - and that's a very different problem. To avoid compiled code inside SciPy - which indeed won't work with any JIT compiler unless that JIT is specifically aware of the SciPy functionality being called - it'd be necessary to have either a pure Python path (slow) or a matching API inside JAX that can be called (jax.scipy has some that we should be deferring to here).

and this is not possible during tracing when the array objects are abstract. JAX has mechanisms for this (namely custom calls and/or pure_callback) but the array API doesn't seem to have much consideration for this kind of library structure. Unfortunately, I think this will severely limit the usefulness of these kinds of implementations. I wonder if the array API could consider this kind of limitation?

JIT compilers were explicitly considered, and nothing in the standard should be JIT-unfriendly, except for the few clearly marked as data-dependent output shapes and the few dunder methods that are also problematic for lazy arrays.

@lucascolley
Copy link
Member Author
lucascolley commented Feb 27, 2024

Do you mean for testing purposes, or for library code? For the latter

If this is what you meant, x-ref the 'Dispatching Mechanism' section of gh-18286

@jakevdp
Copy link
Member
jakevdp commented Feb 27, 2024

I mean for actual user-level code: most of the work here will be more-or-less useless for JAX users because array conversions via dlpack cannot be done under JIT without some sort of callback mechanism.

@rgommers
Copy link
Member

Okay, I had a look at https://jax.readthedocs.io/en/latest/tutorials/external-callbacks.html and understand what you mean now. jax.pure_callback looks quite interesting indeed. I wasn't familiar with it, but it looks like that may actually solve an important puzzle in dealing with compiled code. It doesn't support GPU execution or auto-differentiation, but getting jax.jit and jax.vmap to work would be a significant step forward.

It looks fairly straightforward to support (disclaimer: I haven't tried it yet). It'd be taking this current code pattern:

# inside some Python-level scipy function with array API standard support:

x = np.asarray(x)
result = call_some_compiled_code(x)
result = xp.asarray(result)  # back to original array type

and replacing it with something like (untested):

def call_compiled_code_helper(x, xp):  # needs *args, *kwargs too
    if is_jax(x):
        result_shape_dtypes = ... # TODO: figure out how to construct the needed PyTree here
        result = jax.pure_callback(call_some_compiled_code, result_shape_dtypes, x)
    else:
        x = np.asarray(x)
        result = call_some_compiled_code(x)
        result = xp.asarray(result)

Use of a utility function like call_compiled_code_helper may even make the code shorter and easier to understand. It seems feasible at first sight.

It's interesting that jax.pure_callback transforms JAX arrays to NumPy arrays under the hood already.

@jakevdp
Copy link
Member
jakevdp commented Feb 27, 2024

Yeah, something like that is what I had in mind, though pure_callback is probably not the right mechanism. JAX doesn't currently have an easy pure-callback-like mechanism for executing custom kernels on device, without the round-trip to host implied by pure_callback. I wonder if this kind of thing will be an issue for other array API libraries?

@rgommers
Copy link
Member

I wonder if this kind of thing will be an issue for other array API libraries?

It is (depending on your defintion of "issue") because there's no magic bullet that will do something like take some native function implemented in C/Fortran/Cython inside SciPy and make that run on GPU.

The basic state of things is:

  • functions implemented in pure Python are unproblematic, and with array API support get to run on GPU/TPU, gain autograd support, etc.
    • with a few exceptions: functions using unique and other data-dependent shapes, iterative algorithms with a stopping/branching criterion that requires eager evaluation, functions using in-place operations.
  • as soon as you hit compiled code, things get harder. everything that worked before with numpy only will still work, but autograd and GPU execution won't

JAX doesn't currently have an easy pure-callback-like mechanism for executing custom kernels on device, without the round-trip to host implied by pure_callback.

In a generic library like SciPy it's almost impossible to support custom kernels on device. Our choices for arrays that don't live on host memory are:

  • find a matching function in the other library. e.g., we can explicitly defer to everything in jax.scipy, cupyx.scipy and torch.fft/linalg/special,
  • raise an exception
  • do an automatic to/from host roundtrip (we haven't considered this a good idea before, since data transfers can be very expensive - but apparently that's what pure_callback prefers over raising)

@lucascolley
Copy link
Member Author
lucascolley commented Feb 28, 2024

I gave adding Dask another shot just now, but unfortunately things are missing from dask.array like float64, which makes most of our test code fail. Perhaps we will have to change to using the wrapped namespaces throughout the tests (this is awkward because we still need to imitate an array from the unwrapped namespace being input).

x-ref dask/dask#10387 (comment)

@rgommers
Copy link
Member

I'd suggest keeping this PR focused on JAX and getting that merged first. That makes it easier to see (also in the future) what had to be done only for JAX. And if we're going to experiment a bit with jax.jit, this PR may grow already.

@mdhaber
Copy link
Contributor
mdhaber commented May 19, 2024

I think it would be worth adding something that works for now, even if it's not great. It would avoid all the test skips and make it more obvious what capabilities we need. Once something better comes along, it will be easy to replace. It's probably better than using where, which we're tempted to use otherwise.

@rgommers
Copy link
Member

I think it would be worth adding something that works for now, even if it's not great.

Yeah maybe - I don't want to go too fast though, and add a bunch of code we may regret. Looks like the new version (I edited my comment and pushed a new commit) works though, and is still very fast with JAX.

It's probably better than using where, which we're tempted to use otherwise.

Let's make sure not to do things like that. Using where could potentially be bad for performance with numpy, which would not be helpful. Skips are better for now.

@jakevdp
Copy link
Member
jakevdp commented May 20, 2024

@rgommers FYI I managed to implement the scalar boolean scatter in JAX, and it will be available in the next release. Turns out we had all the necessary logic there already – I just needed to put it together! jax-ml/jax#21305

@rgommers
Copy link
Member

Great! Thanks @jakevdp. Looks like a small patch that I can try out pretty easily on top of JAX 0.4.28 - will give it a go later this week.

(note to self, since comments are hard to find in this PR: the relevant comment here is #20085 (comment))

Sorry, something went wrong.

@mdhaber
Copy link
Contributor
mdhaber commented Jun 11, 2024

Resurrecting the conversation about #20085 (comment) based on gh-20935.

@rgommers can we add that function that mutates an array at boolean indices where possible and copies when necessary (e.g. JAX)? If we regret it later, we can just change the definition of the function. The only downside I see would be the overhead of an extra function call for non-JAX arrays. The potential upside is JAX support in many functions. If the experiment fails completely or a new array-API standard functionality is made available, we can revert or change wherever the new function is used. Adding/removing all these skips has some cost, too, and I would prefer to with or at least note any other JAX incompatibilities while we're still working on converting a function rather than having to come back later.

@lucascolley @jakevdp anything to add / change?

@jakevdp
Copy link
Member
jakevdp commented Jun 11, 2024

Late here but responding to #20085 (comment)

There are no dynamic shapes here, so it could work just fine.

a.at[mask].set(arr) in general does require dynamic shapes: it is only valid if arr is broadcast-compatible with the number of True values in mask, and the number of True values is dynamic.

I suspect what you have in mind is the special case when the size of arr is 1, and so we know a priori that it broadcasts with an array of any size. Still, the semantics of lax.scatter require actually instantiating that array. I've explored the idea of overloading JAX's arr.at[].set() to lower to lax.select rather than lax.scatter in this particular special case, but overall it seems like it adds undue complexity to the implementation and to the user's mental model of what this function does.

@mdhaber
Copy link
Contributor
mdhaber commented Jun 11, 2024

@jakevdp would you be willing to open a PR so we have something concrete to discuss? I'd be happy to contribute to the branch by using the function in stats functions and showing that it allows us to remove test skips without performance costs to non-JAX arrays.

@rgommers
Copy link
Member

Thanks for the answer @jakevdp!

I suspect what you have in mind is the special case when the size of arr is 1

No, not really. To me there's a fundamental difference between the semantics of a function, and implementation details of it. For b = a.at[mask].set(vals), it doesn't matter whether vals is scalar or an array with a compatible shape, or how many True values there are in mask. The shape of b is the same as that of a in all cases, so the semantics of the function do not include any dynamic shapes. This is very different from y = x[mask], where y.shape depends on the values in mask.

I would not call the requirement that mask and vals are broadcast-compatible a dynamic shape. It's more like input validation that may raise an error (or propagate nan's if you can't raise an error), just like there are functions that don't deal with inf/nan, negative values, singular matrices, or any other case of input values not meeting what's required for some function. The output shape itself cannot change.

I've explored the idea of overloading JAX's arr.at[].set() to lower to lax.select rather than lax.scatter in this particular special case, but overall it seems like it adds undue complexity to the implementation and to the user's mental model of what this function does.

It'd be moving the "if boolean indices, then use where or select under the hood" if-else logic from user-land to a single place inside JAX I think. So my expectation is that it would reduce the complexity of the mental model, at least when coming from NumPy or PyTorch. Because now it seems like we have to do the following translation for in-place ops: (for val is either a Python scalar or an array):

  • x += val: works
  • x[int_indices] += val: TypeError -> use x.at[int_indices].add(val)
  • x[bool_mask] += val: TypeError -> use x.where(bool_mask, x, x + val) if val is scalar and ?? if it's an array

Note also that the error message for the boolean case is misleading, since it recommends a .at method but we need where instead:

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

It would be useful to have an improved error message here, and mention the boolean indexing case in https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.

I'm actually curious what your preference here is. It seems like the at_set/at_add etc. implementations I posted above are still non-JAX-like. But replacing a line like x[mask] += 2 with xp.where(mask, x, x+2) is both harder to read and a lot less efficient with NumPy & co. And for x[mask] += y it gets worse (don't even know how to properly do that in JAX, and it'd lose jit-ability since now you're forced to break up the op and then it does use dynamic shapes). So I don't really know what the best way forward here is.

@jakevdp
Copy link
Member
jakevdp commented Aug 9, 2024

To me there's a fundamental difference between the semantics of a function, and implementation details of it. For b = a.at[mask].set(vals), it doesn't matter whether vals is scalar or an array with a compatible shape, or how many True values there are in mask. The shape of b is the same as that of a in all cases, so the semantics of the function do not include any dynamic shapes.

Agreed that the output shapes are the same no matter what. But still, there are dynamic value-based semantics at play here in general, because if mask.sum() != len(vals), the code should raise a ValueError. The only time you don't have dynamic value-based semantics are when vals is a scalar or a length-1 array.

Alternatively, you could relax the requirement that mask.sum() == len(vals) and instead fill and/or truncate if needed, but I suspect that would cause surprises for the user.

@jakevdp
Copy link
Member
jakevdp commented Aug 9, 2024

I like your ideas for improved error messages, but I'm not sure how to discover in __setitem__ whether the call came from something like x[mask] += 1. Note that we can't override __iadd__ because we need Python to execute x += y in terms of __add__, and the more general case of x[mask] = val cannot in general be re-expressed in terms of where (unless val is a scalar, which relates to my previous point).

@jakevdp
Copy link
Member
jakevdp commented Aug 9, 2024

Another compilcation is that x.at[bool_mask].add(val) is just fine, so long as you're not within a JIT context.

@jakevdp
Copy link
Member
jakevdp commented Sep 20, 2024

@jakevdp would you be willing to open a PR so we have something concrete to discuss? I'd be happy to contribute to the branch by using the function in stats functions and showing that it allows us to remove test skips without performance costs to non-JAX arrays.

I'm not sure what exactly you're referring to (what repository would this PR be against, and what would this PR do?).

@mdhaber
Copy link
Contributor
mdhaber commented Sep 20, 2024

In #20085 (comment) above, Ralf experimented with a function that mutates an array at boolean indices where possible and copies when necessary (depending on the backend). In #20085 (comment), I asked again for it to be added to SciPy so we can use it in places that are currently blocking JAX support. In #20085 (comment), you seemed to have some reservations about the particular approach, so in #20085 (comment) I asked if you would be willing to help with it by opening a PR that implements at_set while addressing your concerns.

We would use it only in places where mutation is beneficial for performance but copies can be tolerated. (When copies are desired, we can use e.g. where as before; when mutation is "needed", we mutate and skip JAX in tests.)

This would be against SciPy; specificially, it would be an addition to _lib/_array_api.py. It should probably be named xp_at_set for consistency.

@jakevdp
Copy link
Member
jakevdp commented Sep 20, 2024

I think there were two conversations going on at the same time, so things got muddled. I don't have any reservations with the approach of a helper that dispatches array updates differently in different contexts. That seems like a great idea, and I'd been under the impression that scipy was already doing that!

If you'd like me to put together a prototype, I can. Where would be the appropriate place to define it?

@mdhaber
Copy link
Contributor
mdhaber commented Sep 21, 2024

Great! I think it would go with our other helpers in https://github.com/scipy/scipy/blob/main/scipy/_lib/_array_api.py. Thank you! It will hopefully be simple enough that I can test it just by using it (i.e. to enable JAX support in a function that currently skips JAX because of the immutability). If you think it needs other tests, they can go in https://github.com/scipy/scipy/blob/main/scipy/_lib/tests/test_array_api.py.

@jakevdp
Copy link
Member
jakevdp commented Sep 23, 2024

Update: I'm having trouble building scipy locally due to some gfortran issues; I'm sure it's something I could resolve with a few hours work but I have more pressing priorities for blocks of time that large at the moment. So a PR is not on the table in the next few days.

Just to get something concrete down, what I have in mind is something like this:

if is_jax(x):
  x = xp.where(xp.isnan(x), 0, x)
else:
  x[xp.isnan(x)] = 0

It would certainly not cover all instances of __setitem__-based mutation, but it would be a solution for many of them that have come up (and specifically would apply in the case I commented on in #21597, when I was referred back to here).

See also the related discussion I started at data-apis/array-api#845.

@mdhaber
Copy link
Contributor
mdhaber commented Sep 23, 2024

OK, sounds good. We can include it in a PR that enables logsumexp tests to run with JAX. Thanks @jakevdp!
@lucascolley thought you might be interested to either write or review.

@lucascolley
Copy link
Member Author

@lucascolley thought you might be interested to either write or review.

It's getting dangerously close to the start of the university term for me so I think I'll start saying no to writing new PRs :) happy to review though!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) Cython Issues with the internal Cython code base enhancement A new feature or improvement scipy.cluster scipy.fft scipy._lib scipy.special scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants
0