10000 Merge pull request #3931 from dsullivan7/cwwarn · scikit-learn/scikit-learn@69d393a · GitHub
[go: up one dir, main page]

Skip to content

Commit 69d393a

Browse files
committed
Merge pull request #3931 from dsullivan7/cwwarn
[MRG+1] raising warning for class_weight in fit method
2 parents f2c200e + 6cb191c commit 69d393a

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

sklearn/linear_model/stochastic_gradient.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
10000
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import scipy.sparse as sp
9+
import warnings
910

1011
from abc import ABCMeta, abstractmethod
1112

@@ -543,12 +544,19 @@ def fit(self, X, y, coef_init=None, intercept_init=None,
543544
544545
sample_weight : array-like, shape (n_samples,), optional
545546
Weights applied to individual samples.
546-
If not provided, uniform weights are assumed.
547+
If not provided, uniform weights are assumed. These weights will
548+
be multiplied with class_weight (passed through the
549+
contructor) if class_weight is specified
547550
548551
Returns
549552
-------
550553
self : returns an instance of self.
551554
"""
555+
if class_weight is not None:
556+
warnings.warn("You are trying to set class_weight through the fit "
557+
"method, which will be deprecated in version "
558+
"v0.17 of scikit-learn. Pass the class_weight into "
559+
"the constructor instead.", DeprecationWarning)
552560
return self._fit(X, y, alpha=self.alpha, C=1.0,
553561
loss=self.loss, learning_rate=self.learning_rate,
554562
coef_init=coef_init, intercept_init=intercept_init,

sklearn/linear_model/tests/test_sgd.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.utils.testing import assert_false, assert_true
1515
from sklearn.utils.testing import assert_equal
1616
from sklearn.utils.testing import assert_raises_regexp
17+
from sklearn.utils.testing import assert_warns_message
1718

1819
from sklearn import linear_model, datasets, metrics
1920
from sklearn.base import clone
@@ -597,6 +598,37 @@ def test_wrong_class_weight_format(self):
597598
clf = self.factory(alpha=0.1, n_iter=1000, class_weight=[0.5])
598599
clf.fit(X, Y)
599600

601+
def test_class_weight_warning(self):
602+
"""Tests that class_weight passed through fit raises warning.
603+
This test should be removed after deprecating support for this"""
604+
605+
clf = self.factory()
606+
warning_message = ("You are trying to set class_weight through the "
607+
"fit "
608+
"method, which will be deprecated in version "
609+
"v0.17 of scikit-learn. Pass the class_weight into "
610+
"the constructor instead.")
611+
assert_warns_message(DeprecationWarning,
612+
warning_message,
613+
clf.fit, X4, Y4,
614+
class_weight=1)
615+
616+
def test_weights_multiplied(self):
617+
"""Tests that class_weight and sample_weight are multiplicative"""
618+
class_weights = {1: .6, 2: .3}
619+
sample_weights = np.random.random(Y4.shape[0])
620+
multiplied_together = np.copy(sample_weights)
621+
multiplied_together[Y4 == 1] *= class_weights[1]
622+
multiplied_together[Y4 == 2] *= class_weights[2]
623+
624+
clf1 = self.factory(alpha=0.1, n_iter=20, class_weight=class_weights)
625+
clf2 = self.factory(alpha=0.1, n_iter=20)
626+
627+
clf1.fit(X4, Y4, sample_weight=sample_weights)
628+
clf2.fit(X4, Y4, sample_weight=multiplied_together)
629+
630+
assert_array_equal(clf1.coef_, clf2.coef_)
631+
600632
def test_auto_weight(self):
601633
"""Test class weights for imbalanced data"""
602634
# compute reference metrics on iris dataset that is quite balanced by

0 commit comments

Comments
 (0)
0