8000 Fresh branch for linearsvr_fit_sample_weight with weights and documen… · scikit-learn/scikit-learn@8e403d6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8e403d6

Browse files
committed
Fresh branch for linearsvr_fit_sample_weight with weights and documentation
1 parent f5be780 commit 8e403d6

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

sklearn/svm/classes.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 + 8000 165,7 @@ def __init__(self, penalty='l2', loss='squared_hinge', dual=True, tol=1e-4,
165165
self.penalty = penalty
166166
self.loss = loss
167167

168-
def fit(self, X, y):
168+
def fit(self, X, y, sample_weight=None):
169169
"""Fit the model according to the given training data.
170170
171171
Parameters
@@ -177,6 +177,11 @@ def fit(self, X, y):
177177
y : array-like, shape = [n_samples]
178178
Target vector relative to X
179179
180+
sample_weight : array-like, shape = [n_samples], optional
181+
Array of weights that are assigned to individual
182+
samples. If not provided,
183+
then each sample is given unit weight.
184+
180185
Returns
181186
-------
182187
self : object
@@ -210,7 +215,7 @@ def fit(self, X, y):
210215
X, y, self.C, self.fit_intercept, self.intercept_scaling,
211216
self.class_weight, self.penalty, self.dual, self.verbose,
212217
self.max_iter, self.tol, self.random_state, self.multi_class,
213-
self.loss)
218+
self.loss, sample_weight=sample_weight)
214219

215220
if self.multi_class == "crammer_singer" and len(self.classes_) == 2:
216221
self.coef_ = (self.coef_[1] - self.coef_[0]).reshape(1, -1)
@@ -341,6 +346,11 @@ def fit(self, X, y, sample_weight=None):
341346
y : array-like, shape = [n_samples]
342347
Target vector relative to X
343348
349+
sample_weight : array-like, shape = [n_samples], optional
350+
Array of weights that are assigned to individual
351+
samples. If not provided,
352+
then each sample is given unit weight.
353+
344354
Returns
345355
-------
346356
self : object

sklearn/svm/tests/test_svm.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,46 @@ def test_linearsvc_crammer_singer():
655655
assert_array_almost_equal(dec_func, cs_clf.decision_function(iris.data))
656656

657657

658+
def test_linearsvc_fit_sampleweight():
659+
# check correct result when sample_weight is 1
660+
# check that SVR(kernel='linear') and LinearSVC() give
661+
# comparable results
662+
663+
# Test basic routines using LinearSVC
664+
n_samples = len(X)
665+
unit_weight = np.ones(n_samples)
666+
clf = svm.LinearSVC(random_state=0).fit(X, Y)
667+
clf_unitweight = svm.LinearSVC(random_state=0).fit(X, Y,
668+
sample_weight=unit_weight)
669+
670+
# sanity check, by default should have intercept
671+
assert_true(clf_unitweight.fit_intercept)
672+
assert_array_almost_equal(clf_unitweight.intercept_, [0], decimal=3)
673+
674+
# check if same as sample_weight=None
675+
assert_array_equal(clf_unitweight.predict(T), clf.predict(T))
676+
assert_allclose(np.linalg.norm(clf.coef_),
677+
np.linalg.norm(clf_unitweight.coef_), 1, 0.0001)
678+
679+
# check that fit(X) = fit([X1, X2, X3],sample_weight = [n1, n2, n3]) where
680+
# X = X1 repeated n1 times, X2 repeated n2 times and so forth
681+
682+
random_state = check_random_state(0)
683+
random_weight = random_state.randint(0, 10, n_samples)
684+
lsvc_unflat = svm.LinearSVC(random_state=0).fit(X, Y,
685+
sample_weight=random_weight)
686+
pred1 = lsvc_unflat.predict(T)
687+
688+
X_flat = np.repeat(X, random_weight, axis=0)
689+
y_flat = np.repeat(Y, random_weight, axis=0)
690+
lsvc_flat = svm.LinearSVC(random_state=0).fit(X_flat, y_flat)
691+
pred2 = lsvc_flat.predict(T)
692+
693+
assert_array_equal(pred1, pred2)
694+
assert_allclose(np.linalg.norm(lsvc_unflat.coef_),
695+
np.linalg.norm(lsvc_flat.coef_), 1, 0.0001)
696+
697+
658698
def test_crammer_singer_binary():
659699
# Test Crammer-Singer formulation in the binary case
660700
X, y = make_classification(n_classes=2, random_state=0)

0 commit comments

Comments
 (0)
0