8000 [MRG+1] Added support for sample_weight in linearSVR, including tests and documentation. Fixes #6862 by imaculate · Pull Request #6907 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Added support for sample_weight in linearSVR, including tests and documentation. Fixes #6862 #6907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jun 23, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
0fa60b7
Make KernelCenterer a _pairwise operation
fishcorn Jun 17, 2016
0043885
Adding test for PR #6900
fishcorn Jun 17, 2016
069336e
Simplifying imports and test
fishcorn Jun 17, 2016
039b6f3
updating changelog links on homepage (#6901)
jamoque Jun 18, 2016
f69fb7e
first commit
HashCode55 Jun 18, 2016
2d7929d
changed binary average back to macro
HashCode55 Jun 19, 2016
1267f6d
changed binomialNB to multinomialNB
HashCode55 Jun 19, 2016
f911bb6
emphasis on "higher return values are better..." (#6909)
Jun 19, 2016
1534d0c
fix typo in comment of hierarchical clustering (#6912)
b-carter Jun 21, 2016
3c34fb3
[MRG] Allows KMeans/MiniBatchKMeans to use float32 internally by usin…
yenchenlin Jun 21, 2016
2accd0c
Fix sklearn.base.clone for all scipy.sparse formats (#6910)
lesteve Jun 21, 2016
a08a1fd
DOC If git is not installed, need to catch OSError
jnothman Jun 21, 2016
943836c
DOC add what's new for clone fix
jnothman Jun 21, 2016
478614a
fix a typo in ridge.py (#6917)
ryanyu9 Jun 22, 2016
41000d5
pep8
fishcorn Jun 22, 2016
3dfb282
TST: Speed up: cv=2
GaelVaroquaux Jun 22, 2016
99392a8
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
imaculate Jun 22, 2016
74414dc
Merge branch 'master' of https://github.com/scikit-learn/scikit-learn
imaculate Jun 23, 2016
e5e0320
Added support for sample_weight in linearSVR, including tests and doc…
imaculate Jun 18, 2016
e9f2ff7
Changed assert to assert_allclose and assert_almost_equal, reduced th…
imaculate Jun 22, 2016
ae39622
Fixed pep8 violations and sampleweight format
imaculate Jun 22, 2016
65d1d93
rebased with upstream
imaculate Jun 22, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions sklearn/svm/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..linear_model.base import LinearClassifierMixin, SparseCoefMixin, \
LinearModel
from ..feature_selection.from_model import _LearntSelectorMixin
from ..utils import check_X_y
from ..utils import check_X_y, column_or_1d
from ..utils.validation import _num_samples
from ..utils.multiclass import check_classification_targets

Expand Down Expand Up @@ -329,7 +329,7 @@ def __init__(self, epsilon=0.0, tol=1e-4, C=1.0,
self.dual = dual
self.loss = loss

def fit(self, X, y):
def fit(self, X, y, sample_weight=None):
"""Fit the model according to the given training data.

Parameters
Expand Down Expand Up @@ -374,7 +374,7 @@ def fit(self, X, y):
X, y, self.C, self.fit_intercept, self.intercept_scaling,
None, penalty, self.dual, self.verbose,
self.max_iter, self.tol, self.random_state, loss=self.loss,
epsilon=self.epsilon)
epsilon=self.epsilon, sample_weight=sample_weight)
self.coef_ = self.coef_.ravel()

return self
Expand Down Expand Up @@ -766,6 +766,9 @@ class SVR(BaseLibSVM, RegressorMixin):
intercept_ : array, shape = [1]
Constants in decision function.

sample_weight : array-like, shape = [n_samples]
Individual weights for each sample

Examples
--------
>>> from sklearn.svm import SVR
Expand Down
58 changes: 45 additions & 13 deletions sklearn/svm/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

import numpy as np
import itertools

from numpy.testing import assert_array_equal, assert_array_almost_equal
from numpy.testing import assert_almost_equal
from numpy.testing import assert_allclose
from scipy import sparse
from nose.tools import assert_raises, assert_true, assert_equal, assert_false

from sklearn import svm, linear_model, datasets, metrics, base
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification, make_blobs
Expand All @@ -25,7 +24,6 @@
from sklearn.exceptions import ChangedBehaviorWarning
from sklearn.exceptions import ConvergenceWarning
from sklearn.exceptions import NotFittedError

from sklearn.multiclass import OneVsRestClassifier

# toy sample
Expand Down Expand Up @@ -198,8 +196,44 @@ def test_linearsvr():
svr = svm.SVR(kernel='linear', C=1e3).fit(diabetes.data, diabetes.target)
score2 = svr.score(diabetes.data, diabetes.target)

assert np.linalg.norm(lsvr.coef_ - svr.coef_) / np.linalg.norm(svr.coef_) < .1
assert np.abs(score1 - score2) < 0.1
assert_allclose(np.linalg.norm(lsvr.coef_),
np.linalg.norm(svr.coef_), 1, 0.0001)
assert_almost_equal(score1, score2, 2)


def test_linearsvr_fit_sampleweight():
# check correct result when sample_weight is 1
# check that SVR(kernel='linear') and LinearSVC() give
# comparable results
diabetes = datasets.load_diabetes()
n_samples = len(diabetes.target)
unit_weight = np.ones(n_samples)
lsvr = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target,
sample_weight=unit_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may also be worth testing that the method accepts a list (rather than array).

score1 = lsvr.score(diabetes.data, diabetes.target)

lsvr_no_weight = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target)
score2 = lsvr_no_weight.score(diabetes.data, diabetes.target)

assert_allclose(np.linalg.norm(lsvr.coef_),
np.linalg.norm(lsvr_no_weight.coef_), 1, 0.0001)
assert_almost_equal(score1, score2, 2)

# check that fit(X) = fit([X1, X2, X3],sample_weight = [n1, n2, n3]) where
# X = X1 repeated n1 times, X2 repeated n2 times and so forth
random_state = check_random_state(0)
random_weight = random_state.randint(0, 10, n_samples)
lsvr_unflat = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target,
sample_weight=random_weight)
score3 = lsvr_unflat.score(diabetes.data, diabetes.target,
sample_weight=random_weight)

X_flat = np.repeat(diabetes.data, random_weight, axis=0)
y_flat = np.repeat(diabetes.target, random_weight, axis=0)
lsvr_flat = svm.LinearSVR(C=1e3).fit(X_flat, y_flat)
score4 = lsvr_flat.score(X_flat, y_flat)

assert_almost_equal(score3, score4, 2)


def test_svr_errors():
Expand Down Expand Up @@ -277,14 +311,13 @@ def test_probability():

for clf in (svm.SVC(probability=True, random_state=0, C=1.0),
svm.NuSVC(probability=True, random_state=0)):

clf.fit(iris.data, iris.target)

prob_predict = clf.predict_proba(iris.data)
assert_array_almost_equal(
np.sum(prob_predict, 1), np.ones(iris.data.shape[0]))
assert_true(np.mean(np.argmax(prob_predict, 1)
== clf.predict(iris.data)) > 0.9)
== clf.predict(iris.data)) > 0.9)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually we wouldn't go fixing up cosmetic things when submitting an unrelated PR. It makes the PR somewhat harder to review. But at least this PR is small and focussed


assert_almost_equal(clf.predict_proba(iris.data),
np.exp(clf.predict_log_proba(iris.data)), 8)
Expand Down Expand Up @@ -509,9 +542,9 @@ def test_linearsvc_parameters():
for loss, penalty, dual in itertools.product(losses, penalties, duals):
clf = svm.LinearSVC(penalty=penalty, loss=loss, dual=dual)
if ((loss, penalty) == ('hinge', 'l1') or
(loss, penalty, dual) == ('hinge', 'l2', False) or
(penalty, dual) == ('l1', True) or
loss == 'foo' or penalty == 'bar'):
(loss, penalty, dual) == ('hinge', 'l2', False) or
(penalty, dual) == ('l1', True) or
loss == 'foo' or penalty == 'bar'):

assert_raises_regexp(ValueError,
"Unsupported set of arguments.*penalty='%s.*"
Expand Down Expand Up @@ -569,7 +602,7 @@ def test_linear_svx_uppercase_loss_penality_raises_error():
svm.LinearSVC(loss="SQuared_hinge").fit, X, y)

assert_raise_message(ValueError, ("The combination of penalty='L2'"
" and loss='squared_hinge' is not supported"),
" and loss='squared_hinge' is not supported"),
svm.LinearSVC(penalty="L2").fit, X, y)


Expand Down Expand Up @@ -634,7 +667,6 @@ def test_crammer_singer_binary():


def test_linearsvc_iris():

# Test that LinearSVC gives plausible predictions on the iris dataset
# Also, test symbolic class names (classes_).
target = iris.target_names[iris.target]
Expand Down 563F Expand Up @@ -773,7 +805,7 @@ def test_timeout():


def test_unfitted():
X = "foo!" # input validation not required when SVM not fitted
X = "foo!" # input validation not required when SVM not fitted

clf = svm.SVC()
assert_raises_regexp(Exception, r".*\bSVC\b.*\bnot\b.*\bfitted\b",
Expand Down
0