8000 Support sparse matrices in KernelRidge. · raghavrv/scikit-learn@7fa9b0e · GitHub
[go: up one dir, main page]

Skip to content

Commit 7fa9b0e

Browse files
committed
Support sparse matrices in KernelRidge.
Fixes scikit-learn#4384.
1 parent 2b28b27 commit 7fa9b0e

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

sklearn/kernel_ridge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def fit(self, X, y=None, sample_weight=None):
141141
self : returns an instance of self.
142142
"""
143143
# Convert data
144-
X, y = check_X_y(X, y, multi_output=True)
144+
X, y = check_X_y(X, y, accept_sparse=("csr", "csc"), multi_output=True)
145145

146146
n_samples = X.shape[0]
147147
K = self._get_kernel(X)

sklearn/tests/test_kernel_ridge.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import scipy.sparse as sp
23

34
from sklearn.datasets import make_regression
45
from sklearn.linear_model import Ridge
@@ -10,6 +11,8 @@
1011

1112

1213
X, y = make_regression(n_features=10)
14+
Xcsr = sp.csr_matrix(X)
15+
Xcsc = sp.csc_matrix(X)
1316
Y = np.array([y, y]).T
1417

1518

@@ -19,6 +22,20 @@ def test_kernel_ridge():
1922
assert_array_almost_equal(pred, pred2)
2023

2124

25+
def test_kernel_ridge_csr():
26+
pred = Ridge(alpha=1, fit_intercept=False,
27+
solver="cholesky").fit(Xcsr, y).predict(Xcsr)
28+
pred2 = KernelRidge(kernel="linear", alpha=1).fit(Xcsr, y).predict(Xcsr)
29+
assert_array_almost_equal(pred, pred2)
30+
31+
32+
def test_kernel_ridge_csc():
33+
pred = Ridge(alpha=1, fit_intercept=False,
34+
solver="cholesky").fit(Xcsc, y).predict(Xcsc)
35+
pred2 = KernelRidge(kernel="linear", alpha=1).fit(Xcsc, y).predict(Xcsc)
36+
assert_array_almost_equal(pred, pred2)
37+
38+
2239
def test_kernel_ridge_singular_kernel():
2340
# alpha=0 causes a LinAlgError in computing the dual coefficients,
2441
# which causes a fallback to a lstsq solver. This is tested here.

0 commit comments

Comments
 (0)
0