1
1
# Authors: Gilles Louppe, Mathieu Blondel, Maheshakya Wijewardena
2
2
# License: BSD 3 clause
3
3
4
+ from copy import deepcopy
5
+
4
6
import numpy as np
5
7
import numbers
6
8
@@ -102,11 +104,10 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
102
104
103
105
prefit : bool, default=False
104
106
Whether a prefit model is expected to be passed into the constructor
105
- directly or not. If True, ``transform`` must be called directly
106
- and SelectFromModel cannot be used with ``cross_val_score``,
107
- ``GridSearchCV`` and similar utilities that clone the estimator.
108
- Otherwise train the model using ``fit`` and then ``transform`` to do
109
- feature selection.
107
+ directly or not.
108
+ If `True`, `estimator` must be a fitted estimator.
109
+ If `False`, `estimator` is fitted and updated by calling
110
+ `fit` and `partial_fit`, respectively.
110
111
111
112
norm_order : non-zero int, inf, -inf, default=1
112
113
Order of the norm used to filter the vectors of coefficients below
@@ -120,10 +121,13 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
120
121
allow.
121
122
- If a callable, then it specifies how to calculate the maximum number of
122
123
features allowed by using the output of `max_feaures(X)`.
124
+ - If `None`, then all features are kept.
123
125
124
126
To only select based on ``max_features``, set ``threshold=-np.inf``.
125
127
126
128
.. versionadded:: 0.20
129
+ .. versionchanged:: 1.1
130
+ `max_features` accepts a callable.
127
131
128
132
importance_getter : str or callable, default='auto'
129
133
If 'auto', uses the feature importance either through a ``coef_``
@@ -144,10 +148,13 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
144
148
145
149
Attributes
146
150
----------
147
- estimator_ : an estimator
148
- The base estimator from which the transformer is built.
149
- This is stored only when a non-fitted estimator is passed to the
150
- ``SelectFromModel``, i.e when prefit is False.
151
+ estimator_ : estimator
152
+ The base estimator from which the transformer is built. This attribute
153
+ exist only when `fit` has been called.
154
+
155
+ - If `prefit=True`, it is a deep copy of `estimator`.
156
+ - If `prefit=False`, it is a clone of `estimator` and fit on the data
157
+ passed to `fit` or `partial_fit`.
151
158
152
159
n_features_in_ : int
153
160
Number of features seen during :term:`fit`. Only defined if the
@@ -159,7 +166,7 @@ class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
159
166
Maximum number of features calculated during :term:`fit`. Only defined
160
167
if the ``max_features`` is not `None`.
161
168
162
- - If `max_features` is an int, then `max_features_ = max_features`.
169
+ - If `max_features` is an ` int` , then `max_features_ = max_features`.
163
170
- If `max_features` is a callable, then `max_features_ = max_features(X)`.
164
171
165
172
.. versionadded:: 1.1
@@ -237,17 +244,33 @@ def __init__(
237
244
self .max_features = max_features
238
245
239
246
def _get_support_mask (self ):
240
- # SelectFromModel can directly call on transform.
247
+ estimator = getattr (self , "estimator_" , self .estimator )
248
+ max_features = getattr (self , "max_features_" , self .max_features )
249
+
241
250
if self .prefit :
242
- estimator = self .estimator
243
- elif hasattr (self , "estimator_" ):
244
- estimator = self .estimator_
245
- else :
251
+ try :
252
+ check_is_fitted (self .estimator )
253
+ except NotFittedError as exc :
254
+ raise NotFittedError (
255
+ "When `prefit=True`, `estimator` is expected to be a fitted "
256
+ "estimator."
257
+ ) from exc
258
+ if callable (max_features ):
259
+ # This branch is executed when `transform` is called directly and thus
260
+ # `max_features_` is not set and we fallback using `self.max_features`
261
+ # that is not validated
262
+ raise NotFittedError (
263
+ "When `prefit=True` and `max_features` is a callable, call `fit` "
264
+ "before calling `transform`."
265
+ )
266
+ elif max_features is not None and not isinstance (
267
+ max_features , numbers .Integral
268
+ ):
246
269
raise ValueError (
247
- "Either fit the model before transform or set"
248
- ' "prefit=True" while passing the fitted'
249
- " estimator to the constructor."
270
+ f"`max_features` must be an integer. Got `max_features={ max_features } ` "
271
+ "instead."
250
272
)
273
+
251
274
scores = _get_feature_importances (
252
275
estimator = estimator ,
253
276
getter = self .importance_getter ,
@@ -257,9 +280,7 @@ def _get_support_mask(self):
257
280
threshold = _calculate_threshold (estimator , scores , self .threshold )
258
281
if self .max_features is not None :
259
282
mask = np .zeros_like (scores , dtype = bool )
260
- candidate_indices = np .argsort (- scores , kind = "mergesort" )[
261
- : self .max_features_
262
- ]
283
+ candidate_indices = np .argsort (- scores , kind = "mergesort" )[:max_features ]
263
284
mask [candidate_indices ] = True
264
285
else :
265
286
mask = np .ones_like (scores , dtype = bool )
@@ -313,9 +334,17 @@ def fit(self, X, y=None, **fit_params):
313
334
)
314
335
315
336
if self .prefit :
316
- raise NotFittedError ("Since 'prefit=True', call transform directly" )
317
- self .estimator_ = clone (self .estimator )
318
- self .estimator_ .fit (X , y , ** fit_params )
337
+ try :
338
+ check_is_fitted (self .estimator )
339
+ except NotFittedError as exc :
340
+ raise NotFittedError (
341
+ "When `prefit=True`, `estimator` is expected to be a fitted "
342
+ "estimator."
343
+ ) from exc
344
+ self .estimator_ = deepcopy (self .estimator )
345
+ else :
346
+ self .estimator_ = clone (self .estimator )
347
+ self .estimator_ .fit (X , y , ** fit_params )
319
348
320
349
if hasattr (self .estimator_ , "feature_names_in_" ):
321
350
self .feature_names_in_ = self .estimator_ .feature_names_in_
@@ -357,7 +386,17 @@ def partial_fit(self, X, y=None, **fit_params):
357
386
Fitted estimator.
358
387
"""
359
388
if self .prefit :
360
- raise NotFittedError ("Since 'prefit=True', call transform directly" )
389
+ if not hasattr (self , "estimator_" ):
390
+ try :
391
+ check_is_fitted (self .estimator )
392
+ except NotFittedError as exc :
393
+ raise NotFittedError (
394
+ "When `prefit=True`, `estimator` is expected to be a fitted "
395
+ "estimator."
396
+ ) from exc
397
+ self .estimator_ = deepcopy (self .estimator )
398
+ return self
399
+
361
400
if not hasattr (self , "estimator_" ):
362
401
self .estimator_ = clone (self .estimator )
363
402
self .estimator_ .partial_fit (X , y , ** fit_params )
0 commit comments