8000 Open up `check_array` and `BaseEstimator._validate_data` to overriding `xp.asarray` with an additional callable parameter `asarray_fn` · Issue #25433 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Open up check_array and BaseEstimator._validate_data to overriding xp.asarray with an additional callable parameter asarray_fn #25433
@fcharras

Description

@fcharras

Describe the workflow you want to enable

Some people (including @betatim @ogrisel @jjerphan and I) have been devising a plugin system that would open up sklearn estimators to other external implementations, and in particular implementations with GPU backends - see #22438 .

Some of the plugins we're considering can materialize the data in memory with an array library that is compatible with the Array API - namely CuPy and dpctl.tensor.

One thing we've found is that internally those plugins can benefit from using directly BaseEstimator._validate_data and check_array from scikit-learn to do the data acceptation and preparation step.

Describe your proposed solution

To enable this it would be nice to be able to pass a asarray_fn to check_array and _validate_data, that would be called instead of xp.asarray in _asarray_with_order . This would enable the plugin to convert directly the input data to an array that the plugin supports (e.g. cupy or dpctl.tensor) while still benefiting from reusing existing validation code in check_array.

The override can be necessary in case the asarray method from the array library implements a superset of the array api that is necessary for the plugin, but is currently not used by check_array because it's not part of the array api (for instance, the order argument isn't passed to asarray for array libraries other than numpy)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0