-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG+1] Added support for sample_weight in linearSVR, including tests and documentation. Fixes #6862 #6907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
agramfort
merged 22 commits into
scikit-learn:master
from
imaculate:linearsvr_sampleweight
Jun 23, 2016
Merged
[MRG+1] Added support for sample_weight in linearSVR, including tests and documentation. Fixes #6862 #6907
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
0fa60b7
Make KernelCenterer a _pairwise operation
fishcorn 0043885
Adding test for PR #6900
fishcorn 069336e
Simplifying imports and test
fishcorn 039b6f3
updating changelog links on homepage (#6901)
jamoque f69fb7e
first commit
HashCode55 2d7929d
changed binary average back to macro
HashCode55 1267f6d
changed binomialNB to multinomialNB
HashCode55 f911bb6
emphasis on "higher return values are better..." (#6909)
1534d0c
fix typo in comment of hierarchical clustering (#6912)
b-carter 3c34fb3
[MRG] Allows KMeans/MiniBatchKMeans to use float32 internally by usin…
yenchenlin 2accd0c
Fix sklearn.base.clone for all scipy.sparse formats (#6910)
lesteve a08a1fd
DOC If git is not installed, need to catch OSError
jnothman 943836c
DOC add what's new for clone fix
jnothman 478614a
fix a typo in ridge.py (#6917)
ryanyu9 41000d5
pep8
fishcorn 3dfb282
TST: Speed up: cv=2
GaelVaroquaux 99392a8
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
imaculate 74414dc
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
imaculate e5e0320
Added support for sample_weight in linearSVR, including tests and doc…
imaculate e9f2ff7
Changed assert to assert_allclose and assert_almost_equal, reduced th…
imaculate ae39622
Fixed pep8 violations and sampleweight format
imaculate 65d1d93
rebased with upstream
imaculate File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,12 +6,11 @@ | |
|
||
import numpy as np | ||
import itertools | ||
|
||
from numpy.testing import assert_array_equal, assert_array_almost_equal | ||
from numpy.testing import assert_almost_equal | ||
from numpy.testing import assert_allclose | ||
from scipy import sparse | ||
from nose.tools import assert_raises, assert_true, assert_equal, assert_false | ||
|
||
from sklearn import svm, linear_model, datasets, metrics, base | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.datasets import make_classification, make_blobs | ||
|
@@ -25,7 +24,6 @@ | |
from sklearn.exceptions import ChangedBehaviorWarning | ||
from sklearn.exceptions import ConvergenceWarning | ||
from sklearn.exceptions import NotFittedError | ||
|
||
from sklearn.multiclass import OneVsRestClassifier | ||
|
||
# toy sample | ||
|
@@ -198,8 +196,44 @@ def test_linearsvr(): | |
svr = svm.SVR(kernel='linear', C=1e3).fit(diabetes.data, diabetes.target) | ||
score2 = svr.score(diabetes.data, diabetes.target) | ||
|
||
assert np.linalg.norm(lsvr.coef_ - svr.coef_) / np.linalg.norm(svr.coef_) < .1 | ||
assert np.abs(score1 - score2) < 0.1 | ||
assert_allclose(np.linalg.norm(lsvr.coef_), | ||
np.linalg.norm(svr.coef_), 1, 0.0001) | ||
assert_almost_equal(score1, score2, 2) | ||
|
||
|
||
def test_linearsvr_fit_sampleweight(): | ||
# check correct result when sample_weight is 1 | ||
# check that SVR(kernel='linear') and LinearSVC() give | ||
# comparable results | ||
diabetes = datasets.load_diabetes() | ||
n_samples = len(diabetes.target) | ||
unit_weight = np.ones(n_samples) | ||
lsvr = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target, | ||
sample_weight=unit_weight) | ||
score1 = lsvr.score(diabetes.data, diabetes.target) | ||
|
||
lsvr_no_weight = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target) | ||
score2 = lsvr_no_weight.score(diabetes.data, diabetes.target) | ||
|
||
assert_allclose(np.linalg.norm(lsvr.coef_), | ||
np.linalg.norm(lsvr_no_weight.coef_), 1, 0.0001) | ||
assert_almost_equal(score1, score2, 2) | ||
|
||
# check that fit(X) = fit([X1, X2, X3],sample_weight = [n1, n2, n3]) where | ||
# X = X1 repeated n1 times, X2 repeated n2 times and so forth | ||
random_state = check_random_state(0) | ||
random_weight = random_state.randint(0, 10, n_samples) | ||
lsvr_unflat = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target, | ||
sample_weight=random_weight) | ||
score3 = lsvr_unflat.score(diabetes.data, diabetes.target, | ||
sample_weight=random_weight) | ||
|
||
X_flat = np.repeat(diabetes.data, random_weight, axis=0) | ||
y_flat = np.repeat(diabetes.target, random_weight, axis=0) | ||
lsvr_flat = svm.LinearSVR(C=1e3).fit(X_flat, y_flat) | ||
score4 = lsvr_flat.score(X_flat, y_flat) | ||
|
||
assert_almost_equal(score3, score4, 2) | ||
|
||
|
||
def test_svr_errors(): | ||
|
@@ -277,14 +311,13 @@ def test_probability(): | |
|
||
for clf in (svm.SVC(probability=True, random_state=0, C=1.0), | ||
svm.NuSVC(probability=True, random_state=0)): | ||
|
||
clf.fit(iris.data, iris.target) | ||
|
||
prob_predict = clf.predict_proba(iris.data) | ||
assert_array_almost_equal( | ||
np.sum(prob_predict, 1), np.ones(iris.data.shape[0])) | ||
assert_true(np.mean(np.argmax(prob_predict, 1) | ||
== clf.predict(iris.data)) > 0.9) | ||
== clf.predict(iris.data)) > 0.9) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually we wouldn't go fixing up cosmetic things when submitting an unrelated PR. It makes the PR somewhat harder to review. But at least this PR is small and focussed |
||
|
||
assert_almost_equal(clf.predict_proba(iris.data), | ||
np.exp(clf.predict_log_proba(iris.data)), 8) | ||
|
@@ -509,9 +542,9 @@ def test_linearsvc_parameters(): | |
for loss, penalty, dual in itertools.product(losses, penalties, duals): | ||
clf = svm.LinearSVC(penalty=penalty, loss=loss, dual=dual) | ||
if ((loss, penalty) == ('hinge', 'l1') or | ||
(loss, penalty, dual) == ('hinge', 'l2', False) or | ||
(penalty, dual) == ('l1', True) or | ||
loss == 'foo' or penalty == 'bar'): | ||
(loss, penalty, dual) == ('hinge', 'l2', False) or | ||
(penalty, dual) == ('l1', True) or | ||
loss == 'foo' or penalty == 'bar'): | ||
|
||
assert_raises_regexp(ValueError, | ||
"Unsupported set of arguments.*penalty='%s.*" | ||
|
@@ -569,7 +602,7 @@ def test_linear_svx_uppercase_loss_penality_raises_error(): | |
svm.LinearSVC(loss="SQuared_hinge").fit, X, y) | ||
|
||
assert_raise_message(ValueError, ("The combination of penalty='L2'" | ||
" and loss='squared_hinge' is not supported"), | ||
" and loss='squared_hinge' is not supported"), | ||
svm.LinearSVC(penalty="L2").fit, X, y) | ||
|
||
|
||
|
@@ -634,7 +667,6 @@ def test_crammer_singer_binary(): | |
|
||
|
||
def test_linearsvc_iris(): | ||
|
||
# Test that LinearSVC gives plausible predictions on the iris dataset | ||
# Also, test symbolic class names (classes_). | ||
target = iris.target_names[iris.target] | ||
|
@@ -773,7 +805,7 @@ def test_timeout(): | |
|
||
|
||
def test_unfitted(): | ||
X = "foo!" # input validation not required when SVM not fitted | ||
X = "foo!" # input validation not required when SVM not fitted | ||
|
||
clf = svm.SVC() | ||
assert_raises_regexp(Exception, r".*\bSVC\b.*\bnot\b.*\bfitted\b", | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may also be worth testing that the method accepts a list (rather than array).