8000 ENH Adds PyTorch support to LinearDiscriminantAnalysis (#25956) · scikit-learn/scikit-learn@66a3905 · GitHub
[go: up one dir, main page]

Skip to content

Commit 66a3905

Browse files
thomasjpfanogriselbetatim
authored
ENH Adds PyTorch support to LinearDiscriminantAnalysis (#25956)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Tim Head <betatim@gmail.com>
1 parent a187758 commit 66a3905

17 files changed

+457
-162
lines changed

build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock

Lines changed: 35 additions & 40 deletions
Large diffs are not rendered by default.

build_tools/azure/pylatest_conda_forge_mkl_linux-64_environment.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ dependencies:
2020
- pytest-cov
2121
- coverage
2222
- ccache
23+
- pytorch=1.13
24+
- pytorch-cpu
25+
- array-api-compat

build_tools/update_environments_and_lock_files.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,15 @@ def remove_from(alist, to_remove):
8888
"folder": "build_tools/azure",
8989
"platform": "linux-64",
9090
"channel": "conda-forge",
91-
"conda_dependencies": common_dependencies + ["ccache"],
91+
"conda_dependencies": common_dependencies + [
92+
"ccache",
93+
"pytorch",
94+
"pytorch-cpu",
95+
"array-api-compat",
96+
],
9297
"package_constraints": {
9398
"blas": "[build=mkl]",
99+
"pytorch": "1.13",
94100
},
95101
},
96102
{

doc/modules/array_api.rst

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ Array API support (experimental)
1212

1313
The `Array API <https://data-apis.org/array-api/latest/>`_ specification defines
1414
a standard API for all array manipulation libraries with a NumPy-like API.
15+
Scikit-learn's Array API support requires
16+
`array-api-compat <https://github.com/data-apis/array-api-compat>`__ to be installed.
1517

1618
Some scikit-learn estimators that primarily rely on NumPy (as opposed to using
1719
Cython) to implement the algorithmic logic of their `fit`, `predict` or
@@ -23,8 +25,8 @@ At this stage, this support is **considered experimental** and must be enabled
2325
explicitly as explained in the following.
2426

2527
.. note::
26-
Currently, only `cupy.array_api` and `numpy.array_api` are known to work
27-
with scikit-learn's estimators.
28+
Currently, only `cupy.array_api`, `numpy.array_api`, `cupy`, and `PyTorch`
29+
are known to work with scikit-learn's estimators.
2830

2931
Example usage
3032
=============
@@ -36,11 +38,11 @@ Here is an example code snippet to demonstrate how to use `CuPy
3638
>>> from sklearn.datasets import make_classification
3739
>>> from sklearn import config_context
3840
>>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
39-
>>> import cupy.array_api as xp
41+
>>> import cupy
4042

4143
>>> X_np, y_np = make_classification(random_state=0)
42-
>>> X_cu = xp.asarray(X_np)
43-
>>> y_cu = xp.asarray(y_np)
44+
>>> X_cu = cupy.asarray(X_np)
45+
>>> y_cu = cupy.asarray(y_np)
4446
>>> X_cu.device
4547
<CUDA Device 0>
4648

@@ -57,12 +59,30 @@ GPU. We provide a experimental `_estimator_with_converted_arrays` utility that
5759
transfers an estimator attributes from Array API to a ndarray::
5860

5961
>>> from sklearn.utils._array_api import _estimator_with_converted_arrays
60-
>>> cupy_to_ndarray = lambda array : array._array.get()
62+
>>> cupy_to_ndarray = lambda array : array.get()
6163
>>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray)
6264
>>> X_trans = lda_np.transform(X_np)
6365
>>> type(X_trans)
6466
<class 'numpy.ndarray'>
6567

68+
PyTorch Support
69+
---------------
70+
71+
PyTorch Tensors are supported by setting `array_api_dispatch=True` and passing in
72+
the tensors directly::
73+
74+
>>> import torch
75+
>>> X_torch = torch.asarray(X_np, device="cuda", dtype=torch.float32)
76+
>>> y_torch = torch.asarray(y_np, device="cuda", dtype=torch.float32)
77+
78+
>>> with config_context(array_api_dispatch=True):
79+
... lda = LinearDiscriminantAnalysis()
80+
... X_trans = lda.fit_transform(X_torch, y_torch)
81+
>>> type(X_trans)
82+
<class 'torch.Tensor'>
83+
>>> X_trans.device.type
84+
'cuda'
85+
6686
.. _array_api_estimators:
6787

6888
Estimators with support for `Array API`-compatible inputs

doc/whats_new/v1.3.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,13 @@ Changelog
230230
:class:`decomposition.MiniBatchNMF` which can produce different results than previous
231231
versions. :pr:`25438` by :user:`Yotam Avidar-Constantini <yotamcons>`.
232232

233+
:mod:`sklearn.discriminant_analysis`
234+
....................................
235+
236+
- |Enhancement| :class:`discriminant_analysis.LinearDiscriminantAnalysis` now
237+
supports the `PyTorch <https://pytorch.org/>`__. See
238+
:ref:`array_api` for more details. :pr:`25956` by `Thomas Fan`_.
239+
233240
:mod:`sklearn.ensemble`
234241
.......................
235242

sklearn/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ def set_config(
154154
if enable_cython_pairwise_dist is not None:
155155
local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
156156
if array_api_dispatch is not None:
157+
from .utils._array_api import _check_array_api_dispatch
158+
159+
_check_array_api_dispatch(array_api_dispatch)
157160
local_config["array_api_dispatch"] = array_api_dispatch
158161
if transform_output is not None:
159162
local_config["transform_output"] = transform_output

sklearn/discriminant_analysis.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from .covariance import ledoit_wolf, empirical_covariance, shrunk_covariance
2222
from .utils.multiclass import unique_labels
2323
from .utils.validation import check_is_fitted
24-
from .utils._array_api import get_namespace, _expit
24+
from .utils._array_api import get_namespace, _expit, device, size
2525
from .utils.multiclass import check_classification_targets
2626
from .utils.extmath import softmax
2727
from .utils._param_validation import StrOptions, Interval, HasMethods
@@ -107,11 +107,11 @@ def _class_means(X, y):
107107
means : array-like of shape (n_classes, n_features)
108108
Class means.
109109
"""
110-
xp, is_array_api = get_namespace(X)
110+
xp, is_array_api_compliant = get_namespace(X)
111111
classes, y = xp.unique_inverse(y)
112-
means = xp.zeros(shape=(classes.shape[0], X.shape[1]))
112+
means = xp.zeros((classes.shape[0], X.shape[1]), device=device(X), dtype=X.dtype)
113113

114-
if is_array_api:
114+
if is_array_api_compliant:
115115
for i in range(classes.shape[0]):
116116
means[i, :] = xp.mean(X[y == i], axis=0)
117117
else:
@@ -483,9 +483,9 @@ def _solve_svd(self, X, y):
483483
y : array-like of shape (n_samples,) or (n_samples, n_targets)
484484
Target values.
485485
"""
486-
xp, is_array_api = get_namespace(X)
486+
xp, is_array_api_compliant = get_namespace(X)
487487

488-
if is_array_api:
488+
if is_array_api_compliant:
489489
svd = xp.linalg.svd
490490
else:
491491
svd = scipy.linalg.svd
@@ -586,9 +586,9 @@ def fit(self, X, y):
586586

587587
if self.priors is None: # estimate priors from sample
588588
_, cnts = xp.unique_counts(y) # non-negative ints
589-
self.priors_ = xp.astype(cnts, xp.float64) / float(y.shape[0])
589+
self.priors_ = xp.astype(cnts, X.dtype) / float(y.shape[0])
590590
else:
591-
self.priors_ = xp.asarray(self.priors)
591+
self.priors_ = xp.asarray(self.priors, dtype=X.dtype)
592592

593593
if xp.any(self.priors_ < 0):
594594
raise ValueError("priors must be non-negative")
@@ -634,7 +634,7 @@ def fit(self, X, y):
634634
shrinkage=self.shrinkage,
635635
covariance_estimator=self.covariance_estimator,
636636
)
637-
if self.classes_.size == 2: # treat binary case as a special case
637+
if size(self.classes_) == 2: # treat binary case as a special case
638638
coef_ = xp.asarray(self.coef_[1, :] - self.coef_[0, :], dtype=X.dtype)
639639
self.coef_ = xp.reshape(coef_, (1, -1))
640640
intercept_ = xp.asarray(
@@ -688,9 +688,9 @@ def predict_proba(self, X):
688688
Estimated probabilities.
689689
"""
690690
check_is_fitted(self)
691-
xp, is_array_api = get_namespace(X)
691+
xp, is_array_api_compliant = get_namespace(X)
692692
decision = self.decision_function(X)
693-
if self.classes_.size == 2:
693+
if size(self.classes_) == 2:
694694
proba = _expit(decision)
695695
return xp.stack([1 - proba, proba], axis=1)
696696
else:

sklearn/linear_model/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def predict(self, X):
449449
else:
450450
indices = xp.argmax(scores, axis=1)
451451

452-
return xp.take(self.classes_, indices, axis=0)
452+
return xp.take(self.classes_, indices)
453453

454454
def _predict_proba_lr(self, X):
455455
"""Probability estimation for OvR logistic regression.

sklearn/tests/test_config.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import builtins
12
import time
23
from concurrent.futures import ThreadPoolExecutor
34

45
import pytest
56

67
from sklearn import get_config, set_config, config_context
8+
import sklearn
79
from sklearn.utils.parallel import delayed, Parallel
810

911

@@ -145,3 +147,45 @@ def test_config_threadsafe():
145147
]
146148

147149
assert items == [False, True, False, True]
150+
151+
152+
def test_config_array_api_dispatch_error(monkeypatch):
153+
"""Check error is raised when array_api_compat is not installed."""
154+
155+
# Hide array_api_compat import
156+
orig_import = builtins.__import__
157+
158+
def mocked_import(name, *args, **kwargs):
159+
if name == "array_api_compat":
160+
raise ImportError
161+
return orig_import(name, *args, **kwargs)
162+
163+
monkeypatch.setattr(builtins, "__import__", mocked_import)
164+
165+
with pytest.raises(ImportError, match="array_api_compat is required"):
166+
with config_context(array_api_dispatch=True):
167+
pass
168+
169+
with pytest.raises(ImportError, match="array_api_compat is required"):
170+
set_config(array_api_dispatch=True)
171+
172+
173+
def test_config_array_api_dispatch_error_numpy(monkeypatch):
174+
"""Check error when NumPy is too old"""
175+
# Pretend that array_api_compat is installed.
176+
orig_import = builtins.__import__
177+
178+
def mocked_import(name, *args, **kwargs):
179+
if name == "array_api_compat":
180+
return object()
181+
return orig_import(name, *args, **kwargs)
182+
183+
monkeypatch.setattr(builtins, "__import__", mocked_import)
184+
monkeypatch.setattr(sklearn.utils._array_api.numpy, "__version__", "1.20")
185+
186+
with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"):
187+
with config_context(array_api_dispatch=True):
188+
pass
189+
190+
with pytest.raises(ImportError, match="NumPy must be 1.21 or newer"):
191+
set_config(array_api_dispatch=True)

sklearn/tests/test_discriminant_analysis.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn.utils._testing import assert_almost_equal
1414
from sklearn.utils._array_api import _convert_to_numpy
1515
from sklearn.utils._testing import _convert_container
16+
from sklearn.utils._testing import skip_if_array_api_compat_not_configured
1617

1718
from sklearn.datasets import make_blobs
1819
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
@@ -676,6 +677,7 @@ def test_get_feature_names_out():
676677
assert_array_equal(names_out, expected_names_out)
677678

678679

680+
@skip_if_array_api_compat_not_configured
679681
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
680682
def test_lda_array_api(array_namespace):
681683
"""Check that the array_api Array gives the same results as ndarrays."""
@@ -725,6 +727,66 @@ def test_lda_array_api(array_namespace):
725727

726728
result_xp_np = _convert_to_numpy(result_xp, xp=xp)
727729

730+
assert_allclose(
731+
result,
732+
result_xp_np,
733+
err_msg=f"{method} did not the return the same result",
734+
atol=1e-5,
735+
)
736+
737+
738+
@skip_if_array_api_compat_not_configured
739+
@pytest.mark.parametrize("device", ["cuda", "cpu"])
740+
@pytest.mark.parametrize("dtype", ["float32", "float64"])
741+
def test_lda_array_torch(device, dtype):
742+
"""Check running on PyTorch Tensors gives the same results as NumPy"""
743+
torch = pytest.importorskip("torch")
744+
if device == "cuda" and not torch.has_cuda:
745+
pytest.skip("test requires cuda")
746+
747+
lda = LinearDiscriminantAnalysis()
748+
X_np = X6.astype(dtype)
749+
y_np = y6.astype(dtype)
750+
lda.fit(X_np, y_np)
751+
752+
X_torch = torch.asarray(X_np, device=device)
753+
y_torch = torch.asarray(y_np, device=device)
754+
lda_xp = clone(lda)
755+
with config_context(array_api_dispatch=True):
756+
lda_xp.fit(X_torch, y_torch)
757+
758+
array_attributes = {
759+
key: value for key, value in vars(lda).items() if isinstance(value, np.ndarray)
760+
}
761+
762+
for key, attribute in array_attributes.items():
763+
lda_xp_param = getattr(lda_xp, key)
764+
assert isinstance(lda_xp_param, torch.Tensor)
765+
assert lda_xp_param.device.type == device
766+
767+
lda_xp_param_np = _convert_to_numpy(lda_xp_param, xp=torch)
768+
assert_allclose(
769+
attribute, lda_xp_param_np, err_msg=f"{key} not the same", atol=1e-3
770+
)
771+
772+
# Check predictions are the same
773+
methods = (
774+
"decision_function",
775+
"predict",
776+
"predict_log_proba",
777+
"predict_proba",
778+
"transform",
779+
)
780+
for method in methods:
781+
result = getattr(lda, method)(X_np)
782+
with config_context(array_api_dispatch=True):
783+
result_xp = getattr(lda_xp, method)(X_torch)
784+
785+
assert isinstance(result_xp, torch.Tensor)
786+
assert result_xp.device.type == device
787+
788+
result_xp_np = _convert_to_numpy(result_xp, xp=torch)
789+
728790
assert_allclose(
729791
result,
730792
result_xp_np,

0 commit comments

Comments
 (0)
0