@@ -655,6 +655,46 @@ def test_linearsvc_crammer_singer():
655
655
assert_array_almost_equal (dec_func , cs_clf .decision_function (iris .data ))
656
656
657
657
658
+ def test_linearsvc_fit_sampleweight ():
659
+ # check correct result when sample_weight is 1
660
+ # check that SVR(kernel='linear') and LinearSVC() give
661
+ # comparable results
662
+
663
+ # Test basic routines using LinearSVC
664
+ n_samples = len (X )
665
+ unit_weight = np .ones (n_samples )
666
+ clf = svm .LinearSVC (random_state = 0 ).fit (X , Y )
667
+ clf_unitweight = svm .LinearSVC (random_state = 0 ).fit (X , Y ,
668
+ sample_weight = unit_weight )
669
+
670
+ # sanity check, by default should have intercept
671
+ assert_true (clf_unitweight .fit_intercept )
672
+ assert_array_almost_equal (clf_unitweight .intercept_ , [0 ], decimal = 3 )
673
+
674
+ # check if same as sample_weight=None
675
+ assert_array_equal (clf_unitweight .predict (T ), clf .predict (T ))
676
+ assert_allclose (np .linalg .norm (clf .coef_ ),
677
+ np .linalg .norm (clf_unitweight .coef_ ), 1 , 0.0001 )
678
+
679
+ # check that fit(X) = fit([X1, X2, X3],sample_weight = [n1, n2, n3]) where
680
+ # X = X1 repeated n1 times, X2 repeated n2 times and so forth
681
+
682
+ random_state = check_random_state (0 )
683
+ random_weight = random_state .randint (0 , 10 , n_samples )
684
+ lsvc_unflat = svm .LinearSVC (random_state = 0 ).fit (X , Y ,
685
+ sample_weight = random_weight )
686
+ pred1 = lsvc_unflat .predict (T )
687
+
688
+ X_flat = np .repeat (X , random_weight , axis = 0 )
689
+ y_flat = np .repeat (Y , random_weight , axis = 0 )
690
+ lsvc_flat = svm .LinearSVC (random_state = 0 ).fit (X_flat , y_flat )
691
+ pred2 = lsvc_flat .predict (T )
692
+
693
+ assert_array_equal (pred1 , pred2 )
694
+ assert_allclose (np .linalg .norm (lsvc_unflat .coef_ ),
695
+ np .linalg .norm (lsvc_flat .coef_ ), 1 , 0.0001 )
696
+
697
+
658
698
def test_crammer_singer_binary ():
659
699
# Test Crammer-Singer formulation in the binary case
660
700
X , y = make_classification (n_classes = 2 , random_state = 0 )
0 commit comments