8000 TST Extend tests for `scipy.sparse/*array` in `sklearn/feature_select… · scikit-learn/scikit-learn@f2b5a2d · GitHub
[go: up one dir, main page]

Skip to content

Commit f2b5a2d

Browse files
authored
TST Extend tests for scipy.sparse/*array in sklearn/feature_selection/tests/test_variance_threshold (#27222)
1 parent 749136e commit f2b5a2d

File tree

1 file changed

+30
-19
lines changed

1 file changed

+30
-19
lines changed
Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,39 @@
11
import numpy as np
22
import pytest
3-
from scipy.sparse import bsr_matrix, csc_matrix, csr_matrix
43

54
from sklearn.feature_selection import VarianceThreshold
65
from sklearn.utils._testing import assert_array_equal
6+
from sklearn.utils.fixes import BSR_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS
77

88
data = [[0, 1, 2, 3, 4], [0, 2, 2, 3, 5], [1, 1, 2, 4, 0]]
99

1010
data2 = [[-0.13725701]] * 10
1111

1212

13-
def test_zero_variance():
13+
@pytest.mark.parametrize(
14+
"sparse_container", [None] + BSR_CONTAINERS + CSC_CONTAINERS + CSR_CONTAINERS
15+
)
16+
def test_zero_variance(sparse_container):
1417
# Test VarianceThreshold with default setting, zero variance.
18+
X = data if sparse_container is None else sparse_container(data)
19+
sel = VarianceThreshold().fit(X)
20+
assert_array_equal([0, 1, 3, 4], sel.get_support(indices=True))
1521

16-
for X in [data, csr_matrix(data), csc_matrix(data), bsr_matrix(data)]:
17-
sel = VarianceThreshold().fit(X)
18-
assert_array_equal([0, 1, 3, 4], sel.get_support(indices=True))
1922

23+
def test_zero_variance_value_error():
24+
# Test VarianceThreshold with default setting, zero variance, error cases.
2025
with pytest.raises(ValueError):
2126
VarianceThreshold().fit([[0, 1, 2, 3]])
2227
with pytest.raises(ValueError):
2328
VarianceThreshold().fit([[0, 1], [0, 1]])
2429

2530

26-
def test_variance_threshold():
31+
@pytest.mark.parametrize("sparse_container", [None] + CSR_CONTAINERS)
32+
def test_variance_threshold(sparse_container):
2733
# Test VarianceThreshold with custom variance.
28-
for X in [data, csr_matrix(data)]:
29-
X = VarianceThreshold(threshold=0.4).fit_transform(X)
30-
assert (len(data), 1) == X.shape
34+
X = data if sparse_container is None else sparse_container(data)
35+
X = VarianceThreshold(threshold=0.4).fit_transform(X)
36+
assert (len(data), 1) == X.shape
3137

3238

3339
@pytest.mark.skipif(
@@ -37,25 +43,30 @@ def test_variance_threshold():
3743
"as it relies on numerical instabilities."
3844
),
3945
)
40-
def test_zero_variance_floating_point_error():
46+
@pytest.mark.parametrize(
47+
"sparse_container", [None] + BSR_CONTAINERS + CSC_CONTAINERS + CSR_CONTAINERS
48+
)
49+
def test_zero_variance_floating_point_error(sparse_container):
4150
# Test that VarianceThreshold(0.0).fit eliminates features that have
4251
# the same value in every sample, even when floating point errors
4352
# cause np.var not to be 0 for the feature.
4453
# See #13691
54+
X = data2 if sparse_container is None else sparse_container(data2)
55+
msg = "No feature in X meets the variance threshold 0.00000"
56+
with pytest.raises(ValueError, match=msg):
57+
VarianceThreshold().fit(X)
4558

46-
for X in [data2, csr_matrix(data2), csc_matrix(data2), bsr_matrix(data2)]:
47-
msg = "No feature in X meets the variance threshold 0.00000"
48-
with pytest.raises(ValueError, match=msg):
49-
VarianceThreshold().fit(X)
5059

51-
52-
def test_variance_nan():
60+
@pytest.mark.parametrize(
61+
"sparse_container", [None] + BSR_CONTAINERS + CSC_CONTAINERS + CSR_CONTAINERS
62+
)
63+
def test_variance_nan(sparse_container):
5364
arr = np.array(data, dtype=np.float64)
5465
# add single NaN and feature should still be included
5566
arr[0, 0] = np.nan
5667
# make all values in feature NaN and feature should be rejected
5768
arr[:, 1] = np.nan
5869

59-
for X in [arr, csr_matrix(arr), csc_matrix(arr), bsr_matrix(arr)]:
60-
sel = VarianceThreshold().fit(X)
61-
assert_array_equal([0, 3, 4], sel.get_support(indices=True))
70+
X = arr if sparse_container is None else sparse_container(arr)
71+
sel = VarianceThreshold().fit(X)
72+
assert_array 33F8 _equal([0, 3, 4], sel.get_support(indices=True))

0 commit comments

Comments
 (0)
0