8000 TST Extend tests for `scipy.sparse/*array` in `sklearn/model_selectio… · REDVM/scikit-learn@8d45d7a · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 8d45d7a

Browse files
Charlie-XIAOREDVM
authored andcommitted
TST Extend tests for scipy.sparse/*array in sklearn/model_selection/tests/test_split (scikit-learn#27246)
1 parent fe7944c commit 8d45d7a

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

sklearn/model_selection/tests/test_split.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,7 @@
66
import numpy as np
77
import pytest
88
from scipy import stats
9-
from scipy.sparse import (
10-
coo_matrix,
11-
csc_matrix,
12-
csr_matrix,
13-
issparse,
14-
)
9+
from scipy.sparse import issparse
1510
from scipy.special import comb
1611

1712
from sklearn import config_context
@@ -63,6 +58,7 @@
6358
from sklearn.utils.estimator_checks import (
6459
_array_api_for_tests,
6560
)
61+
from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
6662
from sklearn.utils.validation import _num_samples
6763

6864
NO_GROUP_SPLITTERS = [
@@ -90,7 +86,6 @@
9086

9187
X = np.ones(10)
9288
y = np.arange(10) // 2
93-
P_sparse = coo_matrix(np.eye(5))
9489
test_groups = (
9590
np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
9691
np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
@@ -1335,9 +1330,10 @@ def test_array_api_train_test_split(shuffle, stratify, array_namepsace, device,
13351330
)
13361331

13371332

1338-
def test_train_test_split():
1333+
@pytest.mark.parametrize("coo_container", COO_CONTAINERS)
1334+
def test_train_test_split(coo_container):
13391335
X = np.arange(100).reshape((10, 10))
1340-
X_s = coo_matrix(X)
1336+
X_s = coo_container(X)
13411337
y = np.arange(10)
13421338

13431339
# simple test
@@ -1423,16 +1419,17 @@ def test_train_test_split_pandas():
14231419
assert isinstance(X_test, InputFeatureType)
14241420

14251421

1426-
def test_train_test_split_sparse():
1422+
@pytest.mark.parametrize(
1423+
"sparse_container", COO_CONTAINERS + CSC_CONTAINERS + CSR_CONTAINERS
1424+
)
1425+
def test_train_test_split_sparse(sparse_container):
14271426
# check that train_test_split converts scipy sparse matrices
14281427
# to csr, as stated in the documentation
14291428
X = np.arange(100).reshape((10, 10))
1430-
sparse_types = [csr_matrix, csc_matrix, coo_matrix]
1431-
for InputFeatureType in sparse_types:
1432-
X_s = InputFeatureType(X)
1433-
X_train, X_test = train_test_split(X_s)
1434-
assert issparse(X_train) and X_train.format == "csr"
1435-
assert issparse(X_test) and X_test.format == "csr"
1429+
X_s = sparse_container(X)
1430+
X_train, X_test = train_test_split(X_s)
1431+
assert issparse(X_train) and X_train.format == "csr"
1432+
assert issparse(X_test) and X_test.format == "csr"
14361433

14371434

14381435
def test_train_test_split_mock_pandas():

0 commit comments

Comments
 (0)
0