10000 ENH Array API support for PCA by mtsokol · Pull Request #26315 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Array API support for PCA #26315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 79 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
dd4c9fc
ENH Adds PyTorch support for PCA
mtsokol May 1, 2023
ceb10e3
ENH Support get_precision and get_covariance
mtsokol May 3, 2023
0b8592c
Merge branch 'main' into feature/array_api_compat_pca
mtsokol Jun 16, 2023
2ae83c0
ENH apply review comments
mtsokol Jun 16, 2023
1b4a7cd
Fix multi-fancy indexing via using xp.take on flattened arrays in svd…
ogrisel Jun 16, 2023
cf86c45
Unit test for svd_flip
ogrisel Jun 18, 2023
c84e4ef
Fix stide related logic in call to xp.take with 1d args
ogrisel Jun 18, 2023
3194b7e
Delete dead code
ogrisel Jun 18, 2023
f1546c4
Run PCA array API tests manually
ogrisel Jun 19, 2023
fd8a217
Do note check concrete values in Array API common test by default
ogrisel Jun 19, 2023
51c9596
Move namespace+parameter generation to _array_api
betatim Jun 19, 2023
8f67f46
Rename function
betatim Jun 19, 2023
e1b7230
Test .score and .score_samples
ogrisel Jun 19, 2023
9c90da9
Progress on fixing .score
ogrisel Jun 20, 2023
697618e
Make accuracy_score and score return float explicitly
ogrisel Jun 20, 2023
5cb75eb
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jun 20, 2023
a7aba27
Fix docstests and test_check_array_api_input
ogrisel Jun 20, 2023
7b37bca
Fix one more doctest fix
ogrisel Jun 20, 2023
b457925
Cosmit
ogrisel Jun 20, 2023
76315eb
Fix one more doctest fix
ogrisel Jun 21, 2023
8291a02
Move and update changelog entry
ogrisel Jun 21, 2023
25e80f4
Update the Array API doc page to mention PCA
ogrisel Jun 21, 2023
73787fe
WIP array api for randomized_svd
ogrisel Jun 21, 2023
3fd48c1
Merge main
ogrisel Jun 22, 2023
f97e1d7
pytest parametrization to run custom Array API checks in test_pca
ogrisel Jun 22, 2023
9afc7dc
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jun 22, 2023
266e1a9
Pass estimator to check yielder
ogrisel Jun 22, 2023
e1272e2
Update sklearn/decomposition/_base.py
ogrisel Jun 22, 2023
0916c85
Fix check_array_api_input and update PCA accordingly
ogrisel Jun 25, 2023
41fbc6a
Compare namespace names
ogrisel Jun 25, 2023
b50542e
Keep on using scipy.linalg.svd in PCA by default
ogrisel Jun 25, 2023
51b348a
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jun 25, 2023
f3e6ebf
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jun 26, 2023
e044dc8
cosmetics
ogrisel Jun 26, 2023
e1c1474
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jun 27, 2023
6586822
Test errors when calling PCA with unsupported parameter values
ogrisel Jun 28, 2023
e232837
Protect array_api test against missing soft dependency
ogrisel Jun 28, 2023
764e246
More consistent use of scipy.linalg when array api is disabled
ogrisel Jun 28, 2023
a21bac4
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jun 28, 2023
efe75a0
Improve coverage and simplify code
ogrisel Jun 28, 2023
396752b
Comment about lazy evaluation
ogrisel Jun 29, 2023
3396e46
Simplify power iteration by using @ instead
ogrisel Jun 29, 2023
dfe0b55
Update sklearn/utils/_array_api.py
ogrisel Jun 29, 2023
713f88e
Clean-up left over
ogrisel Jun 29, 2023
3bf6e50
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jun 29, 2023
6920561
Fix randomized_range_finder docstring to reflect the latest version o…
ogrisel Jun 29, 2023
6ff1ab1
Better not use np.newaxis in Array API code
ogrisel Jun 29, 2023
60e307a
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jun 30, 2023
c182fbc
Preserve device info svd_flip
ogrisel Jul 1, 2023
c72f586
Update array-api-compat version in build_tools/azure/pylatest_conda_f…
ogrisel Jul 3, 2023
dcea203
Fix xp.take in sklearn.linear_model by passing axis argument to fit a…
ogrisel Jul 3, 2023
5f38eff
Fix LinearDiscriminantAnalysis.score to work with cupy
ogrisel Jul 3, 2023
41164e9
Test get_covariance / get_precision
ogrisel Jul 3, 2023
8193209
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jul 3, 2023
42524bd
Typo
ogrisel Jul 3, 2023
b45547d
Add comment to explain why we keep the scipy linalg.svd code path for…
ogrisel Jul 5, 2023
ee7ab50
Use array_api_compat.to_device
ogrisel Jul 5, 2023
1eb368d
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jul 5, 2023
acdefc0
Improve the to_device helper (with missing docstring)
ogrisel Jul 5, 2023
6c355e6
Use np.newaxis instead of None
ogrisel Jul 5, 2023
5a53dc6
Use _is_numpy_namespace / xp.asarray for all numpy backed xp values
ogrisel Jul 5, 2023
e1cf17d
Apply suggestions from code review
ogrisel Jul 5, 2023
1d7b802
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jul 5, 2023
4dd424b
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jul 6, 2023
2270b10
Revert no longer needed change to parametrize_with_checks
ogrisel Jul 6, 2023
841d4dc
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jul 6, 2023
20e0ecf
Leverage broadcasting to sparse a temp allocation
ogrisel Jul 10, 2023
47091b8
Remove changelog merge typo
ogrisel Jul 10, 2023
cb5c03f
Keep on using scipy.linalg.inv in PCA.get_covariance by default
ogrisel Jul 10, 2023
4093a6c
Add not on combined dtype conversion and device move
ogrisel Jul 10, 2023
2216276
Spare one more temporary allocation.
ogrisel Jul 10, 2023
5032d1f
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jul 10, 2023
5abc2bf
Assume to_device is always called under an is_array_api_compliant con…
ogrisel Jul 10, 2023
e30bfa8
Extend the common tests to handle the case when array api is used wit…
ogrisel Jul 11, 2023
bb564bb
Extend the common tests to handle the case when array api is used wit…
ogrisel Jul 11, 2023
3e28c50
Simplify condition to protect scipy.linalg.svd in randomized_svd
ogrisel Jul 12, 2023
ea4fc2e
Merge branch 'main' into feature/array_api_compat_pca
ogrisel Jul 12, 2023
b881ff1
Use xp.asarray + device instead of a new to_device helper
ogrisel Jul 13, 2023
aa9a33a
Fix randomized_range_finder with sparse matrices
ogrisel Jul 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 40 additions & 39 deletions build_tools/azure/pylatest_conda_forge_mkl_linux-64_conda.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ the tensors directly::
Estimators with support for `Array API`-compatible inputs
=========================================================

