8000 [MRG] Fix predict_proba not fitted check in SGDClassifier by aniruddhadave · Pull Request #10961 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Fix predict_proba not fitted check in SGDClassifier #10961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 17, 2018

Conversation

aniruddhadave
Copy link
Contributor

Reference Issues/PRs

Fixes #10938

What does this implement/fix? Explain your changes.

Fixes not fitted check in predict_proba method so that it doesn't throw a not fitted error while referencing the method. Checks whether the classifier is fitted or not when the method is called.

Any other comments?

- Remove not fitted check from predict_proba method of SGDClassifier
- Check only while calling predic_proba
Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that is the right fix. Please add a non-leash regression test

@aniruddhadave
Copy link
Contributor Author

I understand non-regression testing as given on the contributing guidelines but how is a non-leash regression test different from that?

@lesteve
Copy link
Member
lesteve commented Apr 13, 2018

I think @jnothman meant a non-regression test (maybe autocorrect or something?). Great if you understand what it is from the contributing guidelines. Can you add a non-regression test then?

@lesteve
Copy link
Member
lesteve commented Apr 13, 2018

Suggestion for non-regression test (add a test function in sklearn/linear_model/tests/test_sgd.py and mention the issue number for completeness):

from sklearn.linear_model import SGDClassifier

clf = SGDClassifier()
clf.predict_proba
clf.predict_log_proba

@aniruddhadave
Copy link
Contributor Author

@lesteve There already exists a test case for predict_proba method (test_sgd_proba) shouldn't it be modified instead of writing a separate test case?

@lesteve
Copy link
Member
lesteve commented Apr 14, 2018

Not really important, but I would rather put that in a separate test function. It is a bit a special case to just test that you can access the predict_proba and predict_log_proba attribute.

	-Test if the the predict_proba and predict_log_proba
	 methods can be accessed before fitting
	-Test if the not fitted check is performed before calling
	 the proba methods
Copy link
Member
@jnothman jnothman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM


# Checks if not fitted check is performed while calling
# the methods
assert_raises(NotFittedError, clf.predict_proba,[[3,2]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space after commas please

# is accessible for refrencing before fitting
# the SGD classifier
clf = SGDClassifier()
assert_false(hasattr(clf,"predict_proba"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please just use bare assert not ... rather than assert_false. Thanks

	-Remove extra line
	-Add space after comma
	-Use assert instead of asser_false
Copy link
Member
@lesteve lesteve left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments.


# Checks if not fitted check is performed while calling
# the methods
assert_raises(NotFittedError, clf.predict_proba, [[3, 2]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would expect the NotFittedError case to be already tested in test_common.py so I would remove this from the test.


for loss in ["log", "modified_huber"]:
clf = SGDClassifier(loss=loss)
assert_true(hasattr(clf, "predict_proba"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use bare assert here too.

# is accessible for refrencing before fitting
# the SGD classifier
clf = SGDClassifier()
assert not(hasattr(clf, "predict_proba"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe check this for all the losses that do not support predict_proba, you can use SGD.loss_functions to get all the possible losses I think.

Slight improvement would be to also check the error message:

with pytest.raises(AttributeError, 'probability estimates are not available for loss={!r}'.format(loss)

aniruddhadave and others added 2 commits April 16, 2018 22:04
Remove test for NotFittedError
@lesteve
Copy link
Member
lesteve commented Apr 16, 2018

I pushed some minor tweaks, I think this can be merged when CIs are green.

@jnothman
Copy link
Member

Perhaps this needs a changelog entry.
Please add an entry to the change log at doc/whats_new/v0.20.rst under API changes. Like the other entries there, please reference this pull request with :issue: and credit yourself (and other contributors if applicable) with :user:

@lesteve
Copy link
Member
lesteve commented Apr 17, 2018

Perhaps this needs a changelog entry.

I am a bit undecided but I would say that this is rather an obscure problem (trying to only access predict_proba rather than calling it before fit) and I feel this does not really deserve a log entry. I am going to merge this one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

SGDClassifier: referencing 'predict_proba' method in unfitted object throws error
3 participants
0