8000 [MRG+1] Select k-best features in SelectFromModel by nsheth12 · Pull Request #9616 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Select k-best features in SelectFromModel #9616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 16, 2018

Conversation

nsheth12
Copy link
Contributor

Reference Issue

Continuation of work from PR #6717.

What does this implement/fix? Explain your changes.

Will merge in master (this branch is a year old) and make changes as discussed in previous PR discussion to make it ready for merging in.

@nsheth12
Copy link
Contributor Author

The AppVeyor build continues to fail. It fails the _check_max_features() check on Windows with Python 2.7.8 and 64-bit architecture. From some reason, it's receiving "10L" as the max_features parameter. I'm not able to reproduce the issue locally. Any ideas as to what's going on/how to reproduce/how to fix? Below is a screenshot of the error traceback in the AppVeyor console:

appveyor_build_fail_console

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this up!

if 0 <= self.max_features <= X.shape[1]:
return
elif self.max_features == 'all':
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

less indentation

if isinstance(self.max_features, int):
if 0 <= self.max_features <= X.shape[1]:
return
elif self.max_features == 'all':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between None and 'all'?

@@ -108,6 +108,9 @@ class SelectFromModel(BaseEstimator, SelectorMixin, MetaEstimatorMixin):
Otherwise train the model using ``fit`` and then ``transform`` to do
feature selection.

