8000 FIX : sparse SVC clone with callable kernel · scikit-learn/scikit-learn@89d73c7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 89d73c7

Browse files
committed
FIX : sparse SVC clone with callable kernel
1 parent e00044c commit 89d73c7

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

sklearn/svm/base.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,12 @@ def _sparse_fit(self, X, y, sample_weight=None):
270270
"boolean masks (use `indices=True` in CV)."
271271
% (sample_weight.shape, X.shape))
272272

273+
kernel = self.kernel
274+
if hasattr(kernel, '__call__'):
275+
kernel = 'precomputed'
276+
273277
solver_type = LIBSVM_IMPL.index(self.impl)
274-
kernel_type = self._sparse_kernels.index(self.kernel)
278+
kernel_type = self._sparse_kernels.index(kernel)
275279

276280
self.class_weight_, self.class_weight_label_ = \
277281
_get_class_weight(self.class_weight, y)
@@ -378,7 +382,12 @@ def _dense_predict(self, X):
378382

379383
def _sparse_predict(self, X):
380384
X = sp.csr_matrix(X, dtype=np.float64)
381-
kernel_type = self._sparse_kernels.index(self.kernel)
385+
386+
kernel = self.kernel
387+
if hasattr(kernel, '__call__'):
388+
kernel = 'precomputed'
389+
390+
kernel_type = self._sparse_kernels.index(kernel)
382391

383392
C = 0.0 # C is not useful here
384393

@@ -468,7 +477,12 @@ def _compute_kernel(self, X):
468477

469478
def _sparse_predict_proba(self, X):
470479
X.data = np.asarray(X.data, dtype=np.float64, order='C')
471-
kernel_type = self._sparse_kernels.index(self.kernel)
480+
481+
kernel = self.kernel
482+
if hasattr(kernel, '__call__'):
483+
kernel = 'precomputed'
484+
485+
kernel_type = self._sparse_kernels.index(kernel)
472486

473487
return libsvm_sparse.libsvm_sparse_predict_proba(
474488
X.data, X.indices, X.indptr,

sklearn/svm/tests/test_sparse.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
from scipy import linalg
33
from scipy import sparse
4-
from sklearn import datasets, svm, linear_model
4+
from sklearn import datasets, svm, linear_model, base
55
from numpy.testing import assert_array_almost_equal, \
66
assert_array_equal, assert_equal
77

@@ -229,6 +229,16 @@ def test_sparse_scale_C():
229229
assert_true(error_with_scale > 1e-3)
230230

231231

232+
def test_sparse_svc_clone_with_callable_kernel():
233+
a = svm.SVC(C=1, kernel=lambda x, y: x * y.T, probability=True)
234+
b = base.clone(a)
235+
236+
b.fit(X_sp, Y)
237+
b.predict(X_sp)
238+
b.predict_proba(X_sp)
239+
# b.decision_function(X_sp) # XXX : should be supported
240+
241+
232242
if __name__ == '__main__':
233243
import nose
234244
nose.runmodule()

sklearn/svm/tests/test_svm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,9 +661,10 @@ def test_linearsvc_verbose():
661661
os.dup2(stdout, 1) # restore original stdout
662662

663663

664-
def test_svc_pickle_with_callable_kernel():
664+
def test_svc_clone_with_callable_kernel():
665665
a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True)
666666
b = base.clone(a)
667+
667668
b.fit(X, Y)
668669
b.predict(X)
669670
b.predict_proba(X)

0 commit comments

Comments
 (0)
0