8000 TST Extend tests for `scipy.sparse/*array` in `sklearn/ensemble/tests… · scikit-learn/scikit-learn@e1c3813 · GitHub
[go: up one dir, main page]

Skip to content

Commit e1c3813

Browse files
authored
TST Extend tests for scipy.sparse/*array in sklearn/ensemble/tests/test_iforest (#27218)
1 parent 0169bde commit e1c3813

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

sklearn/ensemble/tests/test_iforest.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import numpy as np
1313
import pytest
14-
from scipy.sparse import csc_matrix, csr_matrix
1514

1615
from sklearn.datasets import load_diabetes, load_iris, make_classification
1716
from sklearn.ensemble import IsolationForest
@@ -25,6 +24,7 @@
2524
assert_array_equal,
2625
ignore_warnings,
2726
)
27+
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS
2828

2929
# load iris & diabetes dataset
3030
iris = load_iris()
@@ -47,30 +47,30 @@ def test_iforest(global_random_seed):
4747
).predict(X_test)
4848

4949

50-
def test_iforest_sparse(global_random_seed):
50+
@pytest.mark.parametrize("sparse_container", CSC_CONTAINERS + CSR_CONTAINERS)
51+
def test_iforest_sparse(global_random_seed, sparse_container):
5152
"""Check IForest for various parameter settings on sparse input."""
5253
rng = check_random_state(global_random_seed)
5354
X_train, X_test = train_test_split(diabetes.data[:50], random_state=rng)
5455
grid = ParameterGrid({"max_samples": [0.5, 1.0], "bootstrap": [True, False]})
5556

56-
for sparse_format in [csc_matrix, csr_matrix]:
57-
X_train_sparse = sparse_format(X_train)
58-
X_test_sparse = sparse_format(X_test)
57+
X_train_sparse = sparse_container(X_train)
58+
X_test_sparse = sparse_container(X_test)
5959

60-
for params in grid:
61-
# Trained on sparse format
62-
sparse_classifier = IsolationForest(
63-
n_estimators=10, random_state=global_random_seed, **params
64-
).fit(X_train_sparse)
65-
sparse_results = sparse_classifier.predict(X_test_sparse)
60+
for params in grid:
61+
# Trained on sparse format
62+
sparse_classifier = IsolationForest(
63+
n_estimators=10, random_state=global_random_seed, **params
64+
).fit(X_train_sparse)
65+
sparse_results = sparse_classifier.predict(X_test_sparse)
6666

67-
# Trained on dense format
68-
dense_classifier = IsolationForest(
69-
n_estimators=10, random_state=global_random_seed, **params
70-
).fit(X_train)
71-
dense_results = dense_classifier.predict(X_test)
67+
# Trained on dense format
68+
dense_classifier = IsolationForest(
69+
n_estimators=10, random_state=global_random_seed, **params
70+
).fit(X_train)
71+
dense_results = dense_classifier.predict(X_test)
7272

73-
assert_array_equal(sparse_results, dense_results)
73+
assert_array_equal(sparse_results, dense_results)
7474

7575

7676
def test_iforest_error():
@@ -314,13 +314,14 @@ def test_iforest_with_uniform_data():
314314
assert all(iforest.predict(np.ones((100, 10))) == 1)
315315

316316

317-
def test_iforest_with_n_jobs_does_not_segfault():
317+
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
318+
def test_iforest_with_n_jobs_does_not_segfault(csc_container):
318319
"""Check that Isolation Forest does not segfault with n_jobs=2
319320
320321
Non-regression test for #23252
321322
"""
322323
X, _ = make_classification(n_samples=85_000, n_features=100, random_state=0)
323-
X = csc_matrix(X)
324+
X = csc_container(X)
324325
IsolationForest(n_estimators=10, max_samples=256, n_jobs=2).fit(X)
325326

326327

0 commit comments

Comments
 (0)
0