max_features : int, between 0 and number of features, optional.
Select at most this many features that score above the threshold.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a note that to use only max_features, and no threshold, threshold=-np.inf can be used. But perhaps we should allow users to disable threshold with a string '-inf'? (I'd rather not 'none' which will get confused with the current None, meaning automatic threshold determination. Although we could in turn consider deprecating the use of threshold=None and renaming it to threshold='auto'.)

transformer = SelectFromModel(estimator=est,
max_features=invalid_max_n_feature,
threshold=-np.inf)
assert_raises(ValueError, transformer.fit, X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's generally better to check the right error is being raised, especially for something as generic as a ValueError. Use assert_raises_regexp or assert_raise_message.

raise ValueError(
'Either fit SelectFromModel before transform or set "prefit='
'True" and pass a fitted estimator to the constructor.')
raise ValueError('Either fit the model before transform or set'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It probably wasn't your doing, but generally you should avoid touching code not related to the change.

assert_equal(X_new.shape[1], n_features)


def check_threshold_and_max_features(est, X, y):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather this be a separate test_threshold_and_max_features

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test currently tests threshold too much, when it is already covered above. A good set of tests should look like a proof by induction. We start by checking the basic features, and then that their combination makes sense. So I think this test would work well if we assume threshold and max_features work alone, and only confirm that their combination produces the features corresponding to the set intersection of their selections. I.e. this test should not bother with coef_ or with shape.


# Test max_features against actual model.
transformer1 = SelectFromModel(estimator=Lasso(alpha=0.025,
random_state=42))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not appropriate indentation. It makes it look like the random_state belongs to SelectFromModel

assert_array_equal(transformer1.estimator_.coef_,
transformer2.estimator_.coef_)

# Test if max_features can break tie among feature importance
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a separate test function.

threshold=-np.inf)
X_new = transformer.fit_transform(X, y)
selected_feature_indices = np.where(transformer._get_support_mask())[0]
assert_array_equal(selected_feature_indices, np.arange(n_features))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with this approach, but wonder if we'd be better off taking max_features literally and returning none of the tying features at the cutoff (to avoid users being surprised by the tie-breaking; although we do break ties like this in SelectKBest and SelectPercentile, and perhaps we should remain consistent). WDYT? Perhaps it rarely matters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think it is better to give users exactly the number of features they ask for. From my experience as a user, I don't care so much which of feature X or feature Y I get when both are tied in importance as much as I do that when I ask for Z features, I get Z features and not less. Consistency with SelectKBest and SelectPercentile would be other arguments in favor of keeping as is. However, this is just my 2 cents, and I'll defer to you on the final decision.

assert_array_equal(X_new3, X[:, selected_indices[0]])

"""
If threshold and max_features are not provided, all features are
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment makes sense. If threshold and max_features are not provided, the default threshold is used.

@jnothman
Copy link
Member
jnothman commented Aug 28, 2017 via email

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

self.norm_order = norm_order

def _check_max_features(self, X, max_features):
if self.max_features is None or self.max_features == 'all':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason to have both 'all' and None available? Can't we just allow the default value, either 'all' or None, to yield this behaviour?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Removed support for "all".


def _check_params(self, X, y):
X, y = check_X_y(X, y)
self._check_max_features(X, self.max_features)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tend to avoid such nesting and would rather have the check_max_features logic inline here.

n_features_to_select = self.max_features
if self.max_features == 'all':
n_features_to_select = scores.size
candidate_indices = np.argsort(-scores,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sort is unnecessary in the default max_features case, and so seems to be wasted computation.

Note that an alternative way to implement this, in O(n) time, is to just set threshold=max(threshold, np.percentile(scores, 100 * max_features / n_features)), and then handle ties explicitly if too many features are selected. I'm happy with mergesort for readability and comparison to SelectKBest

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I just added a check to avoid sorting in the default case. I think the tiebreaking code required for the percentile approach would decrease readability significantly (as you mention). Of course, if performance becomes an issue, I can always go back and change it.



def check_diff_models_threshold_and_max_features(est, X, y):
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this comment is clear. The default threshold is used, even if max_features is provided. I don't think it's the right place for the comment, either. It can just be removed.

n_repeated=0, shuffle=False, random_state=0)

check_diff_models_threshold_and_max_features(
RandomForestClassifier(n_estimators=50, random_state=0), X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this doesn't seem to be the place to test that SelectFromModel works with different kinds of models. We only need to test the interaction of threshold and max_features assuming that independently, they both work correctly.

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

np.argsort(-scores, kind='mergesort')[:self.max_features]
mask[candidate_indices] = True
else:
mask = np.logical_not(mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: I'd rather see this as ones_like, with zeros_like in the if case. But whatever.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I approve this change :)

assert_equal(X_new3.shape[1], min(X_new1.shape[1], X_new2.shape[1]))
selected_indices = \
transformer3.transform(np.arange(X.shape[1])[np.newaxis, :])
assert_array_equal(X_new3, X[:, selected_indices[0]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Were this to fail, the error would not be very clear. Much clearer if we were just comparing ranges. But it's okay.

@jnothman jnothman changed the title Select k-best features in SelectFromModel [MRG+1] Select k-best features in SelectFromModel Aug 29, 2017
@nsheth12
Copy link
Contributor Author

Is there anything else I need to do for this to be reviewed for moving to MRG+2?

@jnothman
Copy link
Member

The review will come, eventually. Feel free to take up another issue in the meantime.

Also, please add an entry to doc/whats_new/v0.20.rst citing @qmaruf and yourself.

@nsheth12
Copy link
Contributor Author

Just wanted to check in regarding when the second review for this PR will occur?

@amueller
Copy link
Member

@nsheth12 when someone finds time ;) Sorry, a lot of us are pretty busy.

@amueller
Copy link
Member

Can you please resolve the conflict?

@nsheth12
Copy link
Contributor Author
nsheth12 commented Dec 1, 2017

Resolved the conflict - sorry for the delay. Is there anything else I need to do?

@jnothman
Copy link
Member

Wait for a reviewer :\

@jnothman jnothman added this to the 0.20 milestone Jun 17, 2018
Copy link
Member
@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nsheth12 Can you address those minor issues.

self.norm_order = norm_order

def _check_params(self, X, y):
X, y = check_X_y(X, y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason to not accept sparse matrices. I would think that the underlying estimator should take care about it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right: we should not have any criteria on X or y as long as X has shape on its second axis, and can be indexed on it. It's a bit upsetting that we don't have a test for that!

@@ -108,6 +109,11 @@ class SelectFromModel(BaseEstimator, SelectorMixin, MetaEstimatorMixin):
Otherwise train the model using ``fit`` and then ``transform`` to do
feature selection.

max_features : int, between 0 and number of features, optional.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not mention between 0 and number of features in the first line. Also remove the final full stop.

max_features : int, between 0 and number of features, optional.
Select at most this many features that score above the threshold.
To disable the threshold, and only select based on max_features,
set threshold = -np.inf.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put some backsticks and no space around equal

X, y = check_X_y(X, y)

if self.max_features is None:
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also return X and y if we check them.

self.norm_order = norm_order

def _check_params(self, X, y):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this function _check_inputs

@@ -40,6 +42,117 @@ def test_input_estimator_unchanged():
assert_true(transformer.estimator is est)


def check_invalid_max_features(est, X, y):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can parametrize the test using pytest from now on:

@pytest.mark.parametrize("max_features", [-1, X.shape[1] + 1, 'gobbledigook', 'all'])
def check_invalid_max_features(est, X, y, max_features):
    transformer = SelectFromModel(estimator=est,
                                      max_features=max_features,
                                      threshold=-np.inf)
    with pytest.raises(ValueError, err_msg):
        transformer.fit(X, y)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it is weird that you have always the same error "max_features should be >=0".
It is not meaningful for string. We need to make 2 if conditions for type and values.


def check_valid_max_features(est, X, y):
max_features = X.shape[1]
for valid_max_n_feature in [0, max_features, 5]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parametrize

max_features=valid_max_n_feature,
threshold=-np.inf)
X_new = transformer.fit_transform(X, y)
assert_equal(X_new.shape[1], valid_max_n_feature)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call it max_features

transformer2.estimator_.coef_)


def test_max_features_tiebreak():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also parametrize this test at a first glance.

transformer3 = SelectFromModel(estimator=est, max_features=3,
threshold=0.04)
X_new3 = transformer3.fit_transform(X, y)
assert_equal(X_new3.shape[1], min(X_new1.shape[1], X_new2.shape[1]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use bare assert

@glemaitre
Copy link
Member
glemaitre commented Jun 25, 2018

@jorisvandenbossche I did the changes that I requested. Can you have a look to the PR to have an extra eye on it before merging.

@sklearn-lgtm
Copy link

This pull request introduces 1 alert when merging 2bdfb48 into eec7649 - view on LGTM.com

new alerts:

  • 1 for Unused import

Comment posted by LGTM.com

@@ -123,10 +129,12 @@ class SelectFromModel(BaseEstimator, SelectorMixin, MetaEstimatorMixin):
threshold_ : float
The threshold value used for feature selection.
"""
def __init__(self, estimator, threshold=None, prefit=False, norm_order=1):
def __init__(self, estimator, threshold=None, prefit=False,
max_features=None, norm_order=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to avoid API breakage max_features should be added at the end of the signature

@glemaitre
Copy link
Member

@agramfort I did the change

@hermidalc
Copy link
Contributor

@nsheth12 @jnothman @glemaitre @amueller sorry that I have seen this so late and after the merge. I've had customized code to do this exact functionality for a long time. My question to everyone is why is the implementation here so complex? For consistency why did you not use k as a parameter? Was it because you wanted to combine threshold and k best to determine the number of features?

Here is the diff between my code and 0.19.1, i FBB9 t's very simple. I ignore any threshold if k is specified, which is the behavior I wanted since I want it to be consistent and comparable to scoring functions.

>     k : int or "all", optional, default None
>         Number of top features to select.
>         The "all" option bypasses selection, for use in a parameter search.
>         If k is specified threshold is ignored.
> 
126c131
<     def __init__(self, estimator, threshold=None, prefit=False, norm_order=1):
---
>     def __init__(self, estimator, threshold=None, k=None, prefit=False, norm_order=1):
128a134
>         self.k = k
131a138,143
>     def _check_params(self, X, y):
>         if self.k is not None and not (self.k == "all" or 0 <= self.k <= X.shape[1]):
>             raise ValueError("k should be >=0, <= n_features; got %r."
>                              "Use k='all' to return all features."
>                              % self.k)
> 
142,144c154,167
<         scores = _get_feature_importances(estimator, self.norm_order)
<         threshold = _calculate_threshold(estimator, scores, self.threshold)
<         return scores >= threshold
---
>         self.scores_ = _get_feature_importances(estimator, self.norm_order)
>         if self.k is None:
>             threshold = _calculate_threshold(estimator, self.scores_, self.threshold)
>             return self.scores_ >= threshold
>         elif self.k == 'all':
>             return np.ones(self.scores_.shape, dtype=bool)
>         elif self.k == 0:
>             return np.zeros(self.scores_.shape, dtype=bool)
>         else:
>             mask = np.zeros(self.scores_.shape, dtype=bool)
>             # Request a stable sort. Mergesort takes more memory (~40MB per
>             # megafeature on x86-64).
>             mask[np.argsort(self.scores_, kind="mergesort")[-self.k:]] = True
>             return mask

@glemaitre
Copy link
Member

My question to everyone is why is the implementation here so complex

I don't see why our solution is more complex. This is exactly the same steps but we allow to use both max_features and threshold if desired.

        if self.max_features is not None:
            mask = np.zeros_like(scores, dtype=bool)
            candidate_indices = \
                np.argsort(-scores, kind='mergesort')[:self.max_features]
            mask[candidate_indices] = True
        else:
            mask = np.ones_like(scores, dtype=bool)
        mask[scores < threshold] = False
        return mask

For consistency why did you not use k as a parameter

This is true that k could have been an option. IMO, max_features is more explicit only considering the documentation of SelectFromModel.

@jnothman
Copy link
Member

max_features can also be extended in the future to handle fractions of the number of samples (like SelectPercentile), while k cannot as intuitively.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants
0