- :class:`decomposition.PCA` (with `svd_solver="full"`,
`svd_solver="randomized"` and `power_iteration_normalizer="QR"`)
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)

Coverage for more estimators is expected to grow over time. Please follow the
Expand All @@ -107,4 +109,4 @@ To run these checks you need to install
test environment. To run the full set of checks you need to install both
`PyTorch <https://pytorch.org/>`_ and `CuPy <https://cupy.dev/>`_ and have
a GPU. Checks that can not be executed or have missing dependencies will be
automatically skipped.
automatically skipped.
6 changes: 3 additions & 3 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ where :math:`1(x)` is the `indicator function
>>> accuracy_score(y_true, y_pred)
0.5
>>> accuracy_score(y_true, y_pred, normalize=False)
2
2.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is impacted because we now call float(...) explicitly on the output of accuracy_score and zero_one_loss to return a Python scalar instead of 0-dim numpy object whose dtype can vary in hard to predict ways depending on the inputs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: the docstring of accuracy_score specifies that the return type is float, so this might be considered a bugfix. Not sure if we need a changelog entry for this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess no one has noticed this difference till now because the 0-dim numpy array is mostly indistinguishable from a float?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, unless it as an int dtype as was the case in this particular code snippet prior to this PR :)


In the multilabel case with binary label indicators::

Expand Down Expand Up @@ -1696,7 +1696,7 @@ loss can also be computed as :math:`zero-one loss = 1 - accuracy`.
>>> zero_one_loss(y_true, y_pred)
0.25
>>> zero_one_loss(y_true, y_pred, normalize=False)
1
1.0

In the multilabel case with binary label indicators, where the first label
set [0,1] has an error::
Expand All @@ -1705,7 +1705,7 @@ set [0,1] has an error::
0.5

>>> zero_one_loss(np.array([[0, 1], [1, 1]]), np.ones((2, 2)), normalize=False)
1
1.0

.. topic:: Example:

Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v1.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.


:mod:`sklearn.base`
...................

Expand All @@ -61,6 +62,12 @@ Changelog
from `None` to `auto` in version 1.6.
:pr:`26634` by :user:`Alexandre Landeau <AlexL>` and :user:`Alexandre Vigny <avigny>`.

- |Enhancement| :class:`decomposition.PCA` now supports the Array API for the
`full` and `randomized` solvers (with QR power iterations). See
:ref:`array_api` for more details.
:pr:`26315` by :user:`Mateusz Sokół <mtsokol>` and
:user:`Olivier Grisel <ogrisel>`.

:mod:`sklearn.ensemble`
.......................

Expand Down
63 changes: 40 additions & 23 deletions sklearn/decomposition/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy import linalg

from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
from ..utils._array_api import _add_to_diagonal, get_namespace
from ..utils.validation import check_is_fitted


Expand All @@ -38,13 +39,18 @@ def get_covariance(self):
cov : array of shape=(n_features, n_features)
Estimated covariance of data.
"""
xp, _ = get_namespace(self.components_)

components_ = self.components_
exp_var = self.explained_variance_
if self.whiten:
components_ = components_ * np.sqrt(exp_var[:, np.newaxis])
exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.0)
cov = np.dot(components_.T * exp_var_diff, components_)
cov.flat[:: len(cov) + 1] += self.noise_variance_ # modify diag inplace
components_ = components_ * xp.sqrt(exp_var[:, np.newaxis])
exp_var_diff = exp_var - self.noise_variance_
exp_var_diff = xp.where(
exp_var > self.noise_variance_, exp_var_diff, xp.asarray(0.0)
)
cov = (components_.T * exp_var_diff) @ components_
_add_to_diagonal(cov, self.noise_variance_, xp)
return cov

def get_precision(self):
Expand All @@ -58,26 +64,36 @@ def get_precision(self):
precision : array, shape=(n_features, n_features)
Estimated precision of data.
"""
xp, is_array_api_compliant = get_namespace(self.components_)

n_features = self.components_.shape[1]

# handle corner cases first
if self.n_components_ == 0:
return np.eye(n_features) / self.noise_variance_
return xp.eye(n_features) / self.noise_variance_

if is_array_api_compliant:
linalg_inv = xp.linalg.inv
else:
linalg_inv = linalg.inv

if np.isclose(self.noise_variance_, 0.0, atol=0.0):
return linalg.inv(self.get_covariance())
if self.noise_variance_ == 0.0:
return linalg_inv(self.get_covariance())

# Get precision using matrix inversion lemma
components_ = self.components_
exp_var = self.explained_variance_
if self.whiten:
components_ = components_ * np.sqrt(exp_var[:, np.newaxis])
exp_var_diff = np.maximum(exp_var - self.noise_variance_, 0.0)
precision = np.dot(components_, components_.T) / self.noise_variance_
precision.flat[:: len(precision) + 1] += 1.0 / exp_var_diff
precision = np.dot(components_.T, np.dot(linalg.inv(precision), components_))
components_ = components_ * xp.sqrt(exp_var[:, np.newaxis])
exp_var_diff = exp_var - self.noise_variance_
exp_var_diff = xp.where(
exp_var > self.noise_variance_, exp_var_diff, xp.asarray(0.0)
)
precision = components_ @ components_.T / self.noise_variance_
_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
precision = components_.T @ linalg_inv(precision) @ components_
precision /= -(self.noise_variance_**2)
precision.flat[:: len(precision) + 1] += 1.0 / self.noise_variance_
_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
return precision

@abstractmethod
Expand Down Expand Up @@ -116,14 +132,16 @@ def transform(self, X):
Projection of X in the first principal components, where `n_samples`
is the number of samples and `n_components` is the number of the components.
"""
xp, _ = get_namespace(X)

check_is_fitted(self)

X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False)
X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False)
if self.mean_ is not None:
X = X - self.mean_
X_transformed = np.dot(X, self.components_.T)
X_transformed = X @ self.components_.T
if self.whiten:
X_transformed /= np.sqrt(self.explained_variance_)
X_transformed /= xp.sqrt(self.explained_variance_)
return X_transformed

def inverse_transform(self, X):
Expand All @@ -148,16 +166,15 @@ def inverse_transform(self, X):
If whitening is enabled, inverse_transform will compute the
exact inverse operation, which includes reversing whitening.
"""
xp, _ = get_namespace(X)

if self.whiten:
return (
np.dot(
X,
np.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_,
)
+ self.mean_
scaled_components = (
xp.sqrt(self.explained_variance_[:, np.newaxis]) * self.components_
)
return X @ scaled_components + self.mean_
else:
return np.dot(X, self.components_) + self.mean_
return X @ self.components_ + self.mean_

@property
def _n_features_out(self):
Expand Down
65 changes: 44 additions & 21 deletions sklearn/decomposition/_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..base import _fit_context
from ..utils import check_random_state
from ..utils._arpack import _init_arpack_v0
from ..utils._array_api import get_namespace
from ..utils._param_validation import Interval, RealNotInt, StrOptions
from ..utils.deprecation import deprecated
from ..utils.extmath import fast_logdet, randomized_svd, stable_cumsum, svd_flip
Expand Down Expand Up @@ -108,8 +109,10 @@ def _infer_dimension(spectrum, n_samples):

The returned value will be in [1, n_features - 1].
"""
ll = np.empty_like(spectrum)
ll[0] = -np.inf # we don't want to return n_components = 0
xp, _ = get_namespace(spectrum)

ll = xp.empty_like(spectrum)
ll[0] = -xp.inf # we don't want to return n_components = 0
for rank in range(1, spectrum.shape[0]):
ll[rank] = _assess_dimension(spectrum, rank, n_samples)
return ll.argmax()
Expand Down Expand Up @@ -471,6 +474,7 @@ def fit_transform(self, X, y=None):

def _fit(self, X):
"""Dispatch to the right submethod depending on the chosen solver."""
xp, is_array_api_compliant = get_namespace(X)

# Raise an error for sparse input.
# This is more informative than the generic one raised by check_array.
Expand All @@ -479,9 +483,14 @@ def _fit(self, X):
"PCA does not support sparse input. See "
"TruncatedSVD for a possible alternative."
)
# Raise an error for non-Numpy input and arpack solver.
if self.svd_solver == "arpack" and is_array_api_compliant:
raise ValueError(
"PCA with svd_solver='arpack' is not supported for Array API inputs."
)

X = self._validate_data(
X, dtype=[np.float64, np.float32], ensure_2d=True, copy=self.copy
X, dtype=[xp.float64, xp.float32], ensure_2d=True, copy=self.copy
)

# Handle n_components==None
Expand Down Expand Up @@ -513,6 +522,8 @@ def _fit(self, X):

def _fit_full(self, X, n_components):
"""Fit the model by computing full SVD on X."""
xp, is_array_api_compliant = get_namespace(X)

n_samples, n_features = X.shape

if n_components == "mle":
Expand All @@ -528,20 +539,30 @@ def _fit_full(self, X, n_components):
)

# Center data
self.mean_ = np.mean(X, axis=0)
self.mean_ = xp.mean(X, axis=0)
X -= self.mean_

U, S, Vt = linalg.svd(X, full_matrices=False)
if not is_array_api_compliant:
# Use scipy.linalg with NumPy/SciPy inputs for the sake of not
# introducing unanticipated behavior changes. In the long run we
# could instead decide to always use xp.linalg.svd for all inputs,
# but that would make this code rely on numpy's SVD instead of
# scipy's. It's not 100% clear whether they use the same LAPACK
# solver by default though (assuming both are built against the
# same BLAS).
U, S, Vt = linalg.svd(X, full_matrices=False)
else:
U, S, Vt = xp.linalg.svd(X, full_matrices=False)
# flip eigenvectors' sign to enforce deterministic output
U, Vt = svd_flip(U, Vt)

components_ = Vt

# Get variance explained by singular values
explained_variance_ = (S**2) / (n_samples - 1)
total_var = explained_variance_.sum()
total_var = xp.sum(explained_variance_)
explained_variance_ratio_ = explained_variance_ / total_var
singular_values_ = S.copy() # Store the singular values.
singular_values_ = xp.asarray(S, copy=True) # Store the singular values.

# Postprocess the number of components required
if n_components == "mle":
Expand All @@ -553,16 +574,16 @@ def _fit_full(self, X, n_components):
# their variance is always greater than n_components float
# passed. More discussion in issue: #15669
ratio_cumsum = stable_cumsum(explained_variance_ratio_)
n_components = np.searchsorted(ratio_cumsum, n_components, side="right") + 1
n_components = xp.searchsorted(ratio_cumsum, n_components, side="right") + 1
# Compute noise covariance using Probabilistic PCA model
# The sigma2 maximum likelihood (cf. eq. 12.46)
if n_components < min(n_features, n_samples):
self.noise_variance_ = 741A explained_variance_[n_components:].mean()
self.noise_variance_ = xp.mean(explained_variance_[n_components:])
else:
self.noise_variance_ = 0.0

self.n_samples_ = n_samples
self.components_ = components_[:n_components]
self.components_ = components_[:n_components, :]
self.n_components_ = n_components
self.explained_variance_ = explained_variance_[:n_components]
self.explained_variance_ratio_ = explained_variance_ratio_[:n_components]
Expand All @@ -574,6 +595,8 @@ def _fit_truncated(self, X, n_components, svd_solver):
"""Fit the model by computing truncated SVD (by ARPACK or randomized)
on X.
"""
xp, _ = get_namespace(X)

n_samples, n_features = X.shape

if isinstance(n_components, str):
Expand All @@ -599,7 +622,7 @@ def _fit_truncated(self, X, n_components, svd_solver):
random_state = check_random_state(self.random_state)

# Center data
self.mean_ = np.mean(X, axis=0)
self.mean_ = xp.mean(X, axis=0)
X -= self.mean_

if svd_solver == "arpack":
Expand Down Expand Up @@ -633,15 +656,14 @@ def _fit_truncated(self, X, n_components, svd_solver):
# Workaround in-place variance calculation since at the time numpy
# did not have a way to calculate variance in-place.
N = X.shape[0] - 1
np.square(X, out=X)
np.sum(X, axis=0, out=X[0])
total_var = (X[0] / N).sum()
X **= 2
total_var = xp.sum(xp.sum(X, axis=0) / N)

self.explained_variance_ratio_ = self.explained_variance_ / total_var
self.singular_values_ = S.copy() # Store the singular values.
self.singular_values_ = xp.asarray(S, copy=True) # Store the singular values.

if self.n_components_ < min(n_features, n_samples):
self.noise_variance_ = total_var - self.explained_variance_.sum()
self.noise_variance_ = total_var - xp.sum(self.explained_variance_)
self.noise_variance_ /= min(n_features, n_samples) - n_components
else:
self.noise_variance_ = 0.0
Expand All @@ -666,12 +688,12 @@ def score_samples(self, X):
Log-likelihood of each sample under the current model.
"""
check_is_fitted(self)

X = self._validate_data(X, dtype=[np.float64, np.float32], reset=False)
xp, _ = get_namespace(X)
X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False)
Xr = X - self.mean_
n_features = X.shape[1]
precision = self.get_precision()
log_like = -0.5 * (Xr * (np.dot(Xr, precision))).sum(axis=1)
log_like = -0.5 * xp.sum(Xr * (Xr @ precision), axis=1)
log_like -= 0.5 * (n_features * log(2.0 * np.pi) - fast_logdet(precision))
return log_like

Expand All @@ -695,7 +717,8 @@ def score(self, X, y=None):
ll : float
Average log-likelihood of the samples under the current model.
"""
return np.mean(self.score_samples(X))
xp, _ = get_namespace(X)
return float(xp.mean(self.score_samples(X)))

def _more_tags(self):
return {"preserves_dtype": [np.float64, np.float32]}
return {"preserves_dtype": [np.float64, np.float32], "array_api_support": True}
Loading
0