8000 [MRG + 1] Fix the cross_val_predict function for method='predict_proba' by dalmia · Pull Request #7889 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG + 1] Fix the cross_val_predict function for method='predict_proba' #7889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jan 7, 2017

Conversation

dalmia
Copy link
Contributor
@dalmia dalmia commented Nov 16, 2016

Reference Issue

Fixes #7863

What does this implement/fix? Explain your changes.

cross_val_predict() did not specify the order of the classes when method='predict_proba' was passed as the argument. So firstly, the change mentions in the documentation that the columns returned would be sorted. Secondly, for the estimators having a classes_ attribute, it ensures that the predictions are returned in a sorted manner.

Any other comments?

Currently, I haven't yet tested the code. Would be completing that within a day or so.

@dalmia dalmia changed the title [WIP] Fix the cross_val_predict function for method='predict_proba' [MRG] Fix the cross_val_predict function for method='predict_proba' Nov 16, 2016
@amueller
Copy link
Member

classes_ is always sorted, so this sorting never does something.

@dalmia
Copy link
Contributor Author
dalmia commented Nov 16, 2016

@amueller I had the same doubt before I started working on the issue. However, the discussion in the thread mentioned that the classes are not always promised to be returned in sorted order. Also, for classifiers like RandomForestClassifier, the docstring doesn't mention that the classes_ attribute return the classes in sorted manner. @jnothman Please provide your views too in this regard.

@jnothman
Copy link
Member

classes_ is always sorted, so this sorting never does something.

Do we require that to be true for external classifiers?

@amueller
Copy link
Member

It's undocumented and untested, so I guess it's not required. We could require it. I'm not sure where in scikit-learn we might rely on it.

@jnothman
Copy link
Member

actually, it is tested:
https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/utils/estimator_checks.py#L1192

So I suppose we should worry about absent but not out-of-order classes...

On 17 November 2016 at 10:14, Andreas Mueller notifications@github.com
wrote:

It's undocumented and untested, so I guess it's not required. We could
require it. I'm not sure where in scikit-learn we might rely on it.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#7889 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAEz69gdKghBOtXG-GfmiPr1n1cYuWsMks5q-45ngaJpZM4Kz1fC
.

@dalmia
Copy link
Contributor Author
dalmia commented Nov 17, 2016

With this being tested, it's true that we don't need to worry about them being out-of-order. To check for absent classes though, checking that number of unique elements in y in _fit_and_predict equals the # columns of predictions works?

@jnothman
< 8000 /summary> Copy link
Member

not necessarily, no. They may have differing subsets of classes, in a rare
case.

On 17 November 2016 at 17:58, Aman Dalmia notifications@github.com wrote:

With this being tested, it's true that we don't need to worry about them
being out-of-order. To check for absent classes though, checking that
number of unique elements in y in _fit_and_predict equals the # columns
of predictions works?


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#7889 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAEz69tXkKilEkbVcD2ooAhSbVLTUuiKks5q-_sagaJpZM4Kz1fC
.

@amueller
Copy link
Member

I guess _fit_and_score needs to return classes and then we need to ensure compatibility / possibly pad the probabilities. There was a PR for making _fit_and_score return a dict, right? @raghavrv?

@amueller
Copy link
Member

This reminds me of a long-ago PR that added classes as an __init__ argument for all estimators. Here we know all the classes in advance.... I'm wondering whether we couldn't do something more general and use a meta-estimator to make sure a classifier produces predict_proba and decision_function of the right shape.

@jnothman
Copy link
Member

Yes, I was trying to think of a more general solution... I don't mind a
meta-estimator, and it could be similarly applied to calibration where
we've just made similar fixes; but we have to be careful about backwards
compatibility of user interaction with the model.

On 18 November 2016 at 02:01, Andreas Mueller notifications@github.com
wrote:

This reminds me of a long-ago PR that added classes as an init
argument for all estimators. Here we know all the classes in advance....
I'm wondering whether we couldn't do something more general and use a
meta-estimator to make sure a classifier produces predict_proba and
decision_function of the right shape.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#7889 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAEz61rg8fD1xCEnj0-ctx8VoOhMjN5gks5q_GxRgaJpZM4Kz1fC
.

@amueller
Copy link
Member

I was wondering whether we can create a meta-estimator without adding parameter indirection via estimator__. set_params and get_params can be transparent, but __init__ can not be, because **kwargs are not allowed (someone recently complained about that).

We could get around that using meta-classes, though I'm not sure if that'll look nice. I'm trying to come up with a better solution....

