8000 MAINT validate parameter in VarianceThreshold (#23581) · scikit-learn/scikit-learn@b60832d · GitHub
[go: up one dir, main page]

Skip to content

Commit b60832d

Browse files
MAINT validate parameter in VarianceThreshold (#23581)
Co-authored-by: Sangam Swadi K <sangamswadik@users.noreply.github.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 8a8d068 commit b60832d

File tree

3 files changed

+5
-12
lines changed

3 files changed

+5
-12
lines changed

sklearn/feature_selection/_variance_threshold.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Author: Lars Buitinck
22
# License: 3-clause BSD
3+
from numbers import Real
34

45
import numpy as np
56
from ..base import BaseEstimator
67
from ._base import SelectorMixin
78
from ..utils.sparsefuncs import mean_variance_axis, min_max_axis
89
from ..utils.validation import check_is_fitted
10+
from ..utils._param_validation import Interval
911

1012

1113
class VarianceThreshold(SelectorMixin, BaseEstimator):
@@ -67,6 +69,8 @@ class VarianceThreshold(SelectorMixin, BaseEstimator):
6769
[1, 1]])
6870
"""
6971

72+
_parameter_constraints = {"threshold": [Interval(Real, 0, None, closed="left")]}
73+
7074
def __init__(self, threshold=0.0):
7175
self.threshold = threshold
7276

@@ -88,6 +92,7 @@ def fit(self, X, y=None):
8892
self : object
8993
Returns the instance itself.
9094
"""
95+
self._validate_params()
9196
X = self._validate_data(
9297
X,
9398
accept_sparse=("csr", "csc"),
@@ -110,8 +115,6 @@ def fit(self, X, y=None):
110115
# for constant features
111116
compare_arr = np.array([self.variances_, peak_to_peaks])
112117
self.variances_ = np.nanmin(compare_arr, axis=0)
113-
elif self.threshold < 0.0:
114-
raise ValueError(f"Threshold must be non-negative. Got: {self.threshold}")
115118

116119
if np.all(~np.isfinite(self.variances_) | (self.variances_ <= self.threshold)):
117120
msg = "No feature in X meets the variance threshold {0:.5f}"

sklearn/feature_selection/tests/test_variance_threshold.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,6 @@ def test_variance_threshold():
3232
assert (len(data), 1) == X.shape
3333

3434

35-
@pytest.mark.parametrize("X", [data, csr_matrix(data)])
36-
def test_variance_negative(X):
37-
"""Test VarianceThreshold with negative variance."""
38-
var_threshold = VarianceThreshold(threshold=-1.0)
39-
msg = r"^Threshold must be non-negative. Got: -1.0$"
40-
with pytest.raises(ValueError, match=msg):
41-
var_threshold.fit(X)
42-
43-
4435
@pytest.mark.skipif(
4536
np.var(data2) == 0,
4637
reason=(

sklearn/tests/test_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,6 @@ def test_estimators_do_not_raise_errors_in_init_or_set_params(Estimator):
609609
"TransformedTargetRegressor",
610610
"TruncatedSVD",
611611
"TweedieRegressor",
612-
"VarianceThreshold",
613612
"VotingClassifier",
614613
"VotingRegressor",
615614
]

0 commit comments

Comments
 (0)
0