10000 ENH Adds feature_names_in_ to TransformedTargetRegressor (#20868) · scikit-learn/scikit-learn@c5d9939 · GitHub
[go: up one dir, main page]

Skip to content

Commit c5d9939

Browse files
authored
ENH Adds feature_names_in_ to TransformedTargetRegressor (#20868)
1 parent 986d8f2 commit c5d9939

File tree

2 files changed

+21
-22
lines changed

2 files changed

+21
-22
lines changed

sklearn/compose/_target.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ..base import BaseEstimator, RegressorMixin, clone
1010
from ..utils.validation import check_is_fitted
11+
from ..utils._tags import _safe_tags
1112
from ..utils import check_array, _safe_indexing
1213
from ..preprocessing import FunctionTransformer
1314
from ..exceptions import NotFittedError
@@ -88,6 +89,12 @@ class TransformedTargetRegressor(RegressorMixin, BaseEstimator):
8889
8990
.. versionadded:: 0.24
9091
92+
feature_names_in_ : ndarray of shape (`n_features_in_`,)
93+
Names of features seen during :term:`fit`. Defined only when `X`
94+
has feature names that are all strings.
95+
96+
.. versionadded:: 1.0
97+
9198
Examples
9299
--------
93100
>>> import numpy as np
@@ -233,6 +240,9 @@ def fit(self, X, y, **fit_params):
233240

234241
self.regressor_.fit(X, y_trans, **fit_params)
235242

243+
if hasattr(self.regressor_, "feature_names_in_"):
244+
self.feature_names_in_ = self.regressor_.feature_names_in_
245+
236246
return self
237247

238248
def predict(self, X, **predict_params):
@@ -272,7 +282,16 @@ def predict(self, X, **predict_params):
272282
return pred_trans
273283

274284
def _more_tags(self):
275-
return {"poor_score": True, "no_validation": True}
285+
regressor = self.regressor
286+
if regressor is None:
287+
from ..linear_model import LinearRegression
288+
289+
regressor = LinearRegression()
290+
291+
return {
292+
"poor_score": True,
293+
"multioutput": _safe_tags(regressor, key="multioutput"),
294+
}
276295

277296
@property
278297
def n_features_in_(self):

sklearn/tests/test_common.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -316,17 +316,6 @@ def test_check_n_features_in_after_fitting(estimator):
316316
check_n_features_in_after_fitting(estimator.__class__.__name__, estimator)
317317

318318

319-
# TODO: When more modules get added, we can remove it from this list to make
320-
# sure it gets tested. After we finish each module we can move the checks
321-
# into check_estimator.
322-
# NOTE: When running `check_dataframe_column_names_consistency` on a meta-estimator that
323-
# delegates validation to a base estimator, the check is testing that the base estimator
324-
# is checking for column name consistency.
325-
326-
COLUMN_NAME_MODULES_TO_IGNORE = {
327-
"compose",
328-
}
329-
330319
_estimators_to_test = list(
331320
chain(
332321
_tested_estimators(),
@@ -336,16 +325,7 @@ def test_check_n_features_in_after_fitting(estimator):
336325
)
337326

338327

339-
column_name_estimators = [
340-
est
341-
for est in _estimators_to_test
342-
if est.__module__.split(".")[1] not in COLUMN_NAME_MODULES_TO_IGNORE
343-
]
344-
345-
346-
@pytest.mark.parametrize(
347-
"estimator", column_name_estimators, ids=_get_check_estimator_ids
348-
)
328+
@pytest.mark.parametrize("estimator", _estimators_to_test, ids=_get_check_estimator_ids)
349329
def test_pandas_column_name_consistency(estimator):
350330
_set_checking_parameters(estimator)
351331
with ignore_warnings(category=(FutureWarning)):

0 commit comments

Comments
 (0)
0