10000 Use scipy sparse svd · scikit-learn/scikit-learn@e3d5bba · GitHub
[go: up one dir, main page]

Skip to content

Commit e3d5bba

Browse files
committed
Use scipy sparse svd
1 parent 8d815af commit e3d5bba

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

sklearn/linear_model/ridge.py

Lines changed: 7 additions & 6 deletions
965
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,7 @@ def __init__(self, alphas=(0.1, 1.0, 10.0),
910910

911911
def _pre_compute(self, X, y, centered_kernel=True):
912912
# even if X is very sparse, K is usually very dense
913+
n_samples, n_features = X.shape
913914
K = safe_sparse_dot(X, X.T, dense_output=True)
914915
# the following emulates an additional constant regressor
915916
# corresponding to fit_intercept=True
@@ -960,13 +961,14 @@ def _values(self, alpha, y, v, Q, QT_y):
960961
return y - (c / G_diag), c
961962

962963
def _pre_compute_svd(self, X, y, centered_kernel=True):
963-
if sparse.issparse(X):
964-
raise TypeError("SVD not supported for sparse matrices")
965964
if centered_kernel:
966
X = np.hstack((X, np.ones((X.shape[0], 1))))
967966
# to emulate fit_intercept=True situation, add a column on ones
968967
# Note that by centering, the other columns are orthogonal to that one
969-
U, s, _ = linalg.svd(X, full_matrices=0)
968+
if sparse.issparse(X):
969+
U, s, _ = sp_linalg.svds(X)
970+
else:
971+
U, s, _ = linalg.svd(X, full_matrices=0)
970972
v = s ** 2
971973
UT_y = np.dot(U.T, y)
972974
return v, U, UT_y
@@ -1027,7 +1029,7 @@ def fit(self, X, y, sample_weight=None):
10271029
with_sw = len(np.shape(sample_weight))
10281030

10291031
if gcv_mode is None or gcv_mode == 'auto':
1030-
if sparse.issparse(X) or n_features > n_samples or with_sw:
1032+
if n_features > n_samples or with_sw:
10311033
gcv_mode = 'eigen'
10321034
else:
10331035
gcv_mode = 'svd'
@@ -1283,8 +1285,7 @@ class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
12831285
See glossary entry for :term:`cross-validation estimator`.
12841286
12851287
By default, it performs Generalized Cross-Validation, which is a form of
1286-
efficient Leave-One-Out cross-validation. Currently, only the n_features >
1287-
n_samples case is handled efficiently.
1288+
efficient Leave-One-Out cross-validation.
12881289
12891290
Read more in the :ref:`User Guide <ridge_regression>`.
12901291

0 commit comments

Comments
 (0)
0