-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG + 1] Add class_weight to PA Classifier, remove from PA Regressor #4767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
from sklearn.utils.testing import assert_less | ||
from sklearn.utils.testing import assert_greater | ||
from sklearn.utils.testing import assert_array_almost_equal, assert_array_equal | ||
from sklearn.utils.testing import assert_almost_equal | ||
from sklearn.utils.testing import assert_raises | ||
|
||
from sklearn.base import ClassifierMixin | ||
|
@@ -125,6 +126,77 @@ def test_classifier_undefined_methods(): | |
assert_raises(AttributeError, lambda x: getattr(clf, x), meth) | ||
|
||
|
||
def test_class_weights(): | ||
# Test class weights. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this taken from the SGDClassifier tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most of it, with some mods here and there. Some others might have been adapted from d-tree's tests if I recall correctly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be great to reuse the SGD tests for PA (since they share the implementation). There'd be more work there, so I think this shouldn't be a show stopper for this PR. @amueller what do you think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. eg, here: #4838 (comment) |
||
X2 = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0], | ||
[1.0, 1.0], [1.0, 0.0]]) | ||
y2 = [1, 1, 1, -1, -1] | ||
|
||
clf = PassiveAggressiveClassifier(C=0.1, n_iter=100, class_weight=None, | ||
random_state=100) | ||
clf.fit(X2, y2) | ||
assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([1])) | ||
|
||
# we give a small weights to class 1 | ||
clf = PassiveAggressiveClassifier(C=0.1, n_iter=100, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not entirely convinced it's better, but I can see some reasons for just doing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree it might be slightly clearer @vene but I see this paradigm only very rarely in other tests in git grep... You think it's necessary for merge? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not important, it just seems slightly better to me from a maintenance point of view. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a strong opinion on this. Either way would be fine. |
||
class_weight={1: 0.001}, | ||
random_state=100) | ||
clf.fit(X2, y2) | ||
|
||
# now the hyperplane should rotate clock-wise and | ||
# the prediction on this point should shift | ||
assert_array_equal(clf.predict([[0.2, -1.0]]), np.array([-1])) | ||
|
||
|
||
def test_partial_fit_weight_class_balanced(): | ||
# partial_fit with class_weight='balanced' not supported | ||
clf = PassiveAggressiveClassifier(class_weight="balanced") | ||
assert_raises(ValueError, clf.partial_fit, X, y, classes=np.unique(y)) | ||
|
||
|
||
def test_equal_class_weight(): | ||
X2 = [[1, 0], [1, 0], [0, 1], [0, 1]] | ||
y2 = [0, 0, 1, 1] | ||
clf = PassiveAggressiveClassifier(C=0.1, n_iter=1000, class_weight=None) | ||
clf.fit(X2, y2) | ||
|
||
# Already balanced, so "balanced" weights should have no effect | ||
clf_balanced = PassiveAggressiveClassifier(C=0.1, n_iter=1000, | ||
class_weight="balanced") | ||
clf_balanced.fit(X2, y2) | ||
|
||
clf_weighted = PassiveAggressiveClassifier(C=0.1, n_iter=1000, | ||
class_weight={0: 0.5, 1: 0.5}) | ||
clf_weighted.fit(X2, y2) | ||
|
||
# should be similar up to some epsilon due to learning rate schedule | ||
assert_almost_equal(clf.coef_, clf_weighted.coef_, decimal=2) | ||
assert_almost_equal(clf.coef_, clf_balanced.coef_, decimal=2) | ||
|
||
|
||
def test_wrong_class_weight_label(): | ||
# ValueError due to wrong class_weight label. | ||
X2 = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0], | ||
[1.0, 1.0], [1.0, 0.0]]) | ||
y2 = [1, 1, 1, -1, -1] | ||
|
||
clf = PassiveAggressiveClassifier(class_weight={0: 0.5}) | ||
assert_raises(ValueError, clf.fit, X2, y2) | ||
|
||
|
||
def test_wrong_class_weight_format(): | ||
# ValueError due to wrong class_weight argument type. | ||
X2 = np.array([[-1.0, -1.0], [-1.0, 0], [-.8, -1.0], | ||
[1.0, 1.0], [1.0, 0.0]]) | ||
y2 = [1, 1, 1, -1, -1] | ||
|
||
clf = PassiveAggressiveClassifier(class_weight=[0.5]) | ||
assert_raises(ValueError, clf.fit, X2, y2) | ||
|
||
clf = PassiveAggressiveClassifier(class_weight="the larch") | ||
assert_raises(ValueError, clf.fit, X2, y2) | ||
|
||
|
||
def test_regressor_mse(): | ||
y_bin = y.copy() | ||
y_bin[y != 1] = -1 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class_weight='balanced'`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I follow, this is a logical test...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my comment was on the next line, I was referring to the exception message.