8000 TST Add a test for meta-estimators with non tabular data (#19755) · scikit-learn/scikit-learn@1ce1715 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1ce1715

Browse files
jeremiedbbogrisel
andauthored
TST Add a test for meta-estimators with non tabular data (#19755)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent dff37c4 commit 1ce1715

File tree

1 file changed

+117
-2
lines changed

1 file changed

+117
-2
lines changed

sklearn/tests/test_metaestimators.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
"""Common tests for metaestimators"""
22
import functools
3+
from inspect import signature
34

45
import numpy as np
56
import pytest
67

78
from sklearn.base import BaseEstimator
9+
from sklearn.base import is_regressor
810
from sklearn.datasets import make_classification
9-
11+
from sklearn.utils import all_estimators
12+
from sklearn.utils.estimator_checks import _enforce_estimator_tags_x
13+
from sklearn.utils.estimator_checks import _enforce_estimator_tags_y
1014
from sklearn.utils.validation import check_is_fitted
11-
from sklearn.pipeline import Pipeline
15+
from sklearn.utils._testing import set_random_state
16+
from sklearn.pipeline import Pipeline, make_pipeline
1217
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
18+
from sklearn.feature_extraction.text import TfidfVectorizer
1319
from sklearn.feature_selection import RFE, RFECV
1420
from sklearn.ensemble import BaggingClassifier
1521
from sklearn.exceptions import NotFittedError
1622
from sklearn.semi_supervised import SelfTrainingClassifier
23+
from sklearn.linear_model import Ridge, LogisticRegression
1724

1825

1926
class DelegatorData:
@@ -151,3 +158,111 @@ def score(self, X, y, *args, **kwargs):
151158
assert not hasattr(delegator, method), (
152159
"%s has method %r when its delegate does not"
153160
% (delegator_data.name, method))
161+
162+
163+
def _generate_meta_estimator_instances_with_pipeline():
164+
"""Generate instances of meta-estimators fed with a pipeline
165+
166+
Are considered meta-estimators all estimators accepting one of "estimator",
167+
"base_estimator" or "estimators".
168+
"""
169+
for _, Estimator in sorted(all_estimators()):
170+
sig = set(signature(Estimator).parameters)
171+
172+
if "estimator" in sig or "base_estimator" in sig:
173+
if is_regressor(Estimator):
174+
estimator = make_pipeline(TfidfVectorizer(), Ridge())
175+
param_grid = {"ridge__alpha": [0.1, 1.0]}
176+
else:
177+
estimator = make_pipeline(TfidfVectorizer(),
178+
LogisticRegression())
179+
param_grid = {"logisticregression__C": [0.1, 1.0]}
180+
181+
if "param_grid" in sig or "param_distributions" in sig:
182+
# SearchCV estimators
183+
extra_params = {"n_iter": 2} if "n_iter" in sig else {}
184+
yield Estimator(estimator, param_grid, **extra_params)
185+
else:
186+
yield Estimator(estimator)
187+
188+
elif "estimators" in sig:
189+
# stacking, voting
190+
if is_regressor(Estimator):
191+
estimator = [
192+
("est1", make_pipeline(TfidfVectorizer(),
193+
Ridge(alpha=0.1))),
194+
("est2", make_pipeline(TfidfVectorizer(),
195+
Ridge(alpha=1))),
196+
]
197+
else:
198+
estimator = [
199+
("est1", make_pipeline(TfidfVectorizer(),
200+
LogisticRegression(C=0.1))),
201+
("est2", make_pipeline(TfidfVectorizer(),
202+
LogisticRegression(C=1))),
203+
]
204+
yield Estimator(estimator)
205+
206+
else:
207+
continue
208+
209+
210+
# TODO: remove data validation for the following estimators
211+
# They should be able to work on any data and delegate data validation to
212+
# their inner estimator(s).
213+
DATA_VALIDATION_META_ESTIMATORS_TO_IGNORE = [
214+
"AdaBoostClassifier",
215+
"AdaBoostRegressor",
216+
"BaggingClassifier",
217+
"BaggingRegressor",
218+
"ClassifierChain",
219+
"IterativeImputer",
220+
"MultiOutputClassifier",
221+
"MultiOutputRegressor",
222+
"OneVsOneClassifier",
223+
"OutputCodeClassifier",
224+
"RANSACRegressor",
225+
"RFE",
226+
"RFECV",
227+
"RegressorChain",
228+
"SelfTrainingClassifier",
229+
"SequentialFeatureSelector" # not applicable (2D data mandatory)
230+
]
231+
232+
DATA_VALIDATION_META_ESTIMATORS = [
233+
est for est in _generate_meta_estimator_instances_with_pipeline() if
234+
est.__class__.__name__ not in DATA_VALIDATION_META_ESTIMATORS_TO_IGNORE
235+
]
236+
237+
238+
def _get_meta_estimator_id(estimator):
239+
return estimator.__class__.__name__
240+
241+
242+
@pytest.mark.parametrize(
243+
"estimator", DATA_VALIDATION_META_ESTIMATORS, ids=_get_meta_estimator_id
244+
)
245+
def test_meta_estimators_delegate_data_validation(estimator):
246+
# Check that meta-estimators delegate data validation to the inner
247+
# estimator(s).
248+
rng = np.random.RandomState(0)
249+
set_random_state(estimator)
250+
251+
n_samples = 30
252+
X = rng.choice(np.array(["aa", "bb", "cc"], dtype=object), size=n_samples)
253+
254+
if is_regressor(estimator):
255+
y = rng.normal(size=n_samples)
256+
else:
257+
y = rng.randint(3, size=n_samples)
258+
259+
X = _enforce_estimator_tags_x(estimator, X)
260+
y = _enforce_estimator_tags_y(estimator, y)
261+
262+
# Calling fit should not raise any data validation exception since X is a
263+
# valid input datastructure for the first step of the pipeline passed as
264+
# base estimator to the meta estimator.
265+
estimator.fit(X, y)
266+
267+
# n_features_in_ should not be defined since data is not tabular data.
268+
assert not hasattr(estimator, "n_features_in_")

0 commit comments

Comments
 (0)
0