8000 ENH Allow fitting PCA on sparse X with arpack solvers (#18689) · scikit-learn/scikit-learn@2d9fa48 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2d9fa48

Browse files
ivirshupandportnoyogriseljjerphan
authored
ENH Allow fitting PCA on sparse X with arpack solvers (#18689)
Co-authored-by: Andrey Portnoy <aportnoy@fastmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 098356d commit 2d9fa48

File tree

9 files changed

+261
-37
lines changed

9 files changed

+261
-37
lines changed

doc/modules/decomposition.rst

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ out-of-core Principal Component Analysis either by:
7474
* Using its ``partial_fit`` method on chunks of data fetched sequentially
7575
from the local hard drive or a network database.
7676

77-
* Calling its fit method on a sparse matrix or a memory mapped file using
77+
* Calling its fit method on a memory mapped file using
7878
``numpy.memmap``.
7979

8080
:class:`IncrementalPCA` only stores estimates of component and noise variances,
@@ -420,10 +420,6 @@ in that the matrix :math:`X` does not need to be centered.
420420
When the columnwise (per-feature) means of :math:`X`
421421
are subtracted from the feature values,
422422
truncated SVD on the resulting matrix is equivalent to PCA.
423-
In practical terms, this means
424-
that the :class:`TruncatedSVD` transformer accepts ``scipy.sparse``
425-
matrices without the need to densify them,
426-
as densifying may fill up memory even for medium-sized document collections.
427423

428424
While the :class:`TruncatedSVD` transformer
429425
works with any feature matrix,

doc/modules/manifold.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ Barnes-Hut method improves on the exact method where t-SNE complexity is
644644
or less. The 2D case is typical when building visualizations.
645645
* Barnes-Hut only works with dense input data. Sparse data matrices can only be
646646
embedded with the exact method or can be approximated by a dense low rank
647-
projection for instance using :class:`~sklearn.decomposition.TruncatedSVD`
647+
projection for instance using :class:`~sklearn.decomposition.PCA`
648648
* Barnes-Hut is an approximation of the exact method. The approximation is
649649
parameterized with the angle parameter, therefore the angle parameter is
650650
unused when method="exact"

doc/whats_new/v1.4.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,13 @@ Changelog
271271
:pr:`26315` and :pr:`27098` by :user:`Mateusz Sokół <mtsokol>`,
272272
:user:`Olivier Grisel <ogrisel>` and :user:`Edoardo Abati <EdAbati>`.
273273

274+
- |Feature| :class:`decomposition.PCA` now supports :class:`scipy.sparse.sparray`
275+
and :class:`scipy.sparse.spmatrix` inputs when using the `arpack` solver.
276+
When used on sparse data like :func:`datasets.fetch_20newsgroups_vectorized` this
277+
can lead to speed-ups of 100x (single threaded) and 70x lower memory usage.
278+
Based on :user:`Alexander Tarashansky <atarashansky>`'s implementation in `scanpy <https://github.com/scverse/scanpy>`.
279+
:pr:`18689` by :user:`Isaac Virshup <ivirshup>` and :user:`Andrey Portnoy <andportnoy>`.
280+
274281
:mod:`sklearn.ensemble`
275282
.......................
276283

examples/compose/plot_column_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from sklearn.compose import ColumnTransformer
2828
from sklearn.datasets import fetch_20newsgroups
29-
from sklearn.decomposition import TruncatedSVD
29+
from sklearn.decomposition import PCA
3030
from sklearn.feature_extraction import DictVectorizer
3131
from sklearn.feature_extraction.text import TfidfVectorizer
3232
from sklearn.metrics import classification_report
@@ -141,7 +141,7 @@ def text_stats(posts):
141141
Pipeline(
142142
[
143143
("tfidf", TfidfVectorizer()),
144-
("best", TruncatedSVD(n_components=50)),
144+
("best", PCA(n_components=50, svd_solver="arpack")),
145145
]
146146
),
147147
1,

sklearn/decomposition/_base.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212

1313
import numpy as np
1414
from scipy import linalg
15+
from scipy.sparse import issparse
1516

1617
from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
1718
from ..utils._array_api import _add_to_diagonal, device, get_namespace
19+
from ..utils.sparsefuncs import _implicit_column_offset
1820
from ..utils.validation import check_is_fitted
1921

2022

@@ -126,7 +128,7 @@ def transform(self, X):
126128
127129
Parameters
128130
----------
129-
X : array-like of shape (n_samples, n_features)
131+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
130132
New data, where `n_samples` is the number of samples
131133
and `n_features` is the number of features.
132134
@@ -140,9 +142,14 @@ def transform(self, X):
140142

141143
check_is_fitted(self)
142144

143-
X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False)
145+
X = self._validate_data(
146+
X, accept_sparse=("csr", "csc"), dtype=[xp.float64, xp.float32], reset=False
147+
)
144148
if self.mean_ is not None:
145-
X = X - self.mean_
149+
if issparse(X):
150+
X = _implicit_column_offset(X, self.mean_)
151+
else:
152+
X = X - self.mean_
146153
X_transformed = X @ self.components_.T
147154
if self.whiten:
148155
X_transformed /= xp.sqrt(self.explained_variance_)

sklearn/decomposition/_pca.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ..utils._param_validation import Interval, RealNotInt, StrOptions
2727
from ..utils.deprecation import deprecated
2828
from ..utils.extmath import fast_logdet, randomized_svd, stable_cumsum, svd_flip
29+
from ..utils.sparsefuncs import _implicit_column_offset, mean_variance_axis
2930
from ..utils.validation import check_is_fitted
3031
from ._base import _BasePCA
3132

@@ -422,7 +423,7 @@ def fit(self, X, y=None):
422423
423424
Parameters
424425
----------
425-
X : array-like of shape (n_samples, n_features)
426+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
426427
Training data, where `n_samples` is the number of samples
427428
and `n_features` is the number of features.
428429
@@ -443,7 +444,7 @@ def fit_transform(self, X, y=None):
443444
444445
Parameters
445446
----------
446-
X : array-like of shape (n_samples, n_features)
447+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
447448
Training data, where `n_samples` is the number of samples
448449
and `n_features` is the number of features.
449450
@@ -476,12 +477,12 @@ def _fit(self, X):
476477
"""Dispatch to the right submethod depending on the chosen solver."""
477478
xp, is_array_api_compliant = get_namespace(X)
478479

479-
# Raise an error for sparse input.
480-
# This is more informative than the generic one raised by check_array.
481-
if issparse(X):
480+
# Raise an error for sparse input and unsupported svd_solver
481+
if issparse(X) and self.svd_solver != "arpack":
482482
raise TypeError(
483-
"PCA does not support sparse input. See "
484-
"TruncatedSVD for a possible alternative."
483+
'PCA only support sparse inputs with the "arpack" solver, while '
484+
f'"{self.svd_solver}" was passed. See TruncatedSVD for a possible'
485+
" alternative."
485486
)
486487
# Raise an error for non-Numpy input and arpack solver.
487488
if self.svd_solver == "arpack" and is_array_api_compliant:
@@ -490,7 +491,11 @@ def _fit(self, X):
490491
)
491492

