@@ -49,18 +49,22 @@ def check_svm_model_equal(dense_svm, sparse_svm, X_train, y_train, X_test):
49
49
assert_true (sparse .issparse (sparse_svm .dual_coef_ ))
50
50
assert_array_almost_equal (dense_svm .support_vectors_ ,
51
51
sparse_svm .support_vectors_ .toarray ())
52
- assert_array_almost_equal (dense_svm .dual_coef_ , sparse_svm .dual_coef_ .toarray ())
52
+ assert_array_almost_equal (dense_svm .dual_coef_ ,
53
+ sparse_svm .dual_coef_ .toarray ())
53
54
if dense_svm .kernel == "linear" :
54
55
assert_true (sparse .issparse (sparse_svm .coef_ ))
55
- assert_array_almost_equal (dense_svm .coef_ , sparse_svm .coef_ .toarray ())
56
+ assert_array_almost_equal (dense_svm .coef_ ,
57
+ sparse_svm .coef_ .toarray ())
56
58
assert_array_almost_equal (dense_svm .support_ , sparse_svm .support_ )
57
- assert_array_almost_equal (dense_svm .predict (X_test_dense ), sparse_svm .predict (X_test ))
59
+ assert_array_almost_equal (dense_svm .predict (X_test_dense ),
60
+ sparse_svm .predict (X_test ))
58
61
assert_array_almost_equal (dense_svm .decision_function (X_test_dense ),
59
62
sparse_svm .decision_function (X_test ))
60
63
assert_array_almost_equal (dense_svm .decision_function (X_test_dense ),
61
64
sparse_svm .decision_function (X_test_dense ))
62
- if isinstance (dense_svm , svm .OneClassSVM ):
63
- msg = "cannot use sparse input in 'OneClassSVM' trained on dense data"
65
+ if isinstance (dense_svm , (svm .OneClassSVM , svm .SVDD )):
66
+ msg = "cannot use sparse input in '%s' trained on dense data" \
67
+ % (dense_svm .__class__ .__name__ ,)
64
68
else :
65
69
assert_array_almost_equal (dense_svm .predict_proba (X_test_dense ),
66
70
sparse_svm .predict_proba (X_test ), 4 )
@@ -278,6 +282,24 @@ def test_sparse_oneclasssvm():
278
282
check_svm_model_equal (clf , sp_clf , * dataset )
279
283
280
284
285
+ def test_sparse_svdd ():
286
+ """Check that sparse SVDD gives the same result as dense SVDD
287
+ """
288
+ # many class dataset:
289
+ X_blobs , _ = make_blobs (n_samples = 100 , centers = 10 , random_state = 0 )
290
+ X_blobs = sparse .csr_matrix (X_blobs )
291
+
292
+ datasets = [[X_sp , None , T ], [X2_sp , None , T2 ],
293
+ [X_blobs [:80 ], None , X_blobs [80 :]],
294
+ [iris .data , None , iris .data ]]
295
+ kernels = ["linear" , "poly" , "rbf" , "sigmoid" ]
296
+ for dataset in datasets :
297
+ for kernel in kernels :
298
+ clf = svm .SVDD (kernel = kernel , random_state = 0 )
299
+ sp_clf = svm .SVDD (kernel = kernel , random_state = 0 )
300
+ check_svm_model_equal (clf , sp_clf , * dataset )
301
+
302
+
281
303
def test_sparse_realdata ():
282
304
# Test on a subset from the 20newsgroups dataset.
283
305
# This catches some bugs if input is not correctly converted into
0 commit comments