10000 Array API backends support for MLX · Issue #29673 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Array API backends support for MLX #29673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
awni opened this issue Aug 14, 2024 · 9 comments
Open

Array API backends support for MLX #29673

awni opened this issue Aug 14, 2024 · 9 comments

Comments

@awni
Copy link
awni commented Aug 14, 2024

It would be great to get the scikit-learn Array API back-end to be compatible with MLX (which is mostly conformant with the array API).

Here is an example which currently does not work for a few reasons:

from sklearn.datasets import make_classification
from sklearn import config_context
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import mlx.core as mx

X_np, y_np = make_classification(random_state=0)
X_mx = mx.array(X_np)
y_mx = mx.array(y_np)

with config_context(array_api_dispatch=True):
    lda = LinearDiscriminantAnalysis()
    X_trans = lda.fit_transform(X_mx, y_mx)

print(type(X_trans))

The reasons it does not work:

  • MLX does not have a float64 data type (similar to PyTorch MPS backend). It's a bit hacky to set mx.float64 = mx.float32 so maybe good to handle this in the scikit or in a compatibility layer.

  • MLX does not support operations with data-dependent output shapes, e.g. unique_values. Since these are optional in the array API should we attempt to avoid using them in scikit to get maximal compatibility with other frameworks?

  • There are still a couple functions missing in MLX like mx.asarray and mx.isdtype (those are pretty easy for us to add)

Relevant discussion in MLX ml-explore/mlx#1289

CC @betatim

@awni awni added Needs Triage Issue requires triage New Feature labels Aug 14, 2024
@adrinjalali
Copy link
Member

Does MLX plan to support array API? I don't think it'd be a good idea for us to support non-array-API-compatible libraries, since it'll ad quite a bit of maintenance burden and too many parallel code execution paths.

@adrinjalali adrinjalali added Array API and removed Needs Triage Issue requires triage labels Aug 14, 2024
@adrinjalali
Copy link
Member

cc @ogrisel

@awni
Copy link
Author
awni commented Aug 14, 2024

Indeed, MLX plans to support the array API with two exceptions:

  1. float64
  2. Ops with data-dependent output shapes

I think you'll find those exceptions will not be uncommon (e.g. PyTorch MPS doesn't have 1, Jax doesn't have 2). So basically the point of this issue is to discuss how to handle Array API support in scikit-learn for frameworks which don't support 1 and 2.

@betatim
Copy link
Member
betatim commented Aug 15, 2024

Thanks for creating the issue and summarising the thread that lead to it!

I think we can address the "no float64" issue in the same way we addressed it for pytorch on MPS. Though this might be the moment where we should think about doing it "properly" (via the inspection API?). Instead of special casing libraries.

On the second point, I have no thoughts other than "hmm, interesting. Need to think about that". I don't think anyone has tried using jax with scikit-learn (yet). Ideally we can find one solution for all libraries.

Overall our experience has been that there are small differences/things that need to be dealt with, so right now we prefer to explicitly list the libraries that are supported. Yes, a bit weird maybe given the goal of the array API, but :-/ This means it is great to see someone come and lobby for mlx. It is a reason to think about the things that need to be done to support it, as well as things we might be doing right now that maybe we shouldn't.

@ogrisel
Copy link
Member
ogrisel commented Sep 11, 2024

I think we can address the "no float64" issue in the same way we addressed it for pytorch on MPS. Though this might be the moment where we should think about doing it "properly" (via the inspection API?). Instead of special casing libraries.

It's not exactly the same though. torch.float64 (or array_api_compat.torch.float64 as actually used in scikit-learn) does exist as a public attribute on the namespace module but torch.float64 operations are not supported on tensors that are allocated on a MPS device (but they are on CUDA and CPU devices). For mlx, the mlx.core.float64 symbol does not even exist.

At this time we do have library specific code to handle the torch MPS constraint:

def _max_precision_float_dtype(xp, device):
"""Return the float dtype with the highest precision supported by the device."""
# TODO: Update to use `__array_namespace__info__()` from array-api v2023.12
# when/if that becomes more widespread.
xp_name = xp.__name__
if xp_name in {"array_api_compat.torch", "torch"} and (
str(device).startswith("mps")
): # pragma: no cover
return xp.float32
return xp.float64

As Tim said we could extend that code to support mlx if that was the only problem.

Ops with data-dependent output shapes

