diff --git a/pca-sparse-debug/scripts/mismatch-csv.sh b/pca-sparse-debug/scripts/mismatch-csv.sh new file mode 100644 index 0000000000000..2a7ee54a3a8cd --- /dev/null +++ b/pca-sparse-debug/scripts/mismatch-csv.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +LOGFILE=$1 +paste -d ',' \ + <(echo 'seed,rtol,solver,layout,k,density'; grep -Po '(?<=test_pca_sparse\[).+?(?=\])' $LOGFILE | sed -e 's/1e-/1e@/g' -e 's/-/,/g' -e 's/@/-/g') \ + <(echo 'bad,total'; grep -Po '(?<=Mismatched elements: )\d+ / \d+' $LOGFILE | sed -e 's/ //g' -e 's/\//,/g') diff --git a/pca-sparse-debug/scripts/mismatch-log.sh b/pca-sparse-debug/scripts/mismatch-log.sh new file mode 100644 index 0000000000000..68554244485be --- /dev/null +++ b/pca-sparse-debug/scripts/mismatch-log.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all" OMP_NUM_THREADS=1 pytest --color=no -n "$(nproc --all)" sklearn/decomposition/tests/test_pca.py::test_pca_sparse diff --git a/pca-sparse-debug/scripts/mismatch-main.sh b/pca-sparse-debug/scripts/mismatch-main.sh new file mode 100644 index 0000000000000..4938547298674 --- /dev/null +++ b/pca-sparse-debug/scripts/mismatch-main.sh @@ -0,0 +1,19 @@ +set -euxo pipefail + +SCRIPTDIR=$(dirname "$0") +DATADIR=$SCRIPTDIR/../data +PLOTDIR=$SCRIPTDIR/../plots +TIMESTAMP=$(date +"%Y%m%d-%H%M%S") +GITHASH=$(git rev-parse --short HEAD) + +BASENAME=pca-sparse-mismatch-$GITHASH-$TIMESTAMP +LOGFILE=$DATADIR/$BASENAME.log +CSVFILE=$DATADIR/$BASENAME.csv +PLOTFILE=$PLOTDIR/$BASENAME.png + +mkdir -p $DATADIR +mkdir -p $PLOTDIR + +bash $SCRIPTDIR/mismatch-log.sh > $LOGFILE || true +bash $SCRIPTDIR/mismatch-csv.sh $LOGFILE > $CSVFILE +python $SCRIPTDIR/mismatch-plot.py $CSVFILE $PLOTFILE diff --git a/pca-sparse-debug/scripts/mismatch-plot.py b/pca-sparse-debug/scripts/mismatch-plot.py new file mode 100644 index 0000000000000..c3ba555dc91bf --- /dev/null +++ b/pca-sparse-debug/scripts/mismatch-plot.py @@ -0,0 +1,52 @@ +import sys +import pandas as pd +import matplotlib.pyplot as plt + +csv = sys.argv[1] +plot = sys.argv[2] + +df = pd.read_csv(csv) +df = df[df.solver != 'auto'] + +df['rate'] = df.bad/df.total + +def mismatch_by(x): + gb = df.groupby(x) + return gb['rate'].mean() + +fig, axes = plt.subplots(2, 2, figsize=(10, 10), dpi=300) + +ax=axes[0][0] +seed = mismatch_by('seed').hist(ax=ax) +seed.set_title('mismatch rate by seed (histogram)') +seed.set_xlabel('mismatch rate') +seed.set_ylabel('seed count') +seed.set_ylim(top=100) +seed.set_xlim(right=1) + +ax=axes[0][1] +solver = mismatch_by('solver').plot.bar(ax=ax) +solver.set_title('mismatch rate by solver') +solver.set_xlabel('solver') +solver.set_ylabel('mismatch rate') +ax.bar_label(ax.containers[0], fmt="%.3f") + +ax=axes[1][0] +density = mismatch_by('density').plot.bar(ax=ax) +density.set_title('mismatch rate by density') +density.set_xlabel('density') +density.set_ylabel('mismatch rate') +ax.bar_label(ax.containers[0], fmt="%.3f") + +ax=axes[1][1] +ncomp = mismatch_by('k').plot.bar(ax=ax) +ncomp.set_title('mismatch rate by number of components') +ncomp.set_xlabel('# components') +ncomp.set_ylabel('mismatch rate') +ax.bar_label(ax.containers[0], fmt="%.3f") + +for bp in (solver, density, ncomp): + bp.set_xticklabels(bp.get_xticklabels(), rotation=0) + bp.set_ylim(top=1) +fig.tight_layout() +fig.savefig(plot, facecolor='white', transparent=False) diff --git a/pca-sparse-debug/scripts/mismatch-tolerance-plot.py b/pca-sparse-debug/scripts/mismatch-tolerance-plot.py new file mode 100644 index 0000000000000..c980fa7be2a2f --- /dev/null +++ b/pca-sparse-debug/scripts/mismatch-tolerance-plot.py @@ -0,0 +1,40 @@ +import sys +import pandas as pd +import matplotlib as mpl +import matplotlib.pyplot as plt + +csv = sys.argv[1] +plot = sys.argv[2] + +df = pd.read_csv(csv) +df = df[df.solver != 'auto'] + +df['rate'] = df.bad/df.total + +def mismatch_by(x): + gb = df.groupby(x) + return gb['rate'].mean() + +gb = mismatch_by(['solver', 'rtol']) +gb = gb.reindex(pd.MultiIndex.from_product(gb.index.levels)).fillna(0) + +def formatter(x): + if x == 1: + return '1e-00' + else: + return format(x, ".0e") +gb.index = gb.index.set_levels(map(formatter, gb.index.levels[1]), level=1) +solvers = gb.index.levels[0] +fig, axes = plt.subplots(1, len(solvers), figsize=(24, 4), dpi=300) +fig.suptitle('Elementwise mismatch rate by solver and relative tolerance', fontsize=20) +for solver, ax in zip(solvers, axes.flat): + bar = gb[solver].plot.bar(ax=ax) + bar.set_title(solver, pad=15, fontsize=20) + bar.set_xlabel('relative tolerance', fontsize=14) + bar.set_ylabel('mismatch rate', fontsize=14) + bar.set_ylim(bottom=0, top=1) + bar.tick_params(axis='both', which='major', labelsize=14) + ax.bar_label(ax.containers[0], fmt="%.2g") + +fig.tight_layout() +fig.savefig(plot, facecolor='white', transparent=False) diff --git a/pca-sparse-debug/scripts/passrate-csv.sh b/pca-sparse-debug/scripts/passrate-csv.sh new file mode 100644 index 0000000000000..e473123777743 --- /dev/null +++ b/pca-sparse-debug/scripts/passrate-csv.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +LOGFILE=$1 +echo 'seed,rtol,solver,format,k,density,outcome' +grep -P 'PASSED|FAILED' $LOGFILE | sed -E -e 's/^.*(FAILED|PASSED).*\[(.*)\]/\2 \1/' -e 's/1e-/1e@/g' -e 's/-/ /g' -e 's/@/-/g' -e 's/ $//' -e 's/ /,/g' diff --git a/pca-sparse-debug/scripts/passrate-log.sh b/pca-sparse-debug/scripts/passrate-log.sh new file mode 100644 index 0000000000000..54e21e510a4e4 --- /dev/null +++ b/pca-sparse-debug/scripts/passrate-log.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +# writes pca sparse pass/fail results to stdout + +SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all" OMP_NUM_THREADS=1 pytest --color=no -v --tb=no -n "$(nproc --all)" sklearn/decomposition/tests/test_pca.py::test_pca_sparse diff --git a/pca-sparse-debug/scripts/passrate-main.sh b/pca-sparse-debug/scripts/passrate-main.sh new file mode 100644 index 0000000000000..1c1a2fa89e80e --- /dev/null +++ b/pca-sparse-debug/scripts/passrate-main.sh @@ -0,0 +1,19 @@ +set -euxo pipefail + +SCRIPTDIR=$(dirname "$0") +DATADIR=$SCRIPTDIR/../data +PLOTDIR=$SCRIPTDIR/../plots +TIMESTAMP=$(date +"%Y%m%d-%H%M%S") +GITHASH=$(git rev-parse --short HEAD) + +BASENAME=pca-sparse-passrate-$GITHASH-$TIMESTAMP +LOGFILE=$DATADIR/$BASENAME.log +CSVFILE=$DATADIR/$BASENAME.csv +PLOTFILE=$PLOTDIR/$BASENAME.png + +mkdir -p $DATADIR +mkdir -p $PLOTDIR + +bash $SCRIPTDIR/passrate-log.sh > $LOGFILE || true +bash $SCRIPTDIR/passrate-csv.sh $LOGFILE > $CSVFILE +python $SCRIPTDIR/passrate-plot.py $CSVFILE $PLOTFILE diff --git a/pca-sparse-debug/scripts/passrate-plot.py b/pca-sparse-debug/scripts/passrate-plot.py new file mode 100644 index 0000000000000..fe20d5f2522eb --- /dev/null +++ b/pca-sparse-debug/scripts/passrate-plot.py @@ -0,0 +1,55 @@ +import sys +import pandas as pd +import matplotlib as mpl +import matplotlib.pyplot as plt + +csv = sys.argv[1] +plot = sys.argv[2] + +df = pd.read_csv(csv) +df = df[df.solver != 'auto'] + +df['pass'] = df.outcome.apply(lambda x: True if x=='PASSED' else False) + +def passrate_by(x): + passes = df.groupby(x)['pass'] + counts = passes.count() + sums = passes.sum() + return sums / counts + +fig, axes = plt.subplots(2, 2, figsize=(10, 10), dpi=300) + +ax=axes[0][0] +seed = passrate_by('seed').hist(ax=ax) +seed.set_title('pass rate by seed (histogram)') +seed.set_xlabel('pass rate') +seed.set_ylabel('seed count') +seed.set_ylim(top=100) +seed.set_xlim(right=1) + +ax=axes[0][1] +solver = passrate_by('solver').plot.bar(ax=ax) +solver.set_title('pass rate by solver') +solver.set_xlabel('solver') +solver.set_ylabel('pass rate') +ax.bar_label(ax.containers[0], fmt="%.3f") + +ax=axes[1][0] +density = passrate_by('density').plot.bar(ax=ax) +density.set_title('pass rate by density') +density.set_xlabel('density') +density.set_ylabel('pass rate') +ax.bar_label(ax.containers[0], fmt="%.3f") + +ax=axes[1][1] +ncomp = passrate_by('k').plot.bar(ax=ax) +ncomp.set_title('pass rate by number of components') +ncomp.set_xlabel('# components') +ncomp.set_ylabel('pass rate') +ax.bar_label(ax.containers[0], fmt="%.3f") + +for bp in (solver, density, ncomp): + bp.set_xticklabels(bp.get_xticklabels(), rotation=0) + bp.set_ylim(top=1) +fig.tight_layout() +fig.savefig(plot, facecolor='white', transparent=False) diff --git a/pca-sparse-debug/scripts/passrate-tolerance-plot.py b/pca-sparse-debug/scripts/passrate-tolerance-plot.py new file mode 100644 index 0000000000000..e308ba37b8724 --- /dev/null +++ b/pca-sparse-debug/scripts/passrate-tolerance-plot.py @@ -0,0 +1,38 @@ +import sys +import pandas as pd +import matplotlib as mpl +import matplotlib.pyplot as plt + +csv = sys.argv[1] +plot = sys.argv[2] + +df = pd.read_csv(csv) +df['pass'] = df.outcome.apply(lambda x: True if x=='PASSED' else False) +df = df[df.solver != 'auto'] + +def passrate_by(x): + passes = df.groupby(x)['pass'] + counts = passes.count() + sums = passes.sum() + return sums / counts + +gb = passrate_by(['solver', 'rtol']) +def formatter(x): + if x == 1: + return '1e-00' + else: + return format(x, ".0e") +gb.index = gb.index.set_levels(map(formatter, gb.index.levels[1]), level=1) +fig, axes = plt.subplots(1, 4, figsize=(24, 4), dpi=300) +fig.suptitle('Test pass rate by solver and relative tolerance', fontsize=20) +for solver, ax in zip(gb.index.levels[0], axes.flat): + bar = gb[solver].plot.bar(ax=ax) + bar.set_title(solver, pad=15, fontsize=20) + bar.set_xlabel('relative tolerance', fontsize=14) + bar.set_ylabel('pass rate', fontsize=14) + bar.set_ylim(bottom=0, top=1) + bar.tick_params(axis='both', which='major', labelsize=14) + ax.bar_label(ax.containers[0], fmt="%.2g") + +fig.tight_layout() +fig.savefig(plot, facecolor='white', transparent=False) diff --git a/pca-sparse-debug/scripts/search-mismatch-stats.sh b/pca-sparse-debug/scripts/search-mismatch-stats.sh new file mode 100644 index 0000000000000..d00888917a590 --- /dev/null +++ b/pca-sparse-debug/scripts/search-mismatch-stats.sh @@ -0,0 +1,4 @@ +LOGFILE=$1 +paste -d ',' \ + <(echo 'seed,solver,format,k,density'; grep -Po '(?<=_ test_pca_sparse\[).+?(?=\])' $LOGFILE | sed 's/-/,/g') \ + <(echo 'bad,total'; grep -Po '(?<=Mismatched elements: )\d+ / \d+' $LOGFILE | sed -e 's/ //g' -e 's/\//,/g') diff --git a/sklearn/decomposition/_pca.py b/sklearn/decomposition/_pca.py index e8c302fc47129..629a48ed40445 100644 --- a/sklearn/decomposition/_pca.py +++ b/sklearn/decomposition/_pca.py @@ -17,10 +17,10 @@ from scipy import linalg from scipy.special import gammaln from scipy.sparse import issparse -from scipy.sparse.linalg import svds +from scipy.sparse.linalg import svds, LinearOperator from ._base import _BasePCA -from ..utils import check_random_state +from ..utils import check_random_state, sparsefuncs from ..utils._arpack import _init_arpack_v0 from ..utils.deprecation import deprecated from ..utils.extmath import fast_logdet, randomized_svd, svd_flip @@ -116,6 +116,36 @@ def _infer_dimension(spectrum, n_samples): return ll.argmax() +def _center_implicitly(X, row_mean): + """Create an implicitly centered LinearOperator out of a matrix.""" + + assert X.ndim == 2 + m, n = X.shape + r = row_mean.reshape(1, -1) + ones = np.ones((m, 1), dtype=X.dtype) + + def matvec(y): + return X @ y - r @ y + + def rmatvec(y): + return X.T @ y - r.T @ ones.T @ y + + def matmat(Y): + return X @ Y - r @ Y + + def rmatmat(Y): + return X.T @ Y - r.T @ ones.T @ Y + + return LinearOperator( + shape=X.shape, + dtype=X.dtype, + matvec=matvec, + rmatvec=rmatvec, + matmat=matmat, + rmatmat=rmatmat, + ) + + class PCA(_BasePCA): """Principal component analysis (PCA). @@ -370,7 +400,7 @@ class PCA(_BasePCA): ], "copy": ["boolean"], "whiten": ["boolean"], - "svd_solver": [StrOptions({"auto", "full", "arpack", "randomized"})], + "svd_solver": [StrOptions({"auto", "full", "arpack", "randomized", "lobpcg"})], "tol": [Interval(Real, 0, None, closed="left")], "iterated_power": [ StrOptions({"auto"}), @@ -475,16 +505,12 @@ def fit_transform(self, X, y=None): def _fit(self, X): """Dispatch to the right submethod depending on the chosen solver.""" - # Raise an error for sparse input. - # This is more informative than the generic one raised by check_array. - if issparse(X): - raise TypeError( - "PCA does not support sparse input. See " - "TruncatedSVD for a possible alternative." - ) - X = self._validate_data( - X, dtype=[np.float64, np.float32], ensure_2d=True, copy=self.copy + X, + dtype=[np.float64, np.float32], + ensure_2d=True, + copy=self.copy, + accept_sparse=["csr", "csc"], ) # Handle n_components==None @@ -511,7 +537,7 @@ def _fit(self, X): # Call different fits for either full or truncated SVD if self._fit_svd_solver == "full": return self._fit_full(X, n_components) - elif self._fit_svd_solver in ["arpack", "randomized"]: + elif self._fit_svd_solver in ["arpack", "randomized", "lobpcg"]: return self._fit_truncated(X, n_components, self._fit_svd_solver) def _fit_full(self, X, n_components): @@ -602,18 +628,33 @@ def _fit_truncated(self, X, n_components, svd_solver): random_state = check_random_state(self.random_state) # Center data - self.mean_ = np.mean(X, axis=0) - X -= self.mean_ + have_total_var = False + if issparse(X): + # emulate behavior of a centered X without constructing it explicitly + row_mean, row_variance = sparsefuncs.mean_variance_axis(X, axis=0) + self.mean_ = row_mean + total_var = row_variance.sum() * n_samples / (n_samples - 1) # ddof=1 + have_total_var = True + X = _center_implicitly(X, row_mean) + else: + self.mean_ = np.mean(X, axis=0) + X -= self.mean_ if svd_solver == "arpack": v0 = _init_arpack_v0(min(X.shape), random_state) - U, S, Vt = svds(X, k=n_components, tol=self.tol, v0=v0) + U, S, Vt = svds(X, k=n_components, tol=self.tol, v0=v0, solver=svd_solver) + # svds doesn't abide by scipy.linalg.svd/randomized_svd + # conventions, so reverse its outputs. + S = S[::-1] + # flip eigenvectors' sign to enforce deterministic output + U, Vt = svd_flip(U[:, ::-1], Vt[::-1]) + elif svd_solver == "lobpcg": + U, S, Vt = svds(X, k=n_components, tol=self.tol, solver=svd_solver) # svds doesn't abide by scipy.linalg.svd/randomized_svd # conventions, so reverse its outputs. S = S[::-1] # flip eigenvectors' sign to enforce deterministic output U, Vt = svd_flip(U[:, ::-1], Vt[::-1]) - elif svd_solver == "randomized": # sign flipping is done inside U, S, Vt = randomized_svd( @@ -633,12 +674,12 @@ def _fit_truncated(self, X, n_components, svd_solver): # Get variance explained by singular values self.explained_variance_ = (S**2) / (n_samples - 1) - # Workaround in-place variance calculation since at the time numpy - # did not have a way to calculate variance in-place. - N = X.shape[0] - 1 - np.square(X, out=X) - np.sum(X, axis=0, out=X[0]) - total_var = (X[0] / N).sum() + if not have_total_var: + # Workaround in-place variance calculation since at the time numpy + # did not have a way to calculate variance in-place. + np.square(X, out=X) + np.sum(X, axis=0, out=X[0]) + total_var = (X[0] / (n_samples - 1)).sum() self.explained_variance_ratio_ = self.explained_variance_ / total_var self.singular_values_ = S.copy() # Store the singular values. diff --git a/sklearn/decomposition/tests/test_pca.py b/sklearn/decomposition/tests/test_pca.py index 5bf893f92fd16..5ac312cf4b4fa 100644 --- a/sklearn/decomposition/tests/test_pca.py +++ b/sklearn/decomposition/tests/test_pca.py @@ -1,6 +1,7 @@ import numpy as np import scipy as sp from numpy.testing import assert_array_equal +from scipy.sparse.linalg import LinearOperator import pytest import warnings @@ -14,8 +15,10 @@ from sklearn.decomposition._pca import _infer_dimension iris = datasets.load_iris() -PCA_SOLVERS = ["full", "arpack", "randomized", "auto"] +PCA_SOLVERS = ["full", "arpack", "randomized", "auto", "lobpcg"] +SPARSE_M, SPARSE_N = 400, 300 # arbitrary +SPARSE_MAX_COMPONENTS = min(SPARSE_M, SPARSE_N) @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) @pytest.mark.parametrize("n_components", range(1, iris.data.shape[1])) @@ -39,6 +42,56 @@ def test_pca(svd_solver, n_components): assert_allclose(np.dot(cov, precision), np.eye(X.shape[1]), atol=1e-12) +def linear_operator_from_matrix(A): + return LinearOperator( + shape=A.shape, + dtype=A.dtype, + matvec=lambda x: A @ x, + rmatvec=lambda x: A.T @ x, + matmat=lambda X: A @ X, + rmatmat=lambda X: A.T @ X, + ) + + +def test_linear_operator_matmul(): + A = np.array([[1, 2, 3], [4, 5, 6]]) + B = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + B = linear_operator_from_matrix(B) + expected_error = "Input operand 1 does not have enough dimensions" + with pytest.raises(ValueError, match=expected_error): + A @ B + + +def test_linear_operator_reversed_matmul(): + A = np.array([[1, 2, 3], [4, 5, 6]]) + B = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + B = linear_operator_from_matrix(B) + result = (B.T @ A.T).T + assert np.allclose(result, [[38, 44, 50, 56], [83, 98, 113, 128]]) + +@pytest.mark.parametrize("density", [0.01, 0.05, 0.10, 0.30]) +@pytest.mark.parametrize("n_components", [1, 2, 3, 10, SPARSE_MAX_COMPONENTS]) +@pytest.mark.parametrize("format", ["csr", "csc"]) +@pytest.mark.parametrize("svd_solver", PCA_SOLVERS) +@pytest.mark.parametrize("rtol", [1e-07])#, 1e-06, 1e-05, 1e-04, 1e-03, 1e-02, 1e-01, 1e-00]) +def test_pca_sparse(global_random_seed, rtol, svd_solver, format, n_components, density): + if svd_solver in ["lobpcg", "arpack"] and n_components==SPARSE_MAX_COMPONENTS: + pytest.skip("lobpcg and arpack don't support full solves") + random_state = np.random.RandomState(global_random_seed) + X = sp.sparse.random( + SPARSE_M, SPARSE_N, format=format, random_state=random_state, density=density + ) + pca = PCA(n_components=n_components, svd_solver=svd_solver) + pca.fit(X) + + Xd = np.asarray(X.todense()) + pcad = PCA(n_components=n_components, svd_solver=svd_solver) + pcad.fit(Xd) + + assert_allclose(pca.components_, pcad.components_, rtol=rtol) + assert_allclose(pca.singular_values_, pcad.singular_values_, rtol=rtol) + + def test_no_empty_slice_warning(): # test if we avoid numpy warnings for computing over empty arrays n_components = 10 @@ -492,17 +545,6 @@ def test_pca_svd_solver_auto(data, n_components, expected_solver): assert_allclose(pca_auto.components_, pca_test.components_) -@pytest.mark.parametrize("svd_solver", PCA_SOLVERS) -def test_pca_sparse_input(svd_solver): - X = np.random.RandomState(0).rand(5, 4) - X = sp.sparse.csr_matrix(X) - assert sp.sparse.issparse(X) - - pca = PCA(n_components=3, svd_solver=svd_solver) - with pytest.raises(TypeError): - pca.fit(X) - - @pytest.mark.parametrize("svd_solver", PCA_SOLVERS) def test_pca_deterministic_output(svd_solver): rng = np.random.RandomState(0) diff --git a/sklearn/utils/extmath.py b/sklearn/utils/extmath.py index 577ed28f3f1b5..1c98a11a1f412 100644 --- a/sklearn/utils/extmath.py +++ b/sklearn/utils/extmath.py @@ -173,7 +173,11 @@ def safe_sparse_dot(a, b, *, dense_output=False): dot_product : {ndarray, sparse matrix} Sparse if ``a`` and ``b`` are sparse and ``dense_output=False``. """ - if a.ndim > 2 or b.ndim > 2: + if isinstance(b, sparse.linalg.LinearOperator): + # LinearOperator cannot be the RHS operand of a matmul + # so we use a linear algebra identity to make it the LHS + ret = (b.T @ a.T).T + elif a.ndim > 2 or b.ndim > 2: if sparse.issparse(a): # sparse is always 2D. Implies b is 3D+ # [i, j] @ [k, ..., l, m, n] -> [i, k, ..., l, n]