@jnothman
Copy link
Member

meta-estimator or not, this needs tests for:

  • predict_proba, decision_function, predict_log_proba
  • binary, multiclass, multilabel and multioutput problems (use a DecisionTreeClassifier, for instance)

@jnothman jnothman changed the title [MRG] Fix the cross_val_predict function for method='predict_proba' [WIP] Fix the cross_val_predict function for method='predict_proba' Nov 18, 2016
@jnothman
Copy link
Member

i've marked it WIP: untested changes are not ready for merge

@dalmia
Copy link
Contributor Author
dalmia commented Nov 18, 2016

@jnothman @amueller Actually I am not familiar with what a meta-estimator is and hence, am unable to understand the solution that is being proposed. Could you people please give me a small explanation as to what is being discussed and how is a meta estimator different from a common estimator so that I can participate constructively as well?
Thanks a lot.

@jnothman
Copy link
Member

A meta-estimator is just an estimator that wraps an estimator.

Regardless of whether we use a meta-estimator in the solution, the tests are the same.

@amueller, I've said before, you can make metaestimators without parameter indirection using polymorphic clone as discussed in #5080

@amueller
Copy link
Member

@jnothman You mean with polymorphic clone we could allow **kwargs in __init__? Not sure that's my preferred solution.

@amueller
Copy link
Member
from sklearn.preprocessing import LabelEncoder

class ClassesMeta(type):
    def __call__(cls, *args, **kwargs):
        classes = kwargs.pop("classes", None)
        obj = type.__call__(cls, *args, **kwargs)
        obj.classes = classes
        return obj

class ClassesMixin(object):
    def get_params(self, deep=True):
        params = super(ClassesMixin, self).get_params(deep=deep)
        params['classes'] = self.classes
        return params

    def set_params(self, params):
        self.classes = params.pop(classes, None)
        return super(ClassesMixin, self).set_params(params)
    
    def fit(self, X, y):
        self._private_le = LabelEncoder().fit(self.classes)
        return super(ClassesMixin, self).fit(X, self._private_le.transform(y))
    
    def predict(self, X):
        return self._private_le.inverse_transform(super(ClassesMixin, self).predict(X))
    
    def predict_proba(self, X):
        probs = super(ClassesMixin, self).predict_proba(X)
        padded_probs = np.zeros((probs.shape[0], len(self.classes)))
        class_mapping = np.searchsorted(self._private_le.classes_, self.classes_)
        padded_probs[:, class_mapping] = probs
        return padded_probs
        
def add_classes_wrapper(cls):
    return ClassesMeta(cls.__name__ + "WithClasses", (ClassesMixin, cls), {})

And then

LogisticRegressionWithClasses = add_classes_wrapper(LogisticRegression)
asdf = LogisticRegressionWithClasses(classes=[0, 1, 2, 4])

from sklearn.datasets import load_iris
iris = load_iris()
asdf.fit(iris.data, iris.target)
asdf.predict_proba(iris.data).shape

(150, 4)

Not the solution we want to use here probably though ;) Just wanted to see how bad it would be.
I'd say not that bad, actually.... It's a bit awkward that self.classes is not sorted though.

@dalmia
Copy link
Contributor Author
dalmia commented Nov 29, 2016

@amueller I am trying to work on this but feel really confused on how to start. Should I create a meta-estimator or run the tests mentioned by @jnothman separately? Please help me get started.

@amueller
Copy link
Member

You should check out my suggestion here:
#7889 (comment)

And ignore everything afterwards.

@dalmia
Copy link
Contributor Author
dalmia commented Dec 4, 2016

@amueller But it seems that the cross_val_predict is independent of the _fit_and_scores function, so could you please explain your suggestion? Please correct me if I am mistaken.
Thanks.

@dalmia
Copy link
Contributor Author
dalmia commented Dec 5, 2016

@amueller Did you mean _fit_and_predict? (Since _fit_and_scores is used by cross_val_score)

@dalmia
Copy link
Contributor Author
dalmia commented Dec 5, 2016

I made _fit_and_predict return classes. But the problem is that different cross validations might return slightly different sets of classes on a rare occasion and I'm trying to find a good solution for it. Please give your views on the issue @jnothman @amueller.

@jnothman
Copy link
Member
jnothman commented Dec 5, 2016

You would really benefit from writing a test case that currently fails, but you would like to succeed with your patch. That, especially if you then propose a solution, would much better help us decide the appropriate direction.

@dalmia
Copy link
Contributor Author
dalmia 628C commented Dec 29, 2016