492493
X = self._validate_data(
493-
X, dtype=[xp.float64, xp.float32], ensure_2d=True, copy=self.copy
494+
X,
495+
dtype=[xp.float64, xp.float32],
496+
accept_sparse=("csr", "csc"),
497+
ensure_2d=True,
498+
copy=self.copy,
494499
)
495500

496501
# Handle n_components==None
@@ -622,8 +627,14 @@ def _fit_truncated(self, X, n_components, svd_solver):
622627
random_state = check_random_state(self.random_state)
623628

624629
# Center data
625-
self.mean_ = xp.mean(X, axis=0)
626-
X -= self.mean_
630+
total_var = None
631+
if issparse(X):
632+
self.mean_, var = mean_variance_axis(X, axis=0)
633+
total_var = var.sum() * n_samples / (n_samples - 1) # ddof=1
634+
X = _implicit_column_offset(X, self.mean_)
635+
else:
636+
self.mean_ = xp.mean(X, axis=0)
637+
X -= self.mean_
627638

628639
if svd_solver == "arpack":
629640
v0 = _init_arpack_v0(min(X.shape), random_state)
@@ -655,9 +666,15 @@ def _fit_truncated(self, X, n_components, svd_solver):
655666

656667
# Workaround in-place variance calculation since at the time numpy
657668
# did not have a way to calculate variance in-place.
658-
N = X.shape[0] - 1
659-
X **= 2
660-
total_var = xp.sum(xp.sum(X, axis=0) / N)
669+
#
670+
# TODO: update this code to either:
671+
# * Use the array-api variance calculation, unless memory usage suffers
672+
# * Update sklearn.utils.extmath._incremental_mean_and_var to support array-api
673+
# See: https://github.com/scikit-learn/scikit-learn/pull/18689#discussion_r1335540991
674+
if total_var is None:
675+
N = X.shape[0] - 1
676+
X **= 2
677+
total_var = xp.sum(X) / N
661678