At this time we have several operations in scikit-learn that require data-dependent output shapes. We have the same problem with Dask and JAX support. More details in the following stalled experimental PRs:

I cannot suggest an easy forward at this point. More investigation/experiments are required to propose solutions to unlock this class of problems, ideally in a library agnostic fashion.

EDIT: actually the JAX inplace assignment limitation is different. It's possible that the xp.unique_values problem could be worked around more easily in limited number of places in the scikit-learn code base.

@lucascolley
Copy link
Contributor

maybe good to handle this in the scikit or in a compatibility layer.

It would be useful to take a look at wrapping MLX in array-api-compat (see e.g. data-apis/array-api-compat#76 for Dask). Then you can make changes like float64 = mlx.float32 and see what else fails the array API test suite.

@ogrisel
Copy link
Member
ogrisel commented Dec 26, 2024

For the dtype problem, the user should always feed float32 data to scikit-learn if they want to use MLX on Apple Silicon:

import numpy as np
from sklearn.datasets import make_classification
from sklearn import config_context
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import mlx.core as mx

X_np, y_np = make_classification(random_state=0)
X_np = X_np.astype(np.float32)
y_np = y_np.astype(np.float32)
X_mx = mx.array(X_np)
y_mx = mx.array(y_np)

with config_context(array_api_dispatch=True):
    lda = LinearDiscriminantAnalysis()
    X_trans = lda.fit_transform(X_mx, y_mx)

print(type(X_trans))

but that still fails with:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], line 15
     13 with config_context(array_api_dispatch=True):
     14     lda = LinearDiscriminantAnalysis()
---> 15     X_trans = lda.fit_transform(X_mx, y_mx)
     17 print(type(X_trans))

File ~/code/scikit-learn/sklearn/utils/_set_output.py:319, in _wrap_method_output.<locals>.wrapped(self, X, *args, **kwargs)
    317 @wraps(f)
    318 def wrapped(self, X, *args, **kwargs):
--> 319     data_to_wrap = f(self, X, *args, **kwargs)
    320     if isinstance(data_to_wrap, tuple):
    321         # only wrap the first output for cross decomposition
    322         return_tuple = (
    323             _wrap_data_with_container(method, data_to_wrap[0], X, self),
    324             *data_to_wrap[1:],
    325         )

File ~/code/scikit-learn/sklearn/base.py:921, in TransformerMixin.fit_transform(self, X, y, **fit_params)
    918     return self.fit(X, **fit_params).transform(X)
    919 else:
    920     # fit method of arity 2 (supervised transformation)
--> 921     return self.fit(X, y, **fit_params).transform(X)

File ~/code/scikit-learn/sklearn/base.py:1389, in _fit_context.<locals>.decorator.<locals>.wrapper(estimator, *args, **kwargs)
   1382     estimator._validate_params()
   1384 with config_context(
   1385     skip_parameter_validation=(
   1386         prefer_skip_nested_validation or global_skip_validation
   1387     )
   1388 ):
-> 1389     return fit_method(estimator, *args, **kwargs)

File ~/code/scikit-learn/sklearn/discriminant_analysis.py:661, in LinearDiscriminantAnalysis.fit(self, X, y)
    640 """Fit the Linear Discriminant Analysis model.
    641 
    642 .. versionchanged:: 0.19
   (...)
    656     Fitted estimator.
    657 """
    658 xp, _ = get_namespace(X)
    660 X, y = validate_data(
--> 661     self, X, y, ensure_min_samples=2, dtype=[xp.float64, xp.float32]
    662 )
    663 self.classes_ = unique_labels(y)
    664 n_samples, _ = X.shape

AttributeError: module 'mlx.core' has no attribute 'float64'

Because the attribute xp.float64 does not exist, even we don't actually try to ever use it.

Exposing a fake xp.float64 symbol in a compat layer might work but feels weird. It might be a pragmatic first step into exploring what else would break in practice once this technical detail is sidestepped.

@lucascolley
Copy link
Contributor

Exposing a fake xp.float64 symbol in a compat layer might work but feels weird. It might be a pragmatic first step into exploring what else would break in practice once this technical detail is sidestepped.

👍.

@j-emberton are you still interested in working on a compat layer for MLX?

@awni
Copy link
Author
awni commented Dec 26, 2024

That would be great! And functions which you think MLX should have that are part of the array API but missing we can add. Just let us know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants
0