8000 Add `_asarray_fn` override to `check_array` by fcharras · Pull Request #25434 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Add _asarray_fn override to check_array #25434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

fcharras
Copy link
Contributor

Closes #25433

Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a future where we would like estimators to be compatible with the array-API, we will need all np.asarray by our helper _asarray_with_order, I assume.

What is not clear to me is, how one would provide the array_fn in this case. Would it not be better to have somewhat a backend registered somewhere and pick up the asarray_fn provided, if one exists?

I am wondering if adding the keyword to check_array would become obsolete in the future.

I obviously missing some context because I did not review any of the plugin PRs.

@@ -728,6 +729,17 @@ def check_array(

.. versionadded:: 1.1.0

_asarray_fn : callable or None, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird to add a "private" thing in a public function

If not None, this callable will be used in place of calls to
`np.asarray` and `xp.asarray` (where `xp` can be any array namespace
implementing the Array API) when the data is converted to an array
object. Its signature must conform to the Array API specification for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
object. Its signature must conform to the Array API specification for
object. Its signature must conform to the Array API specification for

implement a superset of the Array API specifications and need some of
the extra arguments for input conversion (such as `order`).

.. versionadded:: 1.3.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
.. versionadded:: 1.3.0
.. versionadded:: 1.3

def _asarray_with_order(array, dtype=None, order=None, copy=None, xp=None):
"""Helper to support the order kwarg only for NumPy-backed arrays
def _asarray_with_order(
array, dtype=None, order=None, copy=None, _asarray_fn=None, xp=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we don't impose keywords only, it would be safer to place it at the last position even if this is a private function.

@betatim
Copy link
Member
betatim commented Jan 19, 2023 90C6

What is not clear to me is, how one would provide the array_fn in this case. Would it not be better to have somewhat a backend registered somewhere and pick up the asarray_fn provided, if one exists?

I am wondering if adding the keyword to check_array would become obsolete in the future.

I obviously missing some context because I did not review any of the plugin PRs.

I'll try to add a bit of context.

As a plugin you get passed the input the user provided when they call fit(X, y), same with transform or predict. Taking this input X. In scikit-learn check_array() does a lot of work when it comes to converting, sanity checking and validating that user input. It would be nice if plugins could re-use check_array instead of having to maintain a copy of it.

As part of its work check_array can/always calls asarray. So when check_array is called from the plugin code, it would be nice to be able to pass in a asarray that directly creates a cupy, dpctl, etc array.

I need to think about the idea of "plugin registers a asarray that gets picked up" idea.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Open up check_array and BaseEstimator._validate_data to overriding xp.asarray with an additional callable parameter asarray_fn
3 participants
0