8000 FIX params validation in SelectFromModel with prefit=True (#23271) · kernc/scikit-learn@6bbb3cb · GitHub
[go: up one dir, main page]

Skip to content

Commit 6bbb3cb

Browse files
glemaitrethomasjpfanjeremiedbb
authored
FIX params validation in SelectFromModel with prefit=True (scikit-learn#23271)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 8f8f0f6 commit 6bbb3cb

File tree

3 files changed

+140
-30
lines changed

3 files changed

+140
-30
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@ random sampling procedures.
9292
the correct variance-scaling coefficient which may result in different model
9393
behavior.
9494

95+
- |Fix| :meth:`feature_selection.SelectFromModel.fit` and
96+
:meth:`feature_selection.SelectFromModel.partial_fit` can now be called with
97+
`prefit=True`. `estimators_` will be a deep copy of `estimator` when
98+
`prefit=True`. :pr:`23271` by :user:`Guillaume Lemaitre <glemaitre>`.
99+
95100
Changelog
96101
---------
97102

sklearn/feature_selection/_from_model.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Authors: Gilles Louppe, Mathieu Blondel, Maheshakya Wijewardena
22
# License: BSD 3 clause
33

4+
from copy import deepcopy
5+
46
import numpy as np
57
import numbers
68

@@ -102,11 +104,10 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
102104
103105
prefit : bool, default=False
104106
Whether a prefit model is expected to be passed into the constructor
105-
directly or not. If True, ``transform`` must be called directly
106-
and SelectFromModel cannot be used with ``cross_val_score``,
107-
``GridSearchCV`` and similar utilities that clone the estimator.
108-
Otherwise train the model using ``fit`` and then ``transform`` to do
109-
feature selection.
107+
directly or not.
108+
If `True`, `estimator` must be a fitted estimator.
109+
If `False`, `estimator` is fitted and updated by calling
110+
`fit` and `partial_fit`, respectively.
110111
111112
norm_order : non-zero int, inf, -inf, default=1
112113
Order of the norm used to filter the vectors of coefficients below
@@ -120,10 +121,13 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
120121
allow.
121122
- If a callable, then it specifies how to calculate the maximum number of
122123
features allowed by using the output of `max_feaures(X)`.
124+
- If `None`, then all features are kept.
123125
124126
To only select based on ``max_features``, set ``threshold=-np.inf``.
125127
126128
.. versionadded:: 0.20
129+
.. versionchanged:: 1.1
130+
`max_features` accepts a callable.
127131
128132
importance_getter : str or callable, default='auto'
129133
If 'auto', uses the feature importance either through a ``coef_``
@@ -144,10 +148,13 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
144148
145149
Attributes
146150
----------
147-
estimator_ : an estimator
148-
The base estimator from which the transformer is built.
149-
This is stored only when a non-fitted estimator is passed to the
150-
``SelectFromModel``, i.e when prefit is False.
151+
estimator_ : estimator
152+
The base estimator from which the transformer is built. This attribute
153+
exist only when `fit` has been called.
154+
155+
- If `prefit=True`, it is a deep copy of `estimator`.
156+
- If `prefit=False`, it is a clone of `estimator` and fit on the data
157+
passed to `fit` or `partial_fit`.
151158
152159
n_features_in_ : int
153160
Number of features seen during :term:`fit`. Only defined if the
@@ -159,7 +166,7 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
159166
Maximum number of features calculated during :term:`fit`. Only defined
160167
if the ``max_features`` is not `None`.
161168
162-
- If `max_features` is an int, then `max_features_ = max_features`.
169+
- If `max_features` is an `int`, then `max_features_ = max_features`.
163170
- If `max_features` is a callable, then `max_features_ = max_features(X)`.
164171
165172
.. versionadded:: 1.1
@@ -237,17 +244,33 @@ def __init__(
237244
self.max_features = max_features
238245

239246
def _get_support_mask(self):
240-
# SelectFromModel can directly call on transform.
247+
estimator = getattr(self, "estimator_", self.estimator)
248+
max_features = getattr(self, "max_features_", self.max_features)
249+
241250
if self.prefit:
242-
estimator = self.estimator
243-
elif hasattr(self, "estimator_"):
244-
estimator = self.estimator_
245-
else:
251+
try:
252+
check_is_fitted(self.estimator)
253+
except NotFittedError as exc:
254+
raise NotFittedError(
255+
"When `prefit=True`, `estimator` is expected to be a fitted "
256+
"estimator."
257+
) from exc
258+
if callable(max_features):
259+
# This branch is executed when `transform` is called directly and thus
260+
# `max_features_` is not set and we fallback using `self.max_features`
261+
# that is not validated
262+
raise NotFittedError(
263+
"When `prefit=True` and `max_features` is a callable, call `fit` "
264+
"before calling `transform`."
265+
)
266+
elif max_features is not None and not isinstance(
267+
max_features, numbers.Integral
268+
):
246269
raise ValueError(
247-
"Either fit the model before transform or set"
248-
' "prefit=True" while passing the fitted'
249-
" estimator to the constructor."
270+
f"`max_features` must be an integer. Got `max_features={max_features}` "
271+
"instead."
250272
)
273+
251274
scores = _get_feature_importances(
252275
estimator=estimator,
253276
getter=self.importance_getter,
@@ -257,9 +280,7 @@ def _get_support_mask(self):
257280
threshold = _calculate_threshold(estimator, scores, self.threshold)
258281
if self.max_features is not None:
259282
mask = np.zeros_like(scores, dtype=bool)
260-
candidate_indices = np.argsort(-scores, kind="mergesort")[
261-
: self.max_features_
262-
]
283+
candidate_indices = np.argsort(-scores, kind="mergesort")[:max_features]
263284
mask[candidate_indices] = True
264285
else:
265286
mask = np.ones_like(scores, dtype=bool)
@@ -313,9 +334,17 @@ def fit(self, X, y=None, **fit_params):
313334
)
314335

315336
if self.prefit:
316-
raise NotFittedError("Since 'prefit=True', call transform directly")
317-
self.estimator_ = clone(self.estimator)
318-
self.estimator_.fit(X, y, **fit_params)
337+
try:
338+
check_is_fitted(self.estimator)
339+
except NotFittedError as exc:
340+
raise NotFittedError(
341+
"When `prefit=True`, `estimator` is expected to be a fitted "
342+
"estimator."
343+
) from exc
344+
self.estimator_ = deepcopy(self.estimator)
345+
else:
346+
self.estimator_ = clone(self.estimator)
347+
self.estimator_.fit(X, y, **fit_params)
319348

320349
if hasattr(self.estimator_, "feature_names_in_"):
321350
self.feature_names_in_ = self.estimator_.feature_names_in_
@@ -357,7 +386,17 @@ def partial_fit(self, X, y=None, **fit_params):
357386
Fitted estimator.
358387
"""
359388
if self.prefit:
360-
raise NotFittedError("Since 'prefit=True', call transform directly")
389+
if not hasattr(self, "estimator_"):
390+
try:
391+
check_is_fitted(self.estimator)
392+
except NotFittedError as exc:
393+
raise NotFittedError(
394+
"When `prefit=True`, `estimator` is expected to be a fitted "
395+
"estimator."
396+
) from exc
397+
self.estimator_ = deepcopy(self.estimator)
398+
return self
399+
361400
if not hasattr(self, "estimator_"):
362401
self.estimator_ = clone(self.estimator)
363402
self.estimator_.partial_fit(X, y, **fit_params)

sklearn/feature_selection/tests/test_from_model.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn import datasets
1414
from sklearn.cross_decomposition import CCA, PLSCanonical, PLSRegression
1515
from sklearn.datasets import make_friedman1
16+
from sklearn.exceptions import NotFittedError
1617
from sklearn.linear_model import LogisticRegression, SGDClassifier, Lasso
1718
from sklearn.svm import LinearSVC
1819
from sklearn.feature_selection import SelectFromModel
@@ -99,16 +100,20 @@ def test_max_features_error(max_features, err_type, err_msg):
99100
transformer.fit(data, y)
100101

101102

102-
@pytest.mark.parametrize("max_features", [0, 2, data.shape[1]])
103+
@pytest.mark.parametrize("max_features", [0, 2, data.shape[1], None])
103104
def test_inferred_max_features_integer(max_features):
104105
"""Check max_features_ and output shape for integer max_features."""
105106
clf = RandomForestClassifier(n_estimators=5, random_state=0)
106107
transformer = SelectFromModel(
107108
estimator=clf, max_features=max_features, threshold=-np.inf
108109
)
109110
X_trans = transformer.fit_transform(data, y)
110-
assert transformer.max_features_ == max_features
111-
assert X_trans.shape[1] == transformer.max_features_
111+
if max_features is not None:
112+
assert transformer.max_features_ == max_features
113+
assert X_trans.shape[1] == transformer.max_features_
114+
else:
115+
assert not hasattr(transformer, "max_features_")
116+
assert X_trans.shape[1] == data.shape[1]
112117

113118

114119
@pytest.mark.parametrize(
@@ -405,17 +410,78 @@ def test_prefit():
405410
clf.fit(data, y)
406411
model = SelectFromModel(clf, prefit=True)
407412
assert_array_almost_equal(model.transform(data), X_transform)
413+
model.fit(data, y)
414+
assert model.estimator_ is not clf
408415

409416
# Check that the model is rewritten if prefit=False and a fitted model is
410417
# passed
411418
model = SelectFromModel(clf, prefit=False)
412419
model.fit(data, y)
413420
assert_array_almost_equal(model.transform(data), X_transform)
414421

415-
# Check that prefit=True and calling fit raises a ValueError
422+
# Check that passing an unfitted estimator with `prefit=True` raises a
423+
# `ValueError`
424+
clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, random_state=0, tol=None)
416425
model = SelectFromModel(clf, prefit=True)
417-
with pytest.raises(ValueError):
426+
err_msg = "When `prefit=True`, `estimator` is expected to be a fitted estimator."
427+
with pytest.raises(NotFittedError, match=err_msg):
418428
model.fit(data, y)
429+
with pytest.raises(NotFittedError, match=err_msg):
430+
model.partial_fit(data, y)
431+
with pytest.raises(NotFittedError, match=err_msg):
432+
model.transform(data)
433+
434+
# Check that the internal parameters of prefitted model are not changed
435+
# when calling `fit` or `partial_fit` with `prefit=True`
436+
clf = SGDClassifier(alpha=0.1, max_iter=10, shuffle=True, tol=None).fit(data, y)
437+
model = SelectFromModel(clf, prefit=True)
438+
model.fit(data, y)
439+
assert_allclose(model.estimator_.coef_, clf.coef_)
440+
model.partial_fit(data, y)
441+
assert_allclose(model.estimator_.coef_, clf.coef_)
442+
443+
444+
def test_prefit_max_features():
445+
"""Check the interaction between `prefit` and `max_features`."""
446+
# case 1: an error should be raised at `transform` if `fit` was not called to
447+
# validate the attributes
448+
estimator = RandomForestClassifier(n_estimators=5, random_state=0)
449+
estimator.fit(data, y)
450+
model = SelectFromModel(estimator, prefit=True, max_features=lambda X: X.shape[1])
451+
452+
err_msg = (
453+
"When `prefit=True` and `max_features` is a callable, call `fit` "
454+
"before calling `transform`."
455+
)
456+
with pytest.raises(NotFittedError, match=err_msg):
457+
model.transform(data)
458+
459+
# case 2: `max_features` is not validated and different from an integer
460+
# FIXME: we cannot validate the upper bound of the attribute at transform
461+
# and we should force calling `fit` if we intend to force the attribute
462+
# to have such an upper bound.
463+
max_features = 2.5
464+
model.set_params(max_features=max_features)
465+
with pytest.raises(ValueError, match="`max_features` must be an integer"):
466+
model.transform(data)
467+
468+
469+
def test_prefit_get_feature_names_out():
470+
"""Check the interaction between prefit and the feature names."""
471+
clf = RandomForestClassifier(n_estimators=2, random_state=0)
472+
clf.fit(data, y)
473+
model = SelectFromModel(clf, prefit=True, max_features=1)
474+
475+
# FIXME: the error message should be improved. Raising a `NotFittedError`
476+
# would be better since it would force to validate all class attribute and
477+
# create all the necessary fitted attribute
478+
err_msg = "Unable to generate feature names without n_features_in_"
479+
with pytest.raises(ValueError, match=err_msg):
480+
model.get_feature_names_out()
481+
482+
model.fit(data, y)
483+
feature_names = model.get_feature_names_out()
484+
assert feature_names == ["x3"]
419485

420486

421487
def test_threshold_string():

0 commit comments

Comments
 (0)
0