@raghavrv Thanks for letting me know. I'll try to fix it.

@dalmia
Copy link
Contributor Author
dalmia commented Jan 4, 2017

Is this ready for merge?

@@ -474,6 +481,14 @@ def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params,
estimator.fit(X_train, y_train, **fit_params)
func = getattr(estimator, method)
predictions = func(X_test)
if method in ['decisio A93C n_function', 'predict_proba', 'predict_log_proba']:
true_classes = np.unique(y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have _, n_classes = np.unique(y, return_counts=True) as you don't use the "true_classes"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned by @jnothman that return_counts is not supported for all versions and it's true that only the number of classes is being used, I have thought of another workaround. Will do the change.

assert_array_almost_equal(expected_predictions, predictions)

# Testing unordered labels
y = [1, 1, -4, 6]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to explicitly test for a use-case where the classes_ is guaranteed to not be sorted but cross_val_predict returns output in sorted order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned here, classes_ should indeed be sorted. Please have a look.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough! Thanks for checking :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But wait could we have a mock estimator that does not have the classes_ sorted? @jnothman Is it worth having such a test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind this has been discussed before. Sorry for not checking... +1 for merge with the whatsnew entry...

@jnothman
Copy link
Member
jnothman commented Jan 5, 2017 via email

@raghavrv raghavrv changed the title [MRG+1] Fix the cross_val_predict function for method='predict_proba' [MRG + 1] Fix the cross_val_predict function for method='predict_proba' Jan 5, 2017
@raghavrv
Copy link
Member
raghavrv commented Jan 5, 2017

And a whatnew entry please...

@dalmia
Copy link
Contributor Author
dalmia commented Jan 7, 2017

I am getting this error also when I am building locally:

Exception occurred:
  File "/home/ubuntu/scikit-learn/doc/sphinxext/sphinx_gallery/docs_resolv.py", line 48, in _get_data
    with open(url, 'r') as fid:
IOError: [Errno 2] No such file or directory: '/home/ubuntu/scikit-learn/doc/_build/html/stable/modules/generated/sklearn.feature_selection.SelectKBest.rst.html'
The full traceback has been saved in /tmp/sphinx-err-wz9PtY.log, if you want to report the issue to the developers.
Please also report this if it was a user error, so that a better error message can be provided next time.
A bug report can be filed in the tracker at <https://github.com/sphinx-doc/sphinx/issues>. Thanks!
Embedding documentation hyperlinks in examples..
	processing: feature_stacker.html
make: *** [html] Error 1

Any workaround for this?

@jnothman
Copy link
Member
jnothman commented Jan 7, 2017

Merging with an updated master will fix it for CircleCI

@raghavrv
Copy link
Member
raghavrv commented Jan 7, 2017

+1 for merge after updating master (and all CIs turning green)

@jnothman
Copy link
Member
jnothman commented Jan 7, 2017

@raghavrv, nothing here could cause the doc build to fail.

@jnothman jnothman merged commit fd84a56 into scikit-learn:master Jan 7, 2017
@jnothman
Copy link
Member
jnothman commented Jan 7, 2017

Merged, thanks @dalmia!

@dalmia dalmia deleted the 7863 branch January 8, 2017 01:51
sergeyf pushed a commit to sergeyf/scikit-learn that referenced this pull request Feb 28, 2017
…a' (scikit-learn#7889)

Handle the case where different CV splits have different sets of classes present.
@Przemo10 Przemo10 mentioned this pull request Mar 17, 2017
Sundrique pushed a commit to Sundrique/scikit-learn that referenced this pull request Jun 14, 2017
…a' (scikit-learn#7889)

Handle the case where different CV splits have different sets of classes present.
NelleV pushed a commit to NelleV/scikit-learn that referenced this pull request Aug 11, 2017
…a' (scikit-learn#7889)

Handle the case where different CV splits have different sets of classes present.
paulha pushed a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017
…a' (scikit-learn#7889)

Handle the case where different CV splits have different sets of classes present.
@AlJohri
Copy link
AlJohri commented Aug 21, 2017

I started getting a different shape for cross_val_predict with the method of decision_function for this model: svm.SVC(kernel='linear', C=1, probability=False) with this cross_val_predict(original_pipeline, X_train, y_train, cv=2, method='decision_function'). I was getting a 1D array before and now I'm getting a 2D array where the first column is all zeroes. I'm doing a binary classifier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

order of returned probabilites unclear for cross_val_predict with method=predict_proba
6 participants
0