-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
ENH Allow fitting PCA on sparse X with arpack solvers #18689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
df6232d
14a8954
ed87334
ce97d14
f5158a1
88492fb
fd4733b
6cbeca4
b3b1f5c
74d7afe
602e3bf
0b0d4cd
f51f6c5
f731195
80d9541
1b69b48
9143651
ae0ea16
bcfe316
5c9627b
858f437
0db8d54
a664999
757ed37
5329041
67a40dd
236478e
5a15364
af56492
ce5e8d4
0ce6d50
a201440
05df49c
bfa4faa
9cbd76b
7c8bece
0ed4d69
7fe38e8
e374fe2
1bfad54
3ca40e2
f10cc33
0d91e1c
4bcac1f
e6daeec
bf81430
e920da0
c7eda6a
f9c5329
8000
e416879
a519798
a647d8d
1ccb9a8
5bf6134
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,11 +21,28 @@ | |
_get_check_estimator_ids, | ||
check_array_api_input_and_values, | ||
) | ||
from sklearn.utils.fixes import CSR_CONTAINERS | ||
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS | ||
|
||
iris = datasets.load_iris() | ||
PCA_SOLVERS = ["full", "arpack", "randomized", "auto"] | ||
|
||
# `SPARSE_M` and `SPARSE_N` could be larger, but be aware: | ||
# * SciPy's generation of random sparse matrix can be costly | ||
# * A (SPARSE_M, SPARSE_N) dense array is allocated to compare against | ||
SPARSE_M, SPARSE_N = 1000, 300 # arbitrary | ||
SPARSE_MAX_COMPONENTS = min(SPARSE_M, SPARSE_N) | ||
ivirshup marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def _check_fitted_pca_close(pca1, pca2, rtol): | ||
assert_allclose(pca1.components_, pca2.components_, rtol=rtol) | ||
assert_allclose(pca1.explained_variance_, pca2.explained_variance_, rtol=rtol) | ||
assert_allclose(pca1.singular_values_, pca2.singular_values_, rtol=rtol) | ||
assert_allclose(pca1.mean_, pca2.mean_, rtol=rtol) | ||
assert_allclose(pca1.n_components_, pca2.n_components_, rtol=rtol) | ||
assert_allclose(pca1.n_samples_, pca2.n_samples_, rtol=rtol) | ||
assert_allclose(pca1.noise_variance_, pca2.noise_variance_, rtol=rtol) | ||
assert_allclose(pca1.n_features_in_, pca2.n_features_in_, rtol=rtol) | ||
|
||
|
||
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS) | ||
@pytest.mark.parametrize("n_components", range(1, iris.data.shape[1])) | ||
|
@@ -49,6 +66,118 @@ def test_pca(svd_solver, n_components): | |
assert_allclose(np.dot(cov, precision), np.eye(X.shape[1]), atol=1e-12) | ||
|
||
|
||
@pytest.mark.parametrize("density", [0.01, 0.1, 0.30]) | ||
@pytest.mark.parametrize("n_components", [1, 2, 10]) | ||
@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS) | ||
@pytest.mark.parametrize("svd_solver", ["arpack"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this work with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I gave it a quick try (by allowing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A strategy to implement this and identify the source of the discrepancy would be to first write tests to check that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we introduce There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is what I'd like to do. This test was oringinally implemented in #24415 with support for randomized and lobpcg. But there are numerical stability issues, but I think it's worth getting just arpack through right now since the performance improvement is so great. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the future PR: the problem might be that the default choice of |
||
@pytest.mark.parametrize("scale", [1, 10, 100]) | ||
def test_pca_sparse( | ||
global_random_seed, svd_solver, sparse_container, n_components, density, scale | ||
): | ||
# Make sure any tolerance changes pass with SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all" | ||
rtol = 5e-07 | ||
transform_rtol = 3e-05 | ||
|
||
random_state = np.random.default_rng(global_random_seed) | ||
X = sparse_container( | ||
sp.sparse.random( | ||
SPARSE_M, | ||
SPARSE_N, | ||
random_state=random_state, | ||
density=density, | ||
) | ||
) | ||
# Scale the data + vary the column means | ||
scale_vector = random_state.random(X.shape[1]) * scale | ||
X = X.multiply(scale_vector) | ||
|
||
pca = PCA( | ||
n_components=n_components, | ||
F438 | svd_solver=svd_solver, | |
random_state=global_random_seed, | ||
) | ||
pca.fit(X) | ||
|
||
Xd = X.toarray() | ||
pcad = PCA( | ||
n_components=n_components, | ||
svd_solver=svd_solver, | ||
random_state=global_random_seed, | ||
) | ||
pcad.fit(Xd) | ||
|
||
# Fitted attributes equality | ||
_check_fitted_pca_close(pca, pcad, rtol=rtol) | ||
|
||
# Test transform | ||
X2 = sparse_container( | ||
sp.sparse.random( | ||
SPARSE_M, | ||
SPARSE_N, | ||
random_state=random_state, | ||
density=density, | ||
) | ||
) | ||
X2d = X2.toarray() | ||
|
||
assert_allclose(pca.transform(X2), pca.transform(X2d), rtol=transform_rtol) | ||
assert_allclose(pca.transform(X2), pcad.transform(X2d), rtol=transform_rtol) | ||
|
||
ivirshup marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS) | ||
def test_pca_sparse_fit_transform(global_random_seed, sparse_container): | ||
random_state = np.random.default_rng(global_random_seed) | ||
X = sparse_container( | ||
sp.sparse.random( | ||
SPARSE_M, | ||
SPARSE_N, | ||
random_state=random_state, | ||
density=0.01, | ||
) | ||
) | ||
X2 = sparse_container( | ||
sp.sparse.random( | ||
SPARSE_M, | ||
SPARSE_N, | ||
random_state=random_state, | ||
density=0.01, | ||
) | ||
) | ||
|
||
pca_fit = PCA(n_components=10, svd_solver="arpack", random_state=global_random_seed) | ||
pca_fit_transform = PCA( | ||
n_components=10, svd_solver="arpack", random_state=global_random_seed | ||
) | ||
|
||
pca_fit.fit(X) | ||
transformed_X = pca_fit_transform.fit_transform(X) | ||
|
||
_check_fitted_pca_close(pca_fit, pca_fit_transform, rtol=1e-10) | ||
assert_allclose(transformed_X, pca_fit_transform.transform(X), rtol=2e-9) | ||
assert_allclose(transformed_X, pca_fit.transform(X), rtol=2e-9) | ||
assert_allclose(pca_fit.transform(X2), pca_fit_transform.transform(X2), rtol=2e-9) | ||
|
||
|
||
@pytest.mark.parametrize("svd_solver", ["randomized", "full", "auto"]) | ||
ivirshup marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@pytest.mark.parametrize("sparse_container", CSR_CONTAINERS + CSC_CONTAINERS) | ||
def test_sparse_pca_solver_error(global_random_seed, svd_solver, sparse_container): | ||
random_state = np.random.RandomState(global_random_seed) | ||
X = sparse_container( | ||
sp.sparse.random( | ||
SPARSE_M, | ||
SPARSE_N, | ||
random_state=random_state, | ||
) | ||
) | ||
pca = PCA(n_components=30, svd_solver=svd_solver) | ||
error_msg_pattern = ( | ||
f'PCA only support sparse inputs with the "arpack" solver, while "{svd_solver}"' | ||
" was passed" | ||
) | ||
with pytest.raises(TypeError, match=error_msg_pattern): | ||
pca.fit(X) | ||
|
||
|
||
def test_no_empty_slice_warning(): | ||
# test if we avoid numpy warnings for computing over empty arrays | ||
n_components = 10 | ||
|
@@ -502,18 +631,6 @@ def test_pca_svd_solver_auto(data, n_components, expected_solver): | |
assert_allclose(pca_auto.components_, pca_test.components_) | ||
|
||
|
||
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS) | ||
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) | ||
def test_pca_sparse_input(svd_solver, csr_container): | ||
X = np.random.RandomState(0).rand(5, 4) | ||
X = csr_container(X) | ||
assert sp.sparse.issparse(X) | ||
|
||
pca = PCA(n_components=3, svd_solver=svd_solver) | ||
with pytest.raises(TypeError): | ||
pca.fit(X) | ||
|
||
|
||
@pytest.mark.parametrize("svd_solver", PCA_SOLVERS) | ||
def test_pca_deterministic_output(svd_solver): | ||
rng = np.random.RandomState(0) | ||
|
Uh oh!
There was an error while loading. Please reload this page.