662679
self.explained_variance_ratio_ = self.explained_variance_ / total_var
663680
self.singular_values_ = xp.asarray(S, copy=True) # Store the singular values.

sklearn/decomposition/tests/test_pca.py

Lines changed: 130 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,28 @@
2121
_get_check_estimator_ids,
2222
check_array_api_input_and_values,
2323
)
24-
from sklearn.utils.fixes import CSR_CONTAINERS
24+
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS
2525

2626
iris = datasets.load_iris()
2727
PCA_SOLVERS = ["full", "arpack", "randomized", "auto"]
2828

29+
# `SPARSE_M` and `SPARSE_N` could be larger, but be aware:
30+
# * SciPy's generation of random sparse matrix can be costly
31+
# * A (SPARSE_M, SPARSE_N) dense array is allocated to compare against
32+
SPARSE_M, SPARSE_N = 1000, 300 # arbitrary
33+
SPARSE_MAX_COMPONENTS = min(SPARSE_M, SPARSE_N)
34+
35+
36+
def _check_fitted_pca_close(pca1, pca2, rtol):
37+
assert_allclose(pca1.components_, pca2.components_, rtol=rtol)
38+
assert_allclose(pca1.explained_variance_, pca2.explained_variance_, rtol=rtol)
39+
assert_allclose(pca1.singular_values_, pca2.singular_values_, rtol=rtol)
40+
assert_allclose(pca1.mean_, pca2.mean_, rtol=rtol)
41+
assert_allclose(pca1.n_components_, pca2.n_components_, rtol=rtol)
42+
assert_allclose(pca1.n_samples_, pca2.n_samples_, rtol=rtol)
43+
assert_allclose(pca1.noise_variance_, pca2.noise_variance_, rtol=rtol)
44+
assert_allclose(pca1.n_features_in_, pca2.n_features_in_, rtol=rtol)
45+
2946

3047
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
3148
@pytest.mark.parametrize("n_components", range(1, iris.data.shape[1]))
@@ -49,6 +66,118 @@ def test_pca(svd_solver, n_components):
4966
assert_allclose(np.dot(cov, precision), np.eye(X.shape[1]), atol=1e-12)
5067

5168

