8000 [WIP] Column selector functions for ColumnTransformer by partmor · Pull Request #11301 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] Column selector functions for ColumnTransformer #11301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

partmor
Copy link
Contributor
@partmor partmor commented Jun 16, 2018

Reference Issues

Closes #11190

What does this implement?

This PR introduces functions to generate column selector callables that can be passed to make_column_transformer and ColumnTransformer in place of the actual column selections.

For now, I have implemen 8000 ted a selector for dtypes. I am working on a name selector.
(As discussed in #11190)

Other comments:

This PR is still incomplete. I have created it early in order to receive feedback.
The example I have included is just a quick refactorization of #11197. I will care further about the formatting when we are happy with the functions.

@partmor
Copy link
Contributor Author
partmor commented Jun 16, 2018

As highlighted by @amueller earlier, this would not be robust to occurrences like unexpected float16. I will think about a solution for this; one quick thought would be: hardcode [np.float16,..] and rest of variations to be applied by default when float is passed?

Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document and test passing a callable.

I'm really not sure that the factory needs to be in the same pr. Let's go with it for now, but anticipate that we might pull that into a separate contribution. Focus on getting the generic, relatively uncontroversial interface enhancement right first

@@ -522,7 +525,10 @@ def _get_column(X, key):
if column_names:
if hasattr(X, 'loc'):
# pandas dataframes
return X.loc[:, key]
if not callable(key):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we do this before the ifs, setting key = key(X)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will do

@@ -0,0 +1,69 @@
"""
=======================
Select Column Functions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the existing examples would benefit from this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will drop this new example then, and enhance the already existing one for ColumnTransformer.

@@ -597,6 +608,50 @@ def _get_transformer_list(estimators):
return transformer_list


def select_types(dtypes):
"""Generate a column selector callable (type-based) to be passed to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pep257: a short summary should be alone on the first line

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

@@ -597,6 +608,50 @@ def _get_transformer_list(estimators):
return transformer_list


def select_types(dtypes):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add to doc/modules/classes.rst

Uh oh!

There was an error while loading. Please reload this page.


Parameters
----------
dtypes : list of column dtypes to be selected from the dataset.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numpydoc: type on this line, semantics on the next

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make the same factory also able to select columns by nane. Please add a param to do so and rename the function

"""
def apply_dtype_mask(X, dtype):
if hasattr(X, 'dtypes'):
return X.dtypes == dtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to use issubdtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am returning a boolean mask, issubdtype can't help me here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well you can use np.asarray([issubdtype(xtype, dtypes) for xtype in X.dtypes], dtype=bool)

Copy link
Contributor Author
@partmor partmor Jun 18, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Not to be insistent; we prefer to use np.issubdtype rather than == directly for consistency in the module?

masks = [apply_dtype_mask(X, t) for t in dtypes]
return masks

return lambda X: np.any(get_dtype_masks(X), axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot pickle closures. Please implement this as a class with __call__ instead

Copy link
Contributor Author
@partmor partmor Jun 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep; also using a class with __call__ will also make the code a lot cleaner.

@jnothman
Copy link
Member
jnothman commented Jun 18, 2018 via email

@partmor
Copy link
Contributor Author
partmor commented Jun 18, 2018

@jnothman would we be expecting users to pass as dtypes: np.floating, np.integer, np.number, np.object_ and so on? It worries me specially in the case of categorical columns (pandas dtype object).

Doing np.issubdtype(categorical_column_dtype, 'object') returns True, but with a FutureWarning (eventually makes my code crash when put into the functions). np.issubdtype(categorical_column_dtype, np.object_) does the job without warnings. The same applies np.issubdtype(np.float16, float) vs np.issubdtype(np.float16, np.floating) and so on.

Would we maybe end up doing some user input preprocessing? (e.g. if user passes float, take np.floating and so on...)

@jnothman
Copy link
Member

Hmmm... this is tricky. Pandas dtypes (CategoricalDtype at least) aren't dtypes and raise an exception with np.issubdtype.

object should be treated as object_ because generic is too broad.

But I'm not sure that float should be treated as float64, which it is in my numpy.

We could consider something like:
select_columns(categorical=T/F, object=T/F, numeric=T/F, float=T/F, integer=T/F, fixedwidth_string=T/F, datetime=T/F, timedelta=T/F, ...) and raise errors on redundant specifications. But we should probably talk to pandas folks before reinventing wheels. @jorisvandenbossche, any ideas?

@jorisvandenbossche
Copy link
Member

Didn't look yet at the implementation / discussion, but the relevant piece of pandas functionality to look at is DataFrame.select_dtypes method.

@jnothman
Copy link
Member
jnothman commented Jun 18, 2018

Thanks. Perhaps we should assume this helper is only for DataFrames, and use that directly, i.e. use X.iloc[:1].select_dtypes(include=include, exclude=exclude).columns as our mask.

@jorisvandenbossche
Copy link
Member

On the short term, maybe we should start with a PR with only the actual change to ColumnTransformer to accept functions, and leave defining such functions as sklearn API for later (in light of getting something in for the release)?

@jnothman
Copy link
Member
jnothman commented Jun 28, 2018 via email

@amueller
Copy link
Member

sounds good to me.

@amueller
Copy link
Member
amueller commented Jun 28, 2018

I think the main issue might be that lambdas don't pickle, right? So users will need to actually define a function in some python file (not sure if defining a function in a notebook is sufficient?)

@jorisvandenbossche
Copy link
Member

@partmor I opened #11592 with only the part to add the actual functionality (we are sprinting with the core devs, and would like to try to get this into the release). We can afterwards further use this PR to add the factory functions (and will make sure you get proper credit on the other PR).

@partmor
Copy link
Contributor Author
partmor commented Jul 17, 2018

@partmor I opened #11592 with only the part to add the actual functionality (we are sprinting with the core devs, and would like to try to get this into the release). We can afterwards further use this PR to add the factory functions (and will make sure you get proper credit on the other PR).

@jorisvandenbossche thank you very much for the ping! I will develop with an eye on that PR, sorry for the delays .. :(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Superseded PR has been replace by a newer PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ColumnTransformer should be able to use a function to select columns
4 participants
0