8000 ENH Adds Array API support to LinearDiscriminantAnalysis (#22554) · thomasjpfan/scikit-learn@2710a9e · GitHub
[go: up one dir, main page]

Skip to content

Commit 2710a9e

Browse files
thomasjpfanogriseljjerphan
authored
ENH Adds Array API support to LinearDiscriminantAnalysis (scikit-learn#22554)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 0c79efd commit 2710a9e

File tree

15 files changed

+779
-95
lines changed

15 files changed

+779
-95
lines changed

doc/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ def skip_if_matplotlib_not_installed(fname):
107107
raise SkipTest(f"Skipping doctests for {basename}, matplotlib not installed")
108108

109109

110+
def skip_if_cupy_not_installed(fname):
111+
try:
112+
import cupy # noqa
113+
except ImportError:
114+
basename = os.path.basename(fname)
115+
raise SkipTest(f"Skipping doctests for {basename}, cupy not installed")
116+
117+
110118
def pytest_runtest_setup(item):
111119
fname = item.fspath.strpath
112120
# normalize filename to use forward slashes on Windows for easier handling
@@ -147,6 +155,9 @@ def pytest_runtest_setup(item):
147155
if fname.endswith(each):
148156
skip_if_matplotlib_not_installed(fname)
149157

158+
if fname.endswith("array_api.rst"):
159+
skip_if_cupy_not_installed(fname)
160+
150161

151162
def pytest_configure(config):
152163
# Use matplotlib agg backend during the tests including doctests

doc/dispatching.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
.. Places parent toc into the sidebar
2+
3+
:parenttoc: True
4+
5+
.. include:: includes/big_toc_css.rst
6+
7+
===========
8+
Dispatching
9+
===========
10+
11+
.. toctree::
12+
:maxdepth: 2
13+
14+
modules/array_api

doc/modules/array_api.rst

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
.. Places parent toc into the sidebar
2+
3+
:parenttoc: True
4+
5+
.. _array_api:
6+
7+
================================
8+
Array API support (experimental)
9+
================================
10+
11+
.. currentmodule:: sklearn
12+
13+
The `Array API <https://data-apis.org/array-api/latest/>`_ specification defines
14+
a standard API for all array manipulation libraries with a NumPy-like API.
15+
16+
Some scikit-learn estimators that primarily rely on NumPy (as opposed to using
17+
Cython) to implement the algorithmic logic of their `fit`, `predict` or
18+
`transform` methods can be configured to accept any Array API compatible input
19+
datastructures and automatically dispatch operations to the underlying namespace
20+
instead of relying on NumPy.
21+
22+
At this stage, this support is **considered experimental** and must be enabled
23+
explicitly as explained in the following.
24+
25+
.. note::
26+
Currently, only `cupy.array_api` and `numpy.array_api` are known to work
27+
with scikit-learn's estimators.
28+
29+
Example usage
30+
=============
31+
32+
Here is an example code snippet to demonstrate how to use `CuPy
33+
<https://cupy.dev/>`_ to run
34+
:class:`~discriminant_analysis.LinearDiscriminantAnalysis` on a GPU::
35+
36+
>>> from sklearn.datasets import make_classification
37+
>>> from sklearn import config_context
38+
>>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
39+
>>> import cupy.array_api as xp
40+
41+
>>> X_np, y_np = make_classification(random_state=0)
42+
>>> X_cu = xp.asarray(X_np)
43+
>>> y_cu = xp.asarray(y_np)
44+
>>> X_cu.device
45+
<CUDA Device 0>
46+
47+
>>> with config_context(array_api_dispatch=True):
48+
... lda = LinearDiscriminantAnalysis()
49+
... X_trans = lda.fit_transform(X_cu, y_cu)
50+
>>> X_trans.device
51+
<CUDA Device 0>
52+
53+
After the model is trained, fitted attributes that are arrays will also be
54+
from the same Array API namespace as the training data. For example, if CuPy's
55+
Array API namespace was used for training, then fitted attributes will be on the
56+
GPU. We provide a experimental `_estimator_with_converted_arrays` utility that
57+
transfers an estimator attributes from Array API to a ndarray::
58+
59+
>>> from sklearn.utils._array_api import _estimator_with_converted_arrays
60+
>>> cupy_to_ndarray = lambda array : array._array.get()
61+
>>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray)
62+
>>> X_trans = lda_np.transform(X_np)
63+
>>> type(X_trans)
64+
<class 'numpy.ndarray'>
65+
66+
.. _array_api_estimators:
67+
68+
Estimators with support for `Array API`-compatible inputs
69+
=========================================================
70+
71+
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
72+
73+
Coverage for more estimators is expected to grow over time. Please follow the
74+
dedicated `meta-issue on GitHub
75+
<https://github.com/scikit-learn/scikit-learn/issues/22352>`_ to track progress.

doc/user_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ User Guide
3030
computing.rst
3131
model_persistence.rst
3232
common_pitfalls.rst
33+
dispatching.rst

doc/whats_new/v1.2.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,12 @@ Changelog
223223
:mod:`sklearn.discriminant_analysis`
224224
....................................
225225

226+
- |MajorFeature| :class:`discriminant_analysis.LinearDiscriminantAnalysis` now
227+
supports the `Array API <https://data-apis.org/array-api/latest/>`_ for
228+
`solver="svd"`. Array API support is considered experimental and might evolve
229+
without being subjected to our usual rolling deprecation cycle policy. See
230+
:ref:`array_api` for more details. :pr:`22554` by `Thomas Fan`_.
231+
226232
- |Fix| Validate parameters only in `fit` and not in `__init__`
227233
for :class:`discriminant_analysis.QuadraticDiscriminantAnalysis`.
228234
:pr:`24218` by :user:`Stefanie Molin <stefmolin>`.

sklearn/_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
1414
),
1515
"enable_cython_pairwise_dist": True,
16+
"array_api_dispatch": False,
1617
}
1718
_threadlocal = threading.local()
1819