69+
@pytest.mark.parametrize("density", [0.01, 0.1, 0.30])
70+
@pytest.mark.parametrize("n_components", [1, 2, 10])
71+
@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS)
72+
@pytest.mark.parametrize("svd_solver", ["arpack"])
73+
@pytest.mark.parametrize("scale", [1, 10, 100])
74+
def test_pca_sparse(
75+
global_random_seed, svd_solver, sparse_container, n_components, density, scale
76+
):
77+
# Make sure any tolerance changes pass with SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all"
78+
rtol = 5e-07
79+
transform_rtol = 3e-05
80+
81+
random_state = np.random.default_rng(global_random_seed)
82+
X = sparse_container(
83+
sp.sparse.random(
84+
SPARSE_M,
85+
SPARSE_N,
86+
random_state=random_state,
87+
density=density,
88+
)
89+
)
90+
# Scale the data + vary the column means
91+
scale_vector = random_state.random(X.shape[1]) * scale
92+
X = X.multiply(scale_vector)
93+
94+
pca = PCA(
95+
n_components=n_components,
96+
svd_solver=svd_solver,
97+
random_state=global_random_seed,
98+
)
99+
pca.fit(X)
100+
101+
Xd = X.toarray()
102+
pcad = PCA(
103+
n_components=n_components,
104+
svd_solver=svd_solver,
105+
random_state=global_random_seed,
106+
)
107+
pcad.fit(Xd)
108+
109+
# Fitted attributes equality
110+
_check_fitted_pca_close(pca, pcad, rtol=rtol)
111+
112+
# Test transform
113+
X2 = sparse_container(
114+
sp.sparse.random(
115+
SPARSE_M,
116+
SPARSE_N,
117+
random_state=random_state,
118+
density=density,
119+
)
120+
)
121+
X2d = X2.toarray()
122+
123+
assert_allclose(pca.transform(X2), pca.transform(X2d), rtol=transform_rtol)
124+
assert_allclose(pca.transform(X2), pcad.transform(X2d), rtol=transform_rtol)
125+
126+
127+
@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS)
128+
def test_pca_sparse_fit_transform(global_random_seed, sparse_container):
129+
random_state = np.random.default_rng(global_random_seed)
130+
X = sparse_container(
131+
sp.sparse.random(
132+
SPARSE_M,
133+
SPARSE_N,
134+
random_state=random_state,
135+
density=0.01,
136+
)
137+
)
138+
X2 = sparse_container(
139+
sp.sparse.random(
140+
SPARSE_M,
141+
SPARSE_N,
142+
random_state=random_state,
143+
density=0.01,
144+
)
145+
)
146+
147+
pca_fit = PCA(n_components=10, svd_solver="arpack", random_state=global_random_seed)
148+
pca_fit_transform = PCA(
149+
n_components=10, svd_solver="arpack", random_state=global_random_seed
150+
)
151+
152+
pca_fit.fit(X)
153+
transformed_X = pca_fit_transform.fit_transform(X)
154+
155+
_check_fitted_pca_close(pca_fit, pca_fit_transform, rtol=1e-10)
156+
assert_allclose(transformed_X, pca_fit_transform.transform(X), rtol=2e-9)
157+
assert_allclose(transformed_X, pca_fit.transform(X), rtol=2e-9)
158+
assert_allclose(pca_fit.transform(X2), pca_fit_transform.transform(X2), rtol=2e-9)
159+
160+
161+
@pytest.mark.parametrize("svd_solver", ["randomized", "full", "auto"])
162+
@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS)
163+
def test_sparse_pca_solver_error(global_random_seed, svd_solver, sparse_container):
164+
random_state = np.random.RandomState(global_random_seed)
165+
X = sparse_container(
166+
sp.sparse.random(
167+
SPARSE_M,
168+
SPARSE_N,
169+
random_state=random_state,
170+
)
171+
)
172+
pca = PCA(n_components=30, svd_solver=svd_solver)
173+
error_msg_pattern = (
174+
f'PCA only support sparse inputs with the "arpack" solver, while "{svd_solver}"'
175+
" was passed"
176+
)
177+
with pytest.raises(TypeError, match=error_msg_pattern):
178+
pca.fit(X)
179+
180+
52181
def test_no_empty_slice_warning():
53182
# test if we avoid numpy warnings for computing over empty arrays
54183
n_components = 10
@@ -502,18 +631,6 @@ def test_pca_svd_solver_auto(data, n_components, expected_solver):
502631
assert_allclose(pca_auto.components_, pca_test.components_)
503632

504633

505-
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
506-
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
507-
def test_pca_sparse_input(svd_solver, csr_container):
508-
X = np.random.RandomState(0).rand(5, 4)
509-
X = csr_container(X)
510-
assert sp.sparse.issparse(X)
511-
512-
pca = PCA(n_components=3, svd_solver=svd_solver)
513-
with pytest.raises(TypeError):
514-
pca.fit(X)
515-
516-
517634
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS)
518635
def test_pca_deterministic_output(svd_solver):
519636
rng = np.random.RandomState(0)

sklearn/utils/sparsefuncs.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# License: BSD 3 clause
1111
import numpy as np
1212
import scipy.sparse as sp
13+
from scipy.sparse.linalg import LinearOperator
1314

1415
from ..utils.fixes import _sparse_min_max, _sparse_nan_min_max
1516
from ..utils.validation import _check_sample_weight
@@ -568,3 +569,30 @@ def csc_median_axis_0(X):
568569
median[f_ind] = _get_median(data, nz)
569570

570571
return median
572+
573+
574+
def _implicit_column_offset(X, offset):
575+
"""Create an implicitly offset linear operator.
576+
577+
This is used by PCA on sparse data to avoid densifying the whole data
578+
matrix.
579+
580+
Params
581+
------
582+
X : sparse matrix of shape (n_samples, n_features)
583+
offset : ndarray of shape (n_features,)
584+
585+
Returns
586+
-------
587+
centered : LinearOperator
588+
"""
589+
offset = offset[None, :]
590+
XT = X.T
591+
return LinearOperator(
592+
matvec=lambda x: X @ x - offset @ x,
593+
matmat=lambda x: X @ x - offset @ x,
594+
rmatvec=lambda x: XT @ x - (offset * x.sum()),
595+
rmatmat=lambda x: XT @ x - offset.T @ x.sum(axis=0)[None, :],
596+
dtype=X.dtype,
597+
shape=X.shape,
598+
)

0 commit comments

Comments
 (0)
0