diff --git a/doc/modules/decomposition.rst b/doc/modules/decomposition.rst index b852a6133d542..c1e317d2ff7d3 100644 --- a/doc/modules/decomposition.rst +++ b/doc/modules/decomposition.rst @@ -74,7 +74,7 @@ out-of-core Principal Component Analysis either by: * Using its ``partial_fit`` method on chunks of data fetched sequentially from the local hard drive or a network database. - * Calling its fit method on a sparse matrix or a memory mapped file using + * Calling its fit method on a memory mapped file using ``numpy.memmap``. :class:`IncrementalPCA` only stores estimates of component and noise variances, @@ -420,10 +420,6 @@ in that the matrix :math:`X` does not need to be centered. When the columnwise (per-feature) means of :math:`X` are subtracted from the feature values, truncated SVD on the resulting matrix is equivalent to PCA. -In practical terms, this means -that the :class:`TruncatedSVD` transformer accepts ``scipy.sparse`` -matrices without the need to densify them, -as densifying may fill up memory even for medium-sized document collections. While the :class:`TruncatedSVD` transformer works with any feature matrix, diff --git a/doc/modules/manifold.rst b/doc/modules/manifold.rst index 1656c09f1371d..7cc6776e37daa 100644 --- a/doc/modules/manifold.rst +++ b/doc/modules/manifold.rst @@ -644,7 +644,7 @@ Barnes-Hut method improves on the exact method where t-SNE complexity is or less. The 2D case is typical when building visualizations. * Barnes-Hut only works with dense input data. Sparse data matrices can only be embedded with the exact method or can be approximated by a dense low rank - projection for instance using :class:`~sklearn.decomposition.TruncatedSVD` + projection for instance using :class:`~sklearn.decomposition.PCA` * Barnes-Hut is an approximation of the exact method. The approximation is parameterized with the angle parameter, therefore the angle parameter is unused when method="exact" diff --git a/doc/whats_new/v1.4.rst b/doc/whats_new/v1.4.rst index 00539260ffed6..613e791b7bb78 100644 --- a/doc/whats_new/v1.4.rst +++ b/doc/whats_new/v1.4.rst @@ -271,6 +271,13 @@ Changelog :pr:`26315` and :pr:`27098` by :user:`Mateusz Sokół `, :user:`Olivier Grisel ` and :user:`Edoardo Abati `. +- |Feature| :class:`decomposition.PCA` now supports :class:`scipy.sparse.sparray` + and :class:`scipy.sparse.spmatrix` inputs when using the `arpack` solver. + When used on sparse data like :func:`datasets.fetch_20newsgroups_vectorized` this + can lead to speed-ups of 100x (single threaded) and 70x lower memory usage. + Based on :user:`Alexander Tarashansky `'s implementation in `scanpy `. + :pr:`18689` by :user:`Isaac Virshup ` and :user:`Andrey Portnoy `. + :mod:`sklearn.ensemble` ....................... diff --git a/examples/compose/plot_column_transformer.py b/examples/compose/plot_column_transformer.py index 669e817cbf81d..207f7450a2705 100644 --- a/examples/compose/plot_column_transformer.py +++ b/examples/compose/plot_column_transformer.py @@ -26,7 +26,7 @@ from sklearn.compose import ColumnTransformer from sklearn.datasets import fetch_20newsgroups -from sklearn.decomposition import TruncatedSVD +from sklearn.decomposition import PCA from sklearn.feature_extraction import DictVectorizer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import classification_report @@ -141,7 +141,7 @@ def text_stats(posts): Pipeline( [ ("tfidf", TfidfVectorizer()), - ("best", TruncatedSVD(n_components=50)), + ("best", PCA(n_components=50, svd_solver="arpack")), ] ), 1, diff --git a/sklearn/decomposition/_base.py b/sklearn/decomposition/_base.py index c4ccf92212fe9..9fa720751774f 100644 --- a/sklearn/decomposition/_base.py +++ b/sklearn/decomposition/_base.py @@ -12,9 +12,11 @@ import numpy as np from scipy import linalg +from scipy.sparse import issparse from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin from ..utils._array_api import _add_to_diagonal, device, get_namespace +from ..utils.sparsefuncs import _implicit_column_offset from ..utils.validation import check_is_fitted @@ -126,7 +128,7 @@ def transform(self, X): Parameters ---------- - X : array-like of shape (n_samples, n_features) + X : {array-like, sparse matrix} of shape (n_samples, n_features) New data, where `n_samples` is the number of samples and `n_features` is the number of features. @@ -140,9 +142,14 @@ def transform(self, X): check_is_fitted(self) - X = self._validate_data(X, dtype=[xp.float64, xp.float32], reset=False) + X = self._validate_data( + X, accept_sparse=("csr", "csc"), dtype=[xp.float64, xp.float32], reset=False + ) if self.mean_ is not None: - X = X - self.mean_ + if issparse(X): + X = _implicit_column_offset(X, self.mean_) + else: + X = X - self.mean_ X_transformed = X @ self.components_.T if self.whiten: X_transformed /= xp.sqrt(self.explained_variance_) diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index 877baf4d4e81c..046d121ac1934 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -26,6 +26,7 @@ from ..utils._param_validation import Interval, RealNotInt, StrOptions from ..utils.deprecation import deprecated from ..utils.extmath import fast_logdet, randomized_svd, stable_cumsum, svd_flip +from ..utils.sparsefuncs import _implicit_column_offset, mean_variance_axis from ..utils.validation import check_is_fitted from ._base import _BasePCA @@ -422,7 +423,7 @@ def fit(self, X, y=None): Parameters ---------- - X : array-like of shape (n_samples, n_features) + X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data, where `n_samples` is the number of samples and `n_features` is the number of features. @@ -443,7 +444,7 @@ def fit_transform(self, X, y=None): Parameters ---------- - X : array-like of shape (n_samples, n_features) + X : {array-like, sparse matrix} of shape (n_samples, n_features) Training data, where `n_samples` is the number of samples and `n_features` is the number of features. @@ -476,12 +477,12 @@ def _fit(self, X): """Dispatch to the right submethod depending on the chosen solver.""" xp, is_array_api_compliant = get_namespace(X) - # Raise an error for sparse input. - # This is more informative than the generic one raised by check_array. - if issparse(X): + # Raise an error for sparse input and unsupported svd_solver + if issparse(X) and self.svd_solver != "arpack": raise TypeError( - "PCA does not support sparse input. See " - "TruncatedSVD for a possible alternative." + 'PCA only support sparse inputs with the "arpack" solver, while ' + f'"{self.svd_solver}" was passed. See TruncatedSVD for a possible' + " alternative." ) # Raise an error for non-Numpy input and arpack solver. if self.svd_solver == "arpack" and is_array_api_compliant: @@ -490,7 +491,11 @@ def _fit(self, X): ) X = self._validate_data( - X, dtype=[xp.float64, xp.float32], ensure_2d=True, copy=self.copy + X, + dtype=[xp.float64, xp.float32], + accept_sparse=("csr", "csc"), + ensure_2d=True, + copy=self.copy, ) # Handle n_components==None @@ -622,8 +627,14 @@ def _fit_truncated(self, X, n_components, svd_solver): random_state = check_random_state(self.random_state) # Center data - self.mean_ = xp.mean(X, axis=0) - X -= self.mean_ + total_var = None + if issparse(X): + self.mean_, var = mean_variance_axis(X, axis=0) + total_var = var.sum() * n_samples / (n_samples - 1) # ddof=1 + X = _implicit_column_offset(X, self.mean_) + else: + self.mean_ = xp.mean(X, axis=0) + X -= self.mean_ if svd_solver == "arpack": v0 = _init_arpack_v0(min(X.shape), random_state) @@ -655,9 +666,15 @@ def _fit_truncated(self, X, n_components, svd_solver): # Workaround in-place variance calculation since at the time numpy # did not have a way to calculate variance in-place. - N = X.shape[0] - 1 - X **= 2 - total_var = xp.sum(xp.sum(X, axis=0) / N) + # + # TODO: update this code to either: + # * Use the array-api variance calculation, unless memory usage suffers + # * Update sklearn.utils.extmath._incremental_mean_and_var to support array-api + # See: https://github.com/scikit-learn/scikit-learn/pull/18689#discussion_r1335540991 + if total_var is None: + N = X.shape[0] - 1 + X **= 2 + total_var = xp.sum(X) / N self.explained_variance_ratio_ = self.explained_variance_ / total_var self.singular_values_ = xp.asarray(S, copy=True) # Store the singular values. diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 64b07c653b1a2..9cbd8936dc1dd 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -21,11 +21,28 @@ _get_check_estimator_ids, check_array_api_input_and_values, ) -from sklearn.utils.fixes import CSR_CONTAINERS +from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS iris = datasets.load_iris() PCA_SOLVERS = ["full", "arpack", "randomized", "auto"] +# `SPARSE_M` and `SPARSE_N` could be larger, but be aware: +# * SciPy's generation of random sparse matrix can be costly +# * A (SPARSE_M, SPARSE_N) dense array is allocated to compare against +SPARSE_M, SPARSE_N = 1000, 300 # arbitrary +SPARSE_MAX_COMPONENTS = min(SPARSE_M, SPARSE_N) + + +def _check_fitted_pca_close(pca1, pca2, rtol): + assert_allclose(pca1.components_, pca2.components_, rtol=rtol) + assert_allclose(pca1.explained_variance_, pca2.explained_variance_, rtol=rtol) + assert_allclose(pca1.singular_values_, pca2.singular_values_, rtol=rtol) + assert_allclose(pca1.mean_, pca2.mean_, rtol=rtol) + assert_allclose(pca1.n_components_, pca2.n_components_, rtol=rtol) + assert_allclose(pca1.n_samples_, pca2.n_samples_, rtol=rtol) + assert_allclose(pca1.noise_variance_, pca2.noise_variance_, rtol=rtol) + assert_allclose(pca1.n_features_in_, pca2.n_features_in_, rtol=rtol) + @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) @pytest.mark.parametrize("n_components", range(1, iris.data.shape[1])) @@ -49,6 +66,118 @@ def test_pca(svd_solver, n_components): assert_allclose(np.dot(cov, precision), np.eye(X.shape[1]), atol=1e-12) +@pytest.mark.parametrize("density", [0.01, 0.1, 0.30]) +@pytest.mark.parametrize("n_components", [1, 2, 10]) +@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS) +@pytest.mark.parametrize("svd_solver", ["arpack"]) +@pytest.mark.parametrize("scale", [1, 10, 100]) +def test_pca_sparse( + global_random_seed, svd_solver, sparse_container, n_components, density, scale +): + # Make sure any tolerance changes pass with SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all" + rtol = 5e-07 + transform_rtol = 3e-05 + + random_state = np.random.default_rng(global_random_seed) + X = sparse_container( + sp.sparse.random( + SPARSE_M, + SPARSE_N, + random_state=random_state, + density=density, + ) + ) + # Scale the data + vary the column means + scale_vector = random_state.random(X.shape[1]) * scale + X = X.multiply(scale_vector) + + pca = PCA( + n_components=n_components, + svd_solver=svd_solver, + random_state=global_random_seed, + ) + pca.fit(X) + + Xd = X.toarray() + pcad = PCA( + n_components=n_components, + svd_solver=svd_solver, + random_state=global_random_seed, + ) + pcad.fit(Xd) + + # Fitted attributes equality + _check_fitted_pca_close(pca, pcad, rtol=rtol) + + # Test transform + X2 = sparse_container( + sp.sparse.random( + SPARSE_M, + SPARSE_N, + random_state=random_state, + density=density, + ) + ) + X2d = X2.toarray() + + assert_allclose(pca.transform(X2), pca.transform(X2d), rtol=transform_rtol) + assert_allclose(pca.transform(X2), pcad.transform(X2d), rtol=transform_rtol) + + +@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS) +def test_pca_sparse_fit_transform(global_random_seed, sparse_container): + random_state = np.random.default_rng(global_random_seed) + X = sparse_container( + sp.sparse.random( + SPARSE_M, + SPARSE_N, + random_state=random_state, + density=0.01, + ) + ) + X2 = sparse_container( + sp.sparse.random( + SPARSE_M, + SPARSE_N, + random_state=random_state, + density=0.01, + ) + ) + + pca_fit = PCA(n_components=10, svd_solver="arpack", random_state=global_random_seed) + pca_fit_transform = PCA( + n_components=10, svd_solver="arpack", random_state=global_random_seed + ) + + pca_fit.fit(X) + transformed_X = pca_fit_transform.fit_transform(X) + + _check_fitted_pca_close(pca_fit, pca_fit_transform, rtol=1e-10) + assert_allclose(transformed_X, pca_fit_transform.transform(X), rtol=2e-9) + assert_allclose(transformed_X, pca_fit.transform(X), rtol=2e-9) + assert_allclose(pca_fit.transform(X2), pca_fit_transform.transform(X2), rtol=2e-9) + + +@pytest.mark.parametrize("svd_solver", ["randomized", "full", "auto"]) +@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS) +def test_sparse_pca_solver_error(global_random_seed, svd_solver, sparse_container): + random_state = np.random.RandomState(global_random_seed) + X = sparse_container( + sp.sparse.random( + SPARSE_M, + SPARSE_N, + random_state=random_state, + ) + ) + pca = PCA(n_components=30, svd_solver=svd_solver) + error_msg_pattern = ( + f'PCA only support sparse inputs with the "arpack" solver, while "{svd_solver}"' + " was passed" + ) + with pytest.raises(TypeError, match=error_msg_pattern): + pca.fit(X) + + def test_no_empty_slice_warning(): # test if we avoid numpy warnings for computing over empty arrays n_components = 10 @@ -502,18 +631,6 @@ def test_pca_svd_solver_auto(data, n_components, expected_solver): assert_allclose(pca_auto.components_, pca_test.components_) -@pytest.mark.parametrize("svd_solver", PCA_SOLVERS) -@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) -def test_pca_sparse_input(svd_solver, csr_container): - X = np.random.RandomState(0).rand(5, 4) - X = csr_container(X) - assert sp.sparse.issparse(X) - - pca = PCA(n_components=3, svd_solver=svd_solver) - with pytest.raises(TypeError): - pca.fit(X) - - @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) def test_pca_deterministic_output(svd_solver): rng = np.random.RandomState(0) diff --git a/sklearn/utils/sparsefuncs.py b/sklearn/utils/sparsefuncs.py index ac908a5646fba..9eccb8c07676f 100644 --- a/sklearn/utils/sparsefuncs.py +++ b/sklearn/utils/sparsefuncs.py @@ -10,6 +10,7 @@ # License: BSD 3 clause import numpy as np import scipy.sparse as sp +from scipy.sparse.linalg import LinearOperator from ..utils.fixes import _sparse_min_max, _sparse_nan_min_max from ..utils.validation import _check_sample_weight @@ -568,3 +569,30 @@ def csc_median_axis_0(X): median[f_ind] = _get_median(data, nz) return median + + +def _implicit_column_offset(X, offset): + """Create an implicitly offset linear operator. + + This is used by PCA on sparse data to avoid densifying the whole data + matrix. + + Params + ------ + X : sparse matrix of shape (n_samples, n_features) + offset : ndarray of shape (n_features,) + + Returns + ------- + centered : LinearOperator + """ + offset = offset[None, :] + XT = X.T + return LinearOperator( + matvec=lambda x: X @ x - offset @ x, + matmat=lambda x: X @ x - offset @ x, + rmatvec=lambda x: XT @ x - (offset * x.sum()), + rmatmat=lambda x: XT @ x - offset.T @ x.sum(axis=0)[None, :], + dtype=X.dtype, + shape=X.shape, + ) diff --git a/sklearn/utils/tests/test_sparsefuncs.py b/sklearn/utils/tests/test_sparsefuncs.py index f311305812733..8e3bda13928e4 100644 --- a/sklearn/utils/tests/test_sparsefuncs.py +++ b/sklearn/utils/tests/test_sparsefuncs.py @@ -9,6 +9,7 @@ from sklearn.utils._testing import assert_allclose from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS, LIL_CONTAINERS from sklearn.utils.sparsefuncs import ( + _implicit_column_offset, count_nonzero, csc_median_axis_0, incr_mean_variance_axis, @@ -944,3 +945,54 @@ def test_csr_row_norms(dtype): assert norms.dtype == dtype rtol = 1e-6 if dtype == np.float32 else 1e-7 assert_allclose(norms, scipy_norms, rtol=rtol) + + +@pytest.fixture(scope="module", params=CSR_CONTAINERS + CSC_CONTAINERS) +def centered_matrices(request): + """Returns equivalent tuple[sp.linalg.LinearOperator, np.ndarray].""" + sparse_container = request.param + + random_state = np.random.default_rng(42) + + X_sparse = sparse_container( + sp.random(500, 100, density=0.1, format="csr", random_state=random_state) + ) + X_dense = X_sparse.toarray() + mu = np.asarray(X_sparse.mean(axis=0)).ravel() + + X_sparse_centered = _implicit_column_offset(X_sparse, mu) + X_dense_centered = X_dense - mu + + return X_sparse_centered, X_dense_centered + + +def test_implicit_center_matmat(global_random_seed, centered_matrices): + X_sparse_centered, X_dense_centered = centered_matrices + rng = np.random.default_rng(global_random_seed) + Y = rng.standard_normal((X_dense_centered.shape[1], 50)) + assert_allclose(X_dense_centered @ Y, X_sparse_centered.matmat(Y)) + assert_allclose(X_dense_centered @ Y, X_sparse_centered @ Y) + + +def test_implicit_center_matvec(global_random_seed, centered_matrices): + X_sparse_centered, X_dense_centered = centered_matrices + rng = np.random.default_rng(global_random_seed) + y = rng.standard_normal(X_dense_centered.shape[1]) + assert_allclose(X_dense_centered @ y, X_sparse_centered.matvec(y)) + assert_allclose(X_dense_centered @ y, X_sparse_centered @ y) + + +def test_implicit_center_rmatmat(global_random_seed, centered_matrices): + X_sparse_centered, X_dense_centered = centered_matrices + rng = np.random.default_rng(global_random_seed) + Y = rng.standard_normal((X_dense_centered.shape[0], 50)) + assert_allclose(X_dense_centered.T @ Y, X_sparse_centered.rmatmat(Y)) + assert_allclose(X_dense_centered.T @ Y, X_sparse_centered.T @ Y) + + +def test_implit_center_rmatvec(global_random_seed, centered_matrices): + X_sparse_centered, X_dense_centered = centered_matrices + rng = np.random.default_rng(global_random_seed) + y = rng.standard_normal(X_dense_centered.shape[0]) + assert_allclose(X_dense_centered.T @ y, X_sparse_centered.rmatvec(y)) + assert_allclose(X_dense_centered.T @ y, X_sparse_centered.T @ y)