@@ -50,6 +51,7 @@ def set_config(
5051
display=None,
5152
pairwise_dist_chunk_size=None,
5253
enable_cython_pairwise_dist=None,
54+
array_api_dispatch=None,
5355
):
5456
"""Set global scikit-learn configuration
5557
@@ -110,6 +112,14 @@ def set_config(
110112
111113
.. versionadded:: 1.1
112114
115+
array_api_dispatch : bool, default=None
116+
Use Array API dispatching when inputs follow the Array API standard.
117+
Default is False.
118+
119+
See the :ref:`User Guide <array_api>` for more details.
120+
121+
.. versionadded:: 1.2
122+
113123
See Also
114124
--------
115125
config_context : Context manager for global scikit-learn configuration.
@@ -129,6 +139,8 @@ def set_config(
129139
local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size
130140
if enable_cython_pairwise_dist is not None:
131141
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
142+
if array_api_dispatch is not None:
143+
local_config["array_api_dispatch"] = array_api_dispatch
132144

133145

134146
@contextmanager
@@ -140,6 +152,7 @@ def config_context(
140152
display=None,
141153
pairwise_dist_chunk_size=None,
142154
enable_cython_pairwise_dist=None,
155+
array_api_dispatch=None,
143156
):
144157
"""Context manager for global scikit-learn configuration.
145158
@@ -199,6 +212,14 @@ def config_context(
199212
200213
.. versionadded:: 1.1
201214
215+
array_api_dispatch : bool, default=None
216+
Use Array API dispatching when inputs follow the Array API standard.
217+
Default is False.
218+
219+
See the :ref:`User Guide <array_api>` for more details.
220+
221+
.. versionadded:: 1.2
222+
202223
Yields
203224
------
204225
None.
@@ -234,6 +255,7 @@ def config_context(
234255
display=display,
235256
pairwise_dist_chunk_size=pairwise_dist_chunk_size,
236257
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
258+
array_api_dispatch=array_api_dispatch,
237259
)
238260

239261
try:

0 commit comments

Comments
 (0)
0