-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
RFC/API (Array API) mixing devices and data types with estimators #26083
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
The current state of thinking regarding estimator attributes and their conversion is in the documentation https://scikit-learn.org/dev/modules/array_api.html. TL;DR: the user has to convert it explicitly. A reason not to offer something more "built in" is that conversion from array library X to Numpy and the other direction is something that depends on the array library. For Cupy to Numpy you need This means if you want to fit with one array library and predict with another you will have to convert your model's attributes. You'd have to try out what happens with your array library of choice regarding "fit with data on a GPU" and "predict with data on a CPU". For example, PyTorch will raise an exception if you attempt to mix and match GPU and CPU (haven't tested what happens if you mix and match GPU1 and GPU2). I think it makes sense to ask for explicit action from the user because of the performance costs of moving data from one device to the next. It seems like the chances are high that a mistake happened if you fit a model with a GPU and then predict with CPU without being explicit about it. For example I can imagine that a frequent case would be fitting on a GPU but then in production predicting on single samples with a CPU. Swapping from development/fitting to production is such a big step that requires lots of explicit "we are now in production!!" that I think we'd be in good company by also making things explicit. Conversion when persisting a model or when loading it sounds like a good feature for a library that makes it easy to share your models :D For the plugin architecture we are going with a similar thinking. Plugins should provide a function that lets the user convert a fitted estimator to a fitted numpy based estimator. Similarly, if you fit with plugin X, then plugin X will be used to predict. Plugins can accept input with any type if they want to, so a GPU plugin can accept numpy arrays. So you can fit with cupy arrays on your GPU and then call |
Thinking about this a bit more I think an estimator should raise an explicit exception when it was fitted with, say, a torch tensor and the user is now trying to predict with a cupy array. Right now, if you For example a fitted estimator could check the array namespace of its fitted attributes and compare them to that of the input |
We discussed a bit about it when designing a generic tool to convert everything to numpy, namely
I think it's safe to do so. If we change our minds in the future and decide to implement some kind of automated conversion, we can always do so without breaking user code while it's not possible to do it the other way around. |
Note that since In the future, we might want to expose a |
Note that it might be possible to write a generic device-aware array converter for all the array namespace implementations that support the |
I had two more thoughts about this for which it would be interesting to hear what y'all think.
|
Let's try. I assume you have a cuda machine at hand. If it works, let's add a non-regression test (with the appropriate double
It's probably a good idea. Let's open a PR with your suggested changes to see if it simplifies the message as expected. We can always mention that it's possible to scope the effect using the |
Right now, if the user fits an estimator using a
pandas.DataFrame
, but passes anumpy.ndarray
duringpredict
, they get a warning due to missing feature names.The situation is only to get more complicated as we're adding support for more types via array API.
Some related issues here are:
fit
sits on a GPU, but a CPU is used for predict (with the same data type)fit
, and use another type duringpredict
: how do we handle this both in terms of device and the type? Do we let the operator figure out if they can coerce the data into the type which can be used?I vaguely remember us talking about some of these issues, but I don't see any active discussion. I might have missed something.
Related: in a library like
pytorch
, you can decide which device is going to be used when you load a model's weights.cc @thomasjpfan
The text was updated successfully, but these errors were encountered: