|
1 | 1 | import numpy as np
|
2 | 2 | import pytest
|
3 | 3 | from numpy.testing import assert_array_almost_equal, assert_array_equal
|
4 |
| -from scipy import sparse |
5 | 4 |
|
6 | 5 | from sklearn.datasets import make_regression
|
7 | 6 | from sklearn.exceptions import ConvergenceWarning
|
|
14 | 13 | from sklearn.linear_model._ransac import _dynamic_max_trials
|
15 | 14 | from sklearn.utils import check_random_state
|
16 | 15 | from sklearn.utils._testing import assert_allclose
|
| 16 | +from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS |
17 | 17 |
|
18 | 18 | # Generate coordinates of line
|
19 | 19 | X = np.arange(-200, 200)
|
@@ -248,38 +248,11 @@ def is_data_valid(X, y):
|
248 | 248 | assert ransac_estimator.n_skips_invalid_model_ == 0
|
249 | 249 |
|
250 | 250 |
|
251 |
| -def test_ransac_sparse_coo(): |
252 |
| - X_sparse = sparse.coo_matrix(X) |
253 |
| - |
254 |
| - estimator = LinearRegression() |
255 |
| - ransac_estimator = RANSACRegressor( |
256 |
| - estimator, min_samples=2, residual_threshold=5, random_state=0 |
257 |
| - ) |
258 |
| - ransac_estimator.fit(X_sparse, y) |
259 |
| - |
260 |
| - ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_) |
261 |
| - ref_inlier_mask[outliers] = False |
262 |
| - |
263 |
| - assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask) |
264 |
| - |
265 |
| - |
266 |
| -def test_ransac_sparse_csr(): |
267 |
| - X_sparse = sparse.csr_matrix(X) |
268 |
| - |
269 |
| - estimator = LinearRegression() |
270 |
| - ransac_estimator = RANSACRegressor( |
271 |
| - estimator, min_samples=2, residual_threshold=5, random_state=0 |
272 |
| - ) |
273 |
| - ransac_estimator.fit(X_sparse, y) |
274 |
| - |
275 |
| - ref_inlier_mask = np.ones_like(ransac_estimator.inlier_mask_).astype(np.bool_) |
276 |
| - ref_inlier_mask[outliers] = False |
277 |
| - |
278 |
| - assert_array_equal(ransac_estimator.inlier_mask_, ref_inlier_mask) |
279 |
| - |
280 |
| - |
281 |
| -def test_ransac_sparse_csc(): |
282 |
| - X_sparse = sparse.csc_matrix(X) |
| 251 | +@pytest.mark.parametrize( |
| 252 | + "sparse_container", COO_CONTAINERS + CSR_CONTAINERS + CSC_CONTAINERS |
| 253 | +) |
| 254 | +def test_ransac_sparse(sparse_container): |
| 255 | + X_sparse = sparse_container(X) |
283 | 256 |
|
284 | 257 | estimator = LinearRegression()
|
285 | 258 | ransac_estimator = RANSACRegressor(
|
|
0 commit comments