8000 RFC/API (Array API) mixing devices and data types with estimators · Issue #26083 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

RFC/API (Array API) mixing devices and data types with estimators #26083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
adrinjalali opened this issue Apr 4, 2023 · 7 comments
Open

RFC/API (Array API) mixing devices and data types with estimators #26083

adrinjalali opened this issue Apr 4, 2023 · 7 comments

Comments

@adrinjalali
Copy link
Member

Right now, if the user fits an estimator using a pandas.DataFrame, but passes a numpy.ndarray during predict, they get a warning due to missing feature names.

The situation is only to get more complicated as we're adding support for more types via array API.

Some related issues here are:

  • device: data during fit sits on a GPU, but a CPU is used for predict (with the same data type)
  • types: using one type to fit, and use another type during predict: how do we handle this both in terms of device and the type? Do we let the operator figure out if they can coerce the data into the type which can be used?
  • persistence: how do we let users fit on one device, but load on another device
  • estimator conversion: do we let users convert an estimator which is fit using one type/device, to an estimator compatible with another type/device?

I vaguely remember us talking about some of these issues, but I don't see any active discussion. I might have missed something.

Related: in a library like pytorch, you can decide which device is going to be used when you load a model's weights.

cc @thomasjpfan

@betatim
Copy link
Member
betatim commented Apr 4, 2023

The current state of thinking regarding estimator attributes and their conversion is in the documentation https://scikit-learn.org/dev/modules/array_api.html. TL;DR: the user has to convert it explicitly.

A reason not to offer something more "built in" is that conversion from array library X to Numpy and the other direction is something that depends on the array library. For Cupy to Numpy you need cp_arr.get(), for pytorch to numpy you need something like torch_arr.cpu().numpy(), etc. And to offer something like array lib X to array lib Y conversion you quickly end up with a combinatorial explosion of possibilities.

This means if you want to fit with one array library and predict with another you will have to convert your model's attributes.

You'd have to try out what happens with your array library of choice regarding "fit with data on a GPU" and "predict with data on a CPU". For example, PyTorch will raise an exception if you attempt to mix and match GPU and CPU (haven't tested what happens if you mix and match GPU1 and GPU2).

I think it makes sense to ask for explicit action from the user because of the performance costs of moving data from one device to the next. It seems like the chances are high that a mistake happened if you fit a model with a GPU and then predict with CPU without being explicit about it. For example I can imagine that a frequent case would be fitting on a GPU but then in production predicting on single samples with a CPU. Swapping from development/fitting to production is such a big step that requires lots of explicit "we are now in production!!" that I think we'd be in good company by also making things explicit.

Conversion when persisting a model or when loading it sounds like a good feature for a library that makes it easy to share your models :D


For the plugin architecture we are going with a similar thinking. Plugins should provide a function that lets the user convert a fitted estimator to a fitted numpy based estimator. Similarly, if you fit with plugin X, then plugin X will be used to predict. Plugins can accept input with any type if they want to, so a GPU plugin can accept numpy arrays. So you can fit with cupy arrays on your GPU and then call predict() with a numpy array (if your plugin supports this). What is still a bit under discussion is whether you need to explicitly activate the plugin during predict or if it is implicitly activated because it was used for fitting.

@betatim
Copy link
Member
betatim commented Apr 6, 2023

Thinking about this a bit more I think an estimator should raise an explicit exception when it was fitted with, say, a torch tensor and the user is now trying to predict with a cupy array.

Right now, if you fit with torch and then try to transform a numpy array (with array API dispatch on or off) you get the following exception: TypeError: unsupported operand type(s) for -: 'numpy.ndarray' and 'Tensor'. Which is technically correct but I suspect pretty useless for the unsuspecting user.

For example a fitted estimator could check the array namespace of its fitted attributes and compare them to that of the input X.

@ogrisel
Copy link
Member
ogrisel commented Apr 6, 2023

I vaguely remember us talking about some of these issues, but I don't see any active discussion. I might have missed something.

We discussed a bit about it when designing a generic tool to convert everything to numpy, namely sklearn.utils._array_api._estimator_with_converted_arrays as documented in:

Thinking about this a bit more I think an estimator should raise an explicit exception when it was fitted with, say, a torch tensor and the user is now trying to predict with a cupy array.

I think it's safe to do so. If we change our minds in the future and decide to implement some kind of automated conversion, we can always do so without breaking user code while it's not possible to do it the other way around.

@ogrisel
Copy link
Member
ogrisel commented Apr 6, 2023

Note that since _estimator_with_converted_arrays takes a callable as argument, it provides the user with the flexibility to control device placement by leveraging the device support of the Array API:

In the future, we might want to expose a to_device method directly on estimators with Array API support for convenience (e.g. to move a fitted estimator from GPU to CPU or the converse without changing the Array API container types for the fitted attributes.

@ogrisel
Copy link
Member
ogrisel commented Apr 6, 2023

Note that it might be possible to write a generic device-aware array converter for all the array namespace implementations that support the __dlpack__ protocol:

@betatim
Copy link
Member
betatim commented Apr 11, 2023

I had two more thoughts about this for which it would be interesting to hear what y'all think.

  1. do we even need a torch to cupy (as an example) conversion? Or would it "just work" because they are both Array API compatible?
  2. maybe we should recommend using set_config() instead of the context manager in our docs. It might help people to not make mistakes by fitting with a context manager but predicting without and the like. At least I find it hard to predict whether fitted with context manager means I can keep using it w/o or not (if I put myself in the shoes of someone who has not looked at the code). Having the docs talk about turning it on/off as a global thing for your whole program transports a little bit the idea that this is a global flag. Of course we should also write it in words in the docs, but people read/copy&paste code snippets ...

@ogrisel
Copy link
Member
ogrisel commented Apr 12, 2023
  1. do we even need a torch to cupy (as an example) conversion? Or would it "just work" because they are both Array API compatible?

Let's try. I assume you have a cuda machine at hand. If it works, let's add a non-regression test (with the appropriate double pytest.importorskip).

  1. maybe we should recommend using set_config() instead of the context manager in our docs. (...)

It's probably a good idea. Let's open a PR with your suggested changes to see if it simplifies the message as expected. We can always mention that it's possible to scope the effect using the config_context manager as an alternative to set_config at then end of the page.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Discussion
Development

No branches or pull requests

3 participants
0