-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[WIP] gamma=auto in SVC #8361 #8535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -405,7 +405,7 @@ def test_weight(): | |
weights=[0.833, 0.167], random_state=2) | ||
|
||
for clf in (linear_model.LogisticRegression(), | ||
svm.LinearSVC(random_state=0), svm.SVC()): | ||
svm.LinearSVC(random_state=0), svm.SVC(gamma="scale")): | ||
clf.set_params(class_weight={0: .1, 1: 10}) | ||
clf.fit(X_[:100], y_[:100]) | ||
y_pred = clf.predict(X_[100:]) | ||
|
@@ -415,7 +415,7 @@ def test_weight(): | |
def test_sample_weights(): | ||
# Test weights on individual samples | ||
# TODO: check on NuSVR, OneClass, etc. | ||
clf = svm.SVC() | ||
clf = svm.SVC(gamma="scale") | ||
clf.fit(X, Y) | ||
assert_array_equal(clf.predict([X[2]]), [1.]) | ||
|
||
|
@@ -424,7 +424,7 @@ def test_sample_weights(): | |
assert_array_equal(clf.predict([X[2]]), [2.]) | ||
|
||
# test that rescaling all samples is the same as changing C | ||
clf = svm.SVC() | ||
clf = svm.SVC(gamma="scale") | ||
clf.fit(X, Y) | ||
dual_coef_no_weight = clf.dual_coef_ | ||
clf.set_params(C=100) | ||
|
@@ -472,7 +472,7 @@ def test_bad_input(): | |
assert_raises(ValueError, clf.fit, X, Y2) | ||
|
||
# Test with arrays that are non-contiguous. | ||
for clf in (svm.SVC(), svm.LinearSVC(random_state=0)): | ||
for clf in (svm.SVC(gamma="scale"), svm.LinearSVC(random_state=0)): | ||
Xf = np.asfortranarray(X) | ||
assert_false(Xf.flags['C_CONTIGUOUS']) | ||
yf = np.ascontiguousarray(np.tile(Y, (2, 1)).T) | ||
|
@@ -487,18 +487,18 @@ def test_bad_input(): | |
assert_raises(ValueError, clf.fit, X, Y) | ||
|
||
# sample_weight bad dimensions | ||
clf = svm.SVC() | ||
clf = svm.SVC(gamma="scale") | ||
assert_raises(ValueError, clf.fit, X, Y, sample_weight=range(len(X) - 1)) | ||
|
||
# predict with sparse input when trained with dense | ||
clf = svm.SVC().fit(X, Y) | ||
clf = svm.SVC(gamma="scale").fit(X, Y) | ||
assert_raises(ValueError, clf.predict, sparse.lil_matrix(X)) | ||
|
||
Xt = np.array(X).T | ||
clf.fit(np.dot(X, Xt), Y) | ||
assert_raises(ValueError, clf.predict, X) | ||
|
||
clf = svm.SVC() | ||
clf = svm.SVC(gamma="scale") | ||
clf.fit(X, Y) | ||
assert_raises(ValueError, clf.predict, Xt) | ||
|
||
|
@@ -844,7 +844,7 @@ def test_timeout(): | |
def test_unfitted(): | ||
X = "foo!" # input validation not required when SVM not fitted | ||
|
||
clf = svm.SVC() | ||
clf = svm.SVC(gamma="scale") | ||
assert_raises_regexp(Exception, r".*\bSVC\b.*\bnot\b.*\bfitted\b", | ||
clf.predict, X) | ||
|
||
|
@@ -974,3 +974,21 @@ def test_ovr_decision_function(): | |
# Test if the first point has lower decision value on every quadrant | ||
# compared to the second point | ||
assert_true(np.all(pred_class_deci_val[:, 0] < pred_class_deci_val[:, 1])) | ||
|
||
def test_gamma_auto(): | ||
X, y = [[0.0], [1.0]], [0, 1] | ||
|
||
msg = ("The default gamma parameter value 'auto', calculated as 1 / n_features," | ||
" is depreciated in version 0.19 and will be replaced by 'scale'," | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we use "deprecated" not "depreciated" |
||
" calculated as 1 / (n_features * X.std()) in version 0.21.") | ||
|
||
assert_warns_message(DeprecationWarning, | ||
msg, | ||
svm.SVC(gamma='auto').fit, X, y) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But this means that a user can't intentionally pass 'auto' without receiving a warning, which isn't great. We could solve this by making the default actually |
||
|
||
def test_gamma_scale(): | ||
X, y = [[0.0], [1.0]], [0, 1] | ||
|
||
clf = svm.SVC(gamma='scale').fit(X, y) | ||
assert_equal(clf._gamma, 2.0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please check for more than one |
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there should be a newline at the end fo the file. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can use this PR to also set a
random_state
, to reduce diff of #8563.It will anyway have a merge conflict for all these lines in that PR...
@jnothman Ok with you?