-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Array API backends support for MLX #29673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Does MLX plan to support array API? I don't think it'd be a good idea for us to support non-array-API-compatible libraries, since it'll ad quite a bit of maintenance burden and too many parallel code execution paths. |
cc @ogrisel |
Indeed, MLX plans to support the array API with two exceptions:
I think you'll find those exceptions will not be uncommon (e.g. PyTorch MPS doesn't have 1, Jax doesn't have 2). So basically the point of this issue is to discuss how to handle Array API support in scikit-learn for frameworks which don't support 1 and 2. |
Thanks for creating the issue and summarising the thread that lead to it! I think we can address the "no On the second point, I have no thoughts other than "hmm, interesting. Need to think about that". I don't think anyone has tried using jax with scikit-learn (yet). Ideally we can find one solution for all libraries. Overall our experience has been that there are small differences/things that need to be dealt with, so right now we prefer to explicitly list the libraries that are supported. Yes, a bit weird maybe given the goal of the array API, but :-/ This means it is great to see someone come and lobby for mlx. It is a reason to think about the things that need to be done to support it, as well as things we might be doing right now that maybe we shouldn't. |
It's not exactly the same though. At this time we do have library specific code to handle the torch MPS constraint: scikit-learn/sklearn/utils/_array_api.py Lines 650 to 659 in d9deffe
As Tim said we could extend that code to support mlx if that was the only problem.
At this time we have several operations in scikit-learn that require data-dependent output shapes. We have the same problem with Dask and JAX support. More details in the following stalled experimental PRs:
I cannot suggest an easy forward at this point. More investigation/experiments are required to propose solutions to unlock this class of problems, ideally in a library agnostic fashion. EDIT: actually the JAX inplace assignment limitation is different. It's possible that the |
It would be useful to take a look at wrapping MLX in array-api-compat (see e.g. data-apis/array-api-compat#76 for Dask). Then you can make changes like |
For the dtype problem, the user should always feed float32 data to scikit-learn if they want to use MLX on Apple Silicon: import numpy as np
from sklearn.datasets import make_classification
from sklearn import config_context
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import mlx.core as mx
X_np, y_np = make_classification(random_state=0)
X_np = X_np.astype(np.float32)
y_np = y_np.astype(np.float32)
X_mx = mx.array(X_np)
y_mx = mx.array(y_np)
with config_context(array_api_dispatch=True):
lda = LinearDiscriminantAnalysis()
X_trans = lda.fit_transform(X_mx, y_mx)
print(type(X_trans)) but that still fails with: ---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[3], line 15
13 with config_context(array_api_dispatch=True):
14 lda = LinearDiscriminantAnalysis()
---> 15 X_trans = lda.fit_transform(X_mx, y_mx)
17 print(type(X_trans))
File ~/code/scikit-learn/sklearn/utils/_set_output.py:319, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
317 @wraps(f)
318 def wrapped(self, X, *args, **kwargs):
--> 319 data_to_wrap = f(self, X, *args, **kwargs)
320 if isinstance(data_to_wrap, tuple):
321 # only wrap the first output for cross decomposition
322 return_tuple = (
323 _wrap_data_with_container(method, data_to_wrap[0], X, self),
324 *data_to_wrap[1:],
325 )
File ~/code/scikit-learn/sklearn/base.py:921, in TransformerMixin.fit_transform(self, X, y, **fit_params)
918 return self.fit(X, **fit_params).transform(X)
919 else:
920 # fit method of arity 2 (supervised transformation)
--> 921 return self.fit(X, y, **fit_params).transform(X)
File ~/code/scikit-learn/sklearn/base.py:1389, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
1382 estimator._validate_params()
1384 with config_context(
1385 skip_parameter_validation=(
1386 prefer_skip_nested_validation or global_skip_validation
1387 )
1388 ):
-> 1389 return fit_method(estimator, *args, **kwargs)
File ~/code/scikit-learn/sklearn/discriminant_analysis.py:661, in LinearDiscriminantAnalysis.fit(self, X, y)
640 """Fit the Linear Discriminant Analysis model.
641
642 .. versionchanged:: 0.19
(...)
656 Fitted estimator.
657 """
658 xp, _ = get_namespace(X)
660 X, y = validate_data(
--> 661 self, X, y, ensure_min_samples=2, dtype=[xp.float64, xp.float32]
662 )
663 self.classes_ = unique_labels(y)
664 n_samples, _ = X.shape
AttributeError: module 'mlx.core' has no attribute 'float64' Because the attribute Exposing a fake |
👍. @j-emberton are you still interested in working on a compat layer for MLX? |
That would be great! And functions which you think MLX should have that are part of the array API but missing we can add. Just let us know. |
It would be great to get the scikit-learn Array API back-end to be compatible with MLX (which is mostly conformant with the array API).
Here is an example which currently does not work for a few reasons:
The reasons it does not work:
MLX does not have a
float64
data type (similar to PyTorch MPS backend). It's a bit hacky to setmx.float64 = mx.float32
so maybe good to handle this in the scikit or in a compatibility layer.MLX does not support operations with data-dependent output shapes, e.g.
unique_values
. Since these are optional in the array API should we attempt to avoid using them in scikit to get maximal compatibility with other frameworks?There are still a couple functions missing in MLX like
mx.asarray
andmx.isdtype
(those are pretty easy for us to add)Relevant discussion in MLX ml-explore/mlx#1289
CC @betatim
The text was updated successfully, but these errors were encountered: