|
14 | 14 | from sklearn.utils.testing import assert_false, assert_true
|
15 | 15 | from sklearn.utils.testing import assert_equal
|
16 | 16 | from sklearn.utils.testing import assert_raises_regexp
|
| 17 | +from sklearn.utils.testing import assert_warns_message |
17 | 18 |
|
18 | 19 | from sklearn import linear_model, datasets, metrics
|
19 | 20 | from sklearn.base import clone
|
@@ -597,6 +598,37 @@ def test_wrong_class_weight_format(self):
|
597 | 598 | clf = self.factory(alpha=0.1, n_iter=1000, class_weight=[0.5])
|
598 | 599 | clf.fit(X, Y)
|
599 | 600 |
|
| 601 | + def test_class_weight_warning(self): |
| 602 | + """Tests that class_weight passed through fit raises warning. |
| 603 | + This test should be removed after deprecating support for this""" |
| 604 | + |
| 605 | + clf = self.factory() |
| 606 | + warning_message = ("You are trying to set class_weight through the " |
| 607 | + "fit " |
| 608 | + "method, which will be deprecated in version " |
| 609 | + "v0.17 of scikit-learn. Pass the class_weight into " |
| 610 | + "the constructor instead.") |
| 611 | + assert_warns_message(DeprecationWarning, |
| 612 | + warning_message, |
| 613 | + clf.fit, X4, Y4, |
| 614 | + class_weight=1) |
| 615 | + |
| 616 | + def test_weights_multiplied(self): |
| 617 | + """Tests that class_weight and sample_weight are multiplicative""" |
| 618 | + class_weights = {1: .6, 2: .3} |
| 619 | + sample_weights = np.random.random(Y4.shape[0]) |
| 620 | + multiplied_together = np.copy(sample_weights) |
| 621 | + multiplied_together[Y4 == 1] *= class_weights[1] |
| 622 | + multiplied_together[Y4 == 2] *= class_weights[2] |
| 623 | + |
| 624 | + clf1 = self.factory(alpha=0.1, n_iter=20, class_weight=class_weights) |
| 625 | + clf2 = self.factory(alpha=0.1, n_iter=20) |
| 626 | + |
| 627 | + clf1.fit(X4, Y4, sample_weight=sample_weights) |
| 628 | + clf2.fit(X4, Y4, sample_weight=multiplied_together) |
| 629 | + |
| 630 | + assert_array_equal(clf1.coef_, clf2.coef_) |
| 631 | + |
600 | 632 | def test_auto_weight(self):
|
601 | 633 | """Test class weights for imbalanced data"""
|
602 | 634 | # compute reference metrics on iris dataset that is quite balanced by
|
|
0 commit comments