8000 TST enable to run common test on stacking and voting estimators (#18045) · scikit-learn/scikit-learn@2199de6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2199de6

Browse files
authored
TST enable to run common test on stacking and voting estimators (#18045)
* TST enable to run common test on stacking and voting estimators * revert config changes * revert unecessary change in _testing * TST remove check_estimator from local files * PEP8
1 parent 18a4841 commit 2199de6

File tree

3 files changed

+19
-44
lines changed

3 files changed

+19
-44
lines changed

sklearn/ensemble/tests/test_stacking.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@
2727
from sklearn.svm import LinearSVC
2828
from sklearn.svm import LinearSVR
2929
from sklearn.svm import SVC
30-
from sklearn.tree import DecisionTreeClassifier
31-
from sklearn.tree import DecisionTreeRegressor
3230
from sklearn.ensemble import RandomForestClassifier
3331
from sklearn.ensemble import RandomForestRegressor
3432
from sklearn.preprocessing import scale
@@ -44,8 +42,6 @@
4442
from sklearn.utils._testing import assert_allclose
4543
from sklearn.utils._testing import assert_allclose_dense_sparse
4644
from sklearn.utils._testing import ignore_warnings
47-
from sklearn.utils.estimator_checks import check_estimator
48-
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
4945

5046
X_diabetes, y_diabetes = load_diabetes(return_X_y=True)
5147
X_iris, y_iris = load_iris(return_X_y=True)
@@ -368,24 +364,6 @@ def test_stacking_randomness(estimator, X, y):
368364
)
369365

370366

371-
# These warnings are raised due to _BaseComposition
372-
@pytest.mark.filterwarnings("ignore:TypeError occurred during set_params")
373-
@pytest.mark.filterwarnings("ignore:Estimator's parameters changed after")
374-
@pytest.mark.parametrize(
375-
"estimator",
376-
[StackingClassifier(
377-
estimators=[('lr', LogisticRegression(random_state=0)),
378-
('tree', DecisionTreeClassifier(random_state=0))]),
379-
StackingRegressor(
380-
estimators=[('lr', LinearRegression()),
381-
('tree', DecisionTreeRegressor(random_state=0))])],
382-
ids=['StackingClassifier', 'StackingRegressor']
383-
)
384-
def test_check_estimators_stacking_estimator(estimator):
385-
check_estimator(estimator)
386-
check_no_attributes_set_in_init(estimator.__class__.__name__, estimator)
387-
388-
389367
def test_stacking_classifier_stratify_default():
390368
# check that we stratify the classes for the default CV
391369
clf = StackingClassifier(

sklearn/ensemble/tests/test_voting.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from sklearn.utils._testing import assert_almost_equal, assert_array_equal
88
from sklearn.utils._testing import assert_array_almost_equal
99
from sklearn.utils._testing import assert_raise_message
10-
from sklearn.utils.estimator_checks import check_estimator
11-
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
1210
from sklearn.exceptions import NotFittedError
1311
from sklearn.linear_model import LinearRegression
1412
from sklearn.linear_model import LogisticRegression
@@ -490,23 +488,6 @@ def test_none_estimator_with_weights(X, y, voter):
490488
assert y_pred.shape == y.shape
491489

492490

493-
@pytest.mark.parametrize(
494-
"estimator",
495-
[VotingRegressor(
496-
estimators=[('lr', LinearRegression()),
497-
('tree', DecisionTreeRegressor(random_state=0))]),
498-
VotingClassifier(
499-
estimators=[('lr', LogisticRegression(random_state=0)),
500-
('tree', DecisionTreeClassifier(random_state=0))])],
501-
ids=['VotingRegressor', 'VotingClassifier']
502-
)
503-
def test_check_estimators_voting_estimator(estimator):
504-
# FIXME: to be removed when meta-estimators can specified themselves
505-
# their testing parameters (for required parameters).
506-
check_estimator(estimator)
507-
check_no_attributes_set_in_init(estimator.__class__.__name__, estimator)
508-
509-
510491
@pytest.mark.parametrize(
511492
"est",
512493
[VotingRegressor(

sklearn/utils/estimator_checks.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from ._testing import create_memmap_backed_data
2626
from ._testing import raises
2727
from . import is_scalar_nan
28+
2829
from ..discriminant_analysis import LinearDiscriminantAnalysis
30+
from ..linear_model import LogisticRegression
2931
from ..linear_model import Ridge
3032

3133
from ..base import (
@@ -344,10 +346,24 @@ def _construct_instance(Estimator):
344346
estimator = Estimator(Ridge())
345347
else:
346348
estimator = Estimator(LinearDiscriminantAnalysis())
349+
elif required_parameters in (['estimators'],):
350+
# Heterogeneous ensemble classes (i.e. stacking, voting)
351+
if issubclass(Estimator, RegressorMixin):
352+
estimator = Estimator(estimators=[
353+
("est1", Ridge(alpha=0.1)),
354+
("est2", Ridge(alpha=1))
355+
])
356+
else:
357+
estimator = Estimator(estimators=[
358+
("est1", LogisticRegression(C=0.1)),
359+
("est2", LogisticRegression(C=1))
360+
])
347361
else:
348-
raise SkipTest("Can't instantiate estimator {} which requires "
349-
"parameters {}".format(Estimator.__name__,
350-
required_parameters))
362+
msg = (f"Can't instantiate estimator {Estimator.__name__} "
363+
f"parameters {required_parameters}")
364+
# raise additional warning to be shown by pytest
365+
warnings.warn(msg, SkipTestWarning)
366+
raise SkipTest(msg)
351367
else:
352368
estimator = Estimator()
353369
return estimator

0 commit comments

Comments
 (0)
0