8000 Merge branch 'master' of https://github.com/scikit-learn/scikit-learn · scikit-learn/scikit-learn@33097ab · GitHub
[go: up one dir, main page]

Skip to content

Commit 33097ab

Browse files
committed
2 parents 74414dc + 3cc7fea commit 33097ab

File tree

3 files changed

+52
-17
lines changed

3 files changed

+52
-17
lines changed

doc/whats_new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ Bug fixes
281281

282282
- Fix a bug where some formats of ``scipy.sparse`` matrix, and estimators
283283
with them as parameters, could not be passed to :func:`base.clone`.
284-
By `Loic Eseve`_.
284+
By `Loic Esteve`_.
285285

286286

287287
API changes summary

sklearn/svm/classes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ..linear_model.base import LinearClassifierMixin, SparseCoefMixin, \
77
LinearModel
88
from ..feature_selection.from_model import _LearntSelectorMixin
9-
from ..utils import check_X_y
9+
from ..utils import check_X_y, column_or_1d
1010
from ..utils.validation import _num_samples
1111
from ..utils.multiclass import check_classification_targets
1212

@@ -329,7 +329,7 @@ def __init__(self, epsilon=0.0, tol=1e-4, C=1.0,
329329
self.dual = dual
330330
self.loss = loss
331331

332-
def fit(self, X, y):
332+
def fit(self, X, y, sample_weight=None):
333333
"""Fit the model according to the given training data.
334334
335335
Parameters
@@ -374,7 +374,7 @@ def fit(self, X, y):
374374
X, y, self.C, self.fit_intercept, self.intercept_scaling,
375375
None, penalty, self.dual, self.verbose,
376376
self.max_iter, self.tol, self.random_state, loss=self.loss,
377-
epsilon=self.epsilon)
377+
epsilon=self.epsilon, sample_weight=sample_weight)
378378
self.coef_ = self.coef_.ravel()
379379

380380
return self
@@ -766,6 +766,9 @@ class SVR(BaseLibSVM, RegressorMixin):
766766
intercept_ : array, shape = [1]
767767
Constants in decision function.
768768
769+
sample_weight : array-like, shape = [n_samples]
770+
Individual weights for each sample
771+
769772
Examples
770773
--------
771774
>>> from sklearn.svm import SVR

sklearn/svm/tests/test_svm.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66

77
import numpy as np
88
import itertools
9-
109
from numpy.testing import assert_array_equal, assert_array_almost_equal
1110
from numpy.testing import assert_almost_equal
11+
from numpy.testing import assert_allclose
1212
from scipy import sparse
1313
from nose.tools import assert_raises, assert_true, assert_equal, assert_false
14-
1514
from sklearn import svm, linear_model, datasets, metrics, base
1615
from sklearn.model_selection import train_test_split
1716
from sklearn.datasets import make_classification, make_blobs
@@ -25,7 +24,6 @@
2524
from sklearn.exceptions import ChangedBehaviorWarning
2625
from sklearn.exceptions import ConvergenceWarning
2726
from sklearn.exceptions import NotFittedError
28-
2927
from sklearn.multiclass import OneVsRestClassifier
3028

3129
# toy sample
@@ -198,8 +196,44 @@ def test_linearsvr():
198196
svr = svm.SVR(kernel='linear', C=1e3).fit(diabetes.data, diabetes.target)
199197
score2 = svr.score(diabetes.data, diabetes.target)
200198

201-
assert np.linalg.norm(lsvr.coef_ - svr.coef_) / np.linalg.norm(svr.coef_) < .1
202-
assert np.abs(score1 - score2) < 0.1
199+
assert_allclose(np.linalg.norm(lsvr.coef_),
200+
np.linalg.norm(svr.coef_), 1, 0.0001)
201+
assert_almost_equal(score1, score2, 2)
202+
203+
204+
def test_linearsvr_fit_sampleweight():
205+
# check correct result when sample_weight is 1
206+
# check that SVR(kernel='linear') and LinearSVC() give
207+
# comparable results
208+
diabetes = datasets.load_diabetes()
209+
n_samples = len(diabetes.target)
210+
unit_weight = np.ones(n_samples)
211+
lsvr = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target,
212+
sample_weight=unit_weight)
213+
score1 = lsvr.score(diabetes.data, diabetes.target)
214+
215+
lsvr_no_weight = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target)
216+
score2 = lsvr_no_weight.score(diabetes.data, diabetes.target)
217+
218+
assert_allclose(np.linalg.norm(lsvr.coef_),
219+
np.linalg.norm(lsvr_no_weight.coef_), 1, 0.0001)
220+
assert_almost_equal(score1, score2, 2)
221+
222+
# check that fit(X) = fit([X1, X2, X3],sample_weight = [n1, n2, n3]) where
223+
# X = X1 repeated n1 times, X2 repeated n2 times and so forth
224+
random_state = check_random_state(0)
225+
random_weight = random_state.randint(0, 10, n_samples)
226+
lsvr_unflat = svm.LinearSVR(C=1e3).fit(diabetes.data, diabetes.target,
227+
sample_weight=random_weight)
228+
score3 = lsvr_unflat.score(diabetes.data, diabetes.target,
229+
sample_weight=random_weight)
230+
231+
X_flat = np.repeat(diabetes.data, random_weight, axis=0)
232+
y_flat = np.repeat(diabetes.target, random_weight, axis=0)
233+
lsvr_flat = svm.LinearSVR(C=1e3).fit(X_flat, y_flat)
234+
score4 = lsvr_flat.score(X_flat, y_flat)
235+
236+
assert_almost_equal(score3, score4, 2)
203237

204238

205239
def test_svr_errors():
@@ -277,14 +311,13 @@ def test_probability():
277311

278312
for clf in (svm.SVC(probability=True, random_state=0, C=1.0),
279313
svm.NuSVC(probability=True, random_state=0)):
280-
281314
clf.fit(iris.data, iris.target)
282315

283316
prob_predict = clf.predict_proba(iris.data)
284317
assert_array_almost_equal(
285318
np.sum(prob_predict, 1), np.ones(iris.data.shape[0]))
286319
assert_true(np.mean(np.argmax(prob_predict, 1)
287-
== clf.predict(iris.data)) > 0.9)
320+
== clf.predict(iris.data)) > 0.9)
288321

289322
assert_almost_equal(clf.predict_proba(iris.data),
290323
np.exp(clf.predict_log_proba(iris.data)), 8)
@@ -509,9 +542,9 @@ def test_linearsvc_parameters():
509542
for loss, penalty, dual in itertools.product(losses, penalties, duals):
510543
clf = svm.LinearSVC(penalty=penalty, loss=loss, dual=dual)
511544
if ((loss, penalty) == ('hinge', 'l1') or
512-
(loss, penalty, dual) == ('hinge', 'l2', False) or
513-
(penalty, dual) == ('l1', True) or
514-
loss == 'foo' or penalty == 'bar'):
545+
(loss, penalty, dual) == ('hinge', 'l2', False) or
546+
(penalty, dual) == ('l1', True) or
547+
loss == 'foo' or penalty == 'bar'):
515548

516549
assert_raises_regexp(ValueError,
517550
"Unsupported set of arguments.*penalty='%s.*"
@@ -569,7 +602,7 @@ def test_linear_svx_uppercase_loss_penality_raises_error():
569602
svm.LinearSVC(loss="SQuared_hinge").fit, X, y)
570603

571604
assert_raise_message(ValueError, ("The combination of penalty='L2'"
572-
" and loss='squared_hinge' is not supported"),
605+
" and loss='squared_hinge' is not supported"),
573606
svm.LinearSVC(penalty="L2").fit, X, y)
574607

575608

@@ -634,7 +667,6 @@ def test_crammer_singer_binary():
634667

635668

636669
def test_linearsvc_iris():
637-
638670
# Test that LinearSVC gives plausible predictions on the iris dataset
639671
# Also, test symbolic class names (classes_).
640672
target = iris.target_names[iris.target]
@@ -773,7 +805,7 @@ def test_timeout():
773805

774806

775807
def test_unfitted():
776-
X = "foo!" # input validation not required when SVM not fitted
808+
X = "foo!" # input validation not required when SVM not fitted
777809

778810
clf = svm.SVC()
779811
assert_raises_regexp(Exception, r".*\bSVC\b.*\bnot\b.*\bfitted\b",

0 commit comments

Comments
 (0)
0