-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
RFC Implement Pipeline get feature names #12627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
ab2acbd
work on get_feature_names for pipeline
amueller 3bc674b
fix SimpleImputer get_feature_names
amueller 1c4a78f
use hasattr(transform) to check whether to use final estimator in get…
8000
amueller 7881930
add some docstrings
amueller de63353
fix docstring
amueller 8835f3b
Merge branch 'master' into pipeline_get_feature_names
amueller 2eba5de
fix merge issues with master
amueller 449ed23
fix merge issue
amueller a1fcf67
Merge branch 'master' into pipeline_get_feature_names
amueller b929341
don't do magic slicing in pipeline.get_feature_names
amueller 2b613e5
fix merge issue
amueller ad66b86
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn…
amueller 5eb7603
trying to merge with input feature pr
amueller f4f832a
Merge branch 'master' into pipeline_get_feature_names
amueller 3a9054c
remove tests taht don't apply
amueller 9c4420d
Merge branch 'pipeline_get_feature_names' of github.com:amueller/scik…
amueller 76f5b54
fix onetoone mixing feature names
amueller 52f38e1
remove more tests
amueller cdda1fb
fix test for better expected outputs
amueller 5f4abbc
fix priorities in catch-all get_feature_names
amueller 4305a28
flake8
amueller c387b5b
remove redundant code
amueller 2fefb67
fix error message
amueller a6832c3
fix mixin order
amueller 0f45b22
small refactor with helper function
amueller 4717a73
linting for new options
amueller a658ba7
add feature names to lineardiscriminantanalysis and birch
amueller e9e45af
add get_feature_names in a couple more places
amueller 5acaced
fix up docs
amueller 0353f69
make example actually work
amueller File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
from .utils import _IS_32BIT | ||
from .utils.validation import check_X_y | ||
from .utils.validation import check_array | ||
from .utils._feature_names import _make_feature_names | ||
from .utils._estimator_html_repr import estimator_html_repr | ||
from .utils.validation import _deprecate_positional_args | ||
|
||
|
@@ -689,6 +690,45 @@ def fit_transform(self, X, y=None, **fit_params): | |
# fit method of arity 2 (supervised transformation) | ||
return self.fit(X, y, **fit_params).transform(X) | ||
|
||
def get_feature_names(self, input_features=None): | ||
"""Get output feature names. | ||
|
||
Parameters | ||
---------- | ||
input_features : list of string or None | ||
String names of the input features. | ||
|
||
Returns | ||
------- | ||
output_feature_names : list of string | ||
Feature names for transformer output. | ||
""" | ||
# generate feature names from class name by default | ||
# would be much less guessing if we stored the number | ||
# of output features. | ||
# Ideally this would be done in each class. | ||
if hasattr(self, 'n_clusters'): | ||
# this is before n_components_ | ||
# because n_components_ means something else | ||
# in agglomerative clustering | ||
n_features = self.n_clusters | ||
elif hasattr(self, '_max_components'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whoops this can be removed, it's in the class now |
||
# special case for LinearDiscriminantAnalysis | ||
n_components = self.n_components or np.inf | ||
n_features = min(self._max_components, n_components) | ||
elif hasattr(self, 'n_components_'): | ||
# n_components could be auto or None | ||
# this is more likely to be an int | ||
n_features = self.n_components_ | ||
elif hasattr(self, 'components_'): | ||
n_features = self.components_.shape[0] | ||
elif hasattr(self, 'n_components') and self.n_components is not None: | ||
n_features = self.n_components | ||
else: | ||
return None | ||
return _make_feature_names(n_features=n_features, | ||
prefix=type(self).__name__.lower()) | ||
|
||
|
||
class DensityMixin: | ||
"""Mixin class for all density estimators in scikit-learn.""" | ||
|
@@ -737,6 +777,34 @@ def fit_predict(self, X, y=None): | |
return self.fit(X).predict(X) | ||
|
||
|
||
class OneToOneMixin(object): | ||
"""Provides get_feature_names for simple transformers | ||
|
||
Assumes there's a 1-to-1 correspondence between input features | ||
and output features. | ||
""" | ||
|
||
def get_feature_names(self, input_features=None): | ||
"""Get feature names for transformation. | ||
|
||
Returns input_features as this transformation | ||
doesn't add or drop features. | ||
|
||
Parameters | ||
---------- | ||
input_features : array-like of string | ||
Input feature names. | ||
|
||
Returns | ||
------- | ||
feature_names : array-like of string | ||
Transformed feature names | ||
""" | ||
|
||
return _make_feature_names(self.n_features_in_, | ||
input_features=input_features) | ||
|
||
|
||
class MetaEstimatorMixin: | ||
_required_parameters = ["estimator"] | ||
"""Mixin class for all meta estimators in scikit-learn.""" | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -371,8 +371,12 @@ def get_feature_names(self): | |
raise AttributeError("Transformer %s (type %s) does not " | ||
"provide get_feature_names." | ||
% (str(name), type(trans).__name__)) | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is ducktyping to support both transformative and non-transformative get_feature_names. |
||
more_names = trans.get_feature_names(input_features=column) | ||
except TypeError: | ||
more_names = trans.get_feature_names() | ||
feature_names.extend([name + "__" + f for f in | ||
trans.get_feature_names()]) | ||
more_names]) | ||
return feature_names | ||
|
||
def _update_fitted_transformers(self, transformers): | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can push this down if people think having it here is ugly.