|
6 | 6 | import numpy as np
|
7 | 7 | import pytest
|
8 | 8 | from scipy import stats
|
9 |
| -from scipy.sparse import ( |
10 |
| - coo_matrix, |
11 |
| - csc_matrix, |
12 |
| - csr_matrix, |
13 |
| - issparse, |
14 |
| -) |
| 9 | +from scipy.sparse import issparse |
15 | 10 | from scipy.special import comb
|
16 | 11 |
|
17 | 12 | from sklearn import config_context
|
|
63 | 58 | from sklearn.utils.estimator_checks import (
|
64 | 59 | _array_api_for_tests,
|
65 | 60 | )
|
| 61 | +from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS |
66 | 62 | from sklearn.utils.validation import _num_samples
|
67 | 63 |
|
68 | 64 | NO_GROUP_SPLITTERS = [
|
|
90 | 86 |
|
91 | 87 | X = np.ones(10)
|
92 | 88 | y = np.arange(10) // 2
|
93 |
| -P_sparse = coo_matrix(np.eye(5)) |
94 | 89 | test_groups = (
|
95 | 90 | np.array([1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3]),
|
96 | 91 | np.array([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]),
|
@@ -1335,9 +1330,10 @@ def test_array_api_train_test_split(shuffle, stratify, array_namepsace, device,
|
1335 | 1330 | )
|
1336 | 1331 |
|
1337 | 1332 |
|
1338 |
| -def test_train_test_split(): |
| 1333 | +@pytest.mark.parametrize("coo_container", COO_CONTAINERS) |
| 1334 | +def test_train_test_split(coo_container): |
1339 | 1335 | X = np.arange(100).reshape((10, 10))
|
1340 |
| - X_s = coo_matrix(X) |
| 1336 | + X_s = coo_container(X) |
1341 | 1337 | y = np.arange(10)
|
1342 | 1338 |
|
1343 | 1339 | # simple test
|
@@ -1423,16 +1419,17 @@ def test_train_test_split_pandas():
|
1423 | 1419 | assert isinstance(X_test, InputFeatureType)
|
1424 | 1420 |
|
1425 | 1421 |
|
1426 |
| -def test_train_test_split_sparse(): |
| 1422 | +@pytest.mark.parametrize( |
| 1423 | + "sparse_container", COO_CONTAINERS + CSC_CONTAINERS + CSR_CONTAINERS |
| 1424 | +) |
| 1425 | +def test_train_test_split_sparse(sparse_container): |
1427 | 1426 | # check that train_test_split converts scipy sparse matrices
|
1428 | 1427 | # to csr, as stated in the documentation
|
1429 | 1428 | X = np.arange(100).reshape((10, 10))
|
1430 |
| - sparse_types = [csr_matrix, csc_matrix, coo_matrix] |
1431 |
| - for InputFeatureType in sparse_types: |
1432 |
| - X_s = InputFeatureType(X) |
1433 |
| - X_train, X_test = train_test_split(X_s) |
1434 |
| - assert issparse(X_train) and X_train.format == "csr" |
1435 |
| - assert issparse(X_test) and X_test.format == "csr" |
| 1429 | + X_s = sparse_container(X) |
| 1430 | + X_train, X_test = train_test_split(X_s) |
| 1431 | + assert issparse(X_train) and X_train.format == "csr" |
| 1432 | + assert issparse(X_test) and X_test.format == "csr" |
1436 | 1433 |
|
1437 | 1434 |
|
1438 | 1435 | def test_train_test_split_mock_pandas():
|
|
0 commit comments