10000 TST Extend tests for `scipy.sparse/*array` in `sklearn/linear_model/t… · REDVM/scikit-learn@72560c1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 72560c1

Browse files
Charlie-XIAOREDVM
authored andcommitted
TST Extend tests for scipy.sparse/*array in sklearn/linear_model/tests/test_ransac (scikit-learn#27233)
1 parent 5ba4c82 commit 72560c1

File tree

1 file changed

+6
-33
lines changed

1 file changed

+6
-33
lines changed

sklearn/linear_model/tests/test_ransac.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22
import pytest
33
from numpy.testing import assert_array_almost_equal, assert_array_equal
4-
from scipy import sparse
54

65
from sklearn.datasets import make_regression
76
from sklearn.exceptions import ConvergenceWarning
@@ -14,6 +13,7 @@
1413
from sklearn.linear_model._ransac import _dynamic_max_trials
1514
from sklearn.utils import check_random_state
1615
from sklearn.utils._testing import assert_allclose
16+
from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
1717

1818
# Generate coordinates of line
1919
X = np.arange(-200, 200)
@@ -248,38 +248,11 @@ def is_data_valid(X, y):
248248
assert ransac_estimator.n_skips_invalid_model_ == 0
249249

250250

251-
def test_ransac_sparse_coo():
252-
X_sparse = sparse.coo_matrix(X)
253-
254-
estimator = LinearRegression()
255-
ransac_estimator = RANSACRegressor(
256-
estimator, min_samples=2, residual_threshold=5, random_state=0
257-
)
258-
ransac_estimator.fit(X_sparse, y)
259-
260-
ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
261-
ref_inlier_mask[outliers] = False
262-
263-
assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
264-
265-
266-
def test_ransac_sparse_csr():
267-
X_sparse = sparse.csr_matrix(X)
268-
269-
estimator = LinearRegression()
270-
ransac_estimator = RANSACRegressor(
271-
estimator, min_samples=2, residual_threshold=5, random_state=0
272-
)
273-
ransac_estimator.fit(X_sparse, y)
274-
275-
ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_)
276-
ref_inlier_mask[outliers] = False
277-
278-
assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask)
279-
280-
281-
def test_ransac_sparse_csc():
282-
X_sparse = sparse.csc_matrix(X)
251+
@pytest.mark.parametrize(
252+
"sparse_container", COO_CONTAINERS + CSR_CONTAINERS + CSC_CONTAINERS
253+
)
254+
def test_ransac_sparse(sparse_container):
255+
X_sparse = sparse_container(X)
283256

284257
estimator = LinearRegression()
285258
ransac_estimator = RANSACRegressor(

0 commit comments

Comments
 (0)
0