8000 TST Extend tests for `scipy.sparse.*array` in `sklearn/tests/test_ker… · scikit-learn/scikit-learn@611cffe · GitHub
[go: up one dir, main page]

Skip to content

Commit 611cffe

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/tests/test_kernel_ridge.py (#27270)
1 parent a0bdf60 commit 611cffe

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

sklearn/tests/test_kernel_ridge.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import numpy as np
2-
import scipy.sparse as sp
2+
import pytest
33

44
from sklearn.datasets import make_regression
55
from sklearn.kernel_ridge import KernelRidge
66
from sklearn.linear_model import Ridge
77
from sklearn.metrics.pairwise import pairwise_kernels
88
from sklearn.utils._testing import assert_array_almost_equal, ignore_warnings
9+
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS
910

1011
X, y = make_regression(n_features=10, random_state=0)
11-
Xcsr = sp.csr_matrix(X)
12-
Xcsc = sp.csc_matrix(X)
1312
Y = np.array([y, y]).T
1413

1514

@@ -19,23 +18,15 @@ def test_kernel_ridge():
1918
assert_array_almost_equal(pred, pred2)
2019

2120

22-
def test_kernel_ridge_csr():
21+
@pytest.mark.parametrize("sparse_container", [*CSR_CONTAINERS, *CSC_CONTAINERS])
22+
def test_kernel_ridge_sparse(sparse_container):
23+
X_sparse = sparse_container(X)
2324
pred = (
2425
Ridge(alpha=1, fit_intercept=False, solver="cholesky")
25-
.fit(Xcsr, y)
26-
.predict(Xcsr)
26+
.fit(X_sparse, y)
27+
.predict(X_sparse)
2728
)
28-
pred2 = KernelRidge(kernel="linear", alpha=1).fit(Xcsr, y).predict(Xcsr)
29-
assert_array_almost_equal(pred, pred2)
30-
31-
32-
def test_kernel_ridge_csc():
33-
pred = (
34-
Ridge(alpha=1, fit_intercept=False, solver="cholesky")
35-
.fit(Xcsc, y)
36-
.predict(Xcsc)
37-
)
38-
pred2 = KernelRidge(kernel="linear", alpha=1).fit(Xcsc, y).predict(Xcsc)
29+
pred2 = KernelRidge(kernel="linear", alpha=1).fit(X_sparse, y).predict(X_sparse)
3930
assert_array_almost_equal(pred, pred2)
4031

4132

0 commit comments

Comments
 (0)
0