8000 TST added sparse-dense test for the SVDD · scikit-learn/scikit-learn@2bf41cb · GitHub
[go: up one dir, main page]

Skip to content

Commit 2bf41cb

Browse files
committed
TST added sparse-dense test for the SVDD
1 parent efc7ba5 commit 2bf41cb

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

sklearn/svm/tests/test_sparse.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,22 @@ def check_svm_model_equal(dense_svm, sparse_svm, X_train, y_train, X_test):
4949
assert_true(sparse.issparse(sparse_svm.dual_coef_))
5050
assert_array_almost_equal(dense_svm.support_vectors_,
5151
sparse_svm.support_vectors_.toarray())
52-
assert_array_almost_equal(dense_svm.dual_coef_, sparse_svm.dual_coef_.toarray())
52+
assert_array_almost_equal(dense_svm.dual_coef_,
53+
sparse_svm.dual_coef_.toarray())
5354
if dense_svm.kernel == "linear":
5455
assert_true(sparse.issparse(sparse_svm.coef_))
55-
assert_array_almost_equal(dense_svm.coef_, sparse_svm.coef_.toarray())
56+
assert_array_almost_equal(dense_svm.coef_,
57+
sparse_svm.coef_.toarray())
5658
assert_array_almost_equal(dense_svm.support_, sparse_svm.support_)
57-
assert_array_almost_equal(dense_svm.predict(X_test_dense), sparse_svm.predict(X_test))
59+
assert_array_almost_equal(dense_svm.predict(X_test_dense),
60+
sparse_svm.predict(X_test))
5861
assert_array_almost_equal(dense_svm.decision_function(X_test_dense),
5962
sparse_svm.decision_function(X_test))
6063
assert_array_almost_equal(dense_svm.decision_function(X_test_dense),
6164
sparse_svm.decision_function(X_test_dense))
62-
if isinstance(dense_svm, svm.OneClassSVM):
63-
msg = "cannot use sparse input in 'OneClassSVM' trained on dense data"
65+
if isinstance(dense_svm, (svm.OneClassSVM, svm.SVDD)):
66+
msg = "cannot use sparse input in '%s' trained on dense data" \
67+
% (dense_svm.__class__.__name__,)
6468
else:
6569
assert_array_almost_equal(dense_svm.predict_proba(X_test_dense),
6670
sparse_svm.predict_proba(X_test), 4)
@@ -278,6 +282,24 @@ def test_sparse_oneclasssvm():
278282
check_svm_model_equal(clf, sp_clf, *dataset)
279283

280284

285+
def test_sparse_svdd():
286+
"""Check that sparse SVDD gives the same result as dense SVDD
287+
"""
288+
# many class dataset:
289+
X_blobs, _ = make_blobs(n_samples=100, centers=10, random_state=0)
290+
X_blobs = sparse.csr_matrix(X_blobs)
291+
292+
datasets = [[X_sp, None, T], [X2_sp, None, T2],
293+
[X_blobs[:80], None, X_blobs[80:]],
294+
[iris.data, None, iris.data]]
295+
kernels = ["linear", "poly", "rbf", "sigmoid"]
296+
for dataset in datasets:
297+
for kernel in kernels:
298+
clf = svm.SVDD(kernel=kernel, random_state=0)
299+
sp_clf = svm.SVDD(kernel=kernel, random_state=0)
300+
check_svm_model_equal(clf, sp_clf, *dataset)
301+
302+
281303
def test_sparse_realdata():
282304
# Test on a subset from the 20newsgroups dataset.
283305
# This catches some bugs if input is not correctly converted into

0 commit comments

Comments
 (0)
0