8000 [MRG+2] Fix SVC predict_proba fails with new-style kernel strings (#1… · scikit-learn/scikit-learn@74b69df · GitHub
[go: up one dir, main page]

Skip to content

Commit 74b69df

Browse files
qmickqinhanmin2014
authored andcommitted
[MRG+2] Fix SVC predict_proba fails with new-style kernel strings (#10412)
1 parent 2e85c86 commit 74b69df

File tree

3 files changed

+16
-27
lines changed

3 files changed

+16
-27
lines changed

sklearn/svm/base.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,6 @@ def _dense_fit(self, X, y, sample_weight, solver_type, kernel,
228228

229229
libsvm.set_verbosity_wrap(self.verbose)
230230

231-
if six.PY2:
232-
# In python2 ensure kernel is ascii bytes to prevent a TypeError
233-
if isinstance(kernel, six.types.UnicodeType):
234-
kernel = str(kernel)
235-
if six.PY3:
236-
# In python3 ensure kernel is utf8 unicode to prevent a TypeError
237-
if isinstance(kernel, bytes):
238-
kernel = str(kernel, 'utf8')
239-
240231
# we don't pass **self.get_params() to allow subclasses to
241232
# add other parameters to __init__
242233
self.support_, self.support_vectors_, self.n_support_, \

sklearn/svm/libsvm.pyx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ LIBSVM_KERNEL_TYPES = ['linear', 'poly', 'rbf', 'sigmoid', 'precomputed']
5454
def fit(
5555
np.ndarray[np.float64_t, ndim=2, mode='c'] X,
5656
np.ndarray[np.float64_t, ndim=1, mode='c'] Y,
57-
int svm_type=0, str kernel='rbf', int degree=3,
57+
int svm_type=0, kernel='rbf', int degree=3,
5858
double gamma=0.1, double coef0=0., double tol=1e-3,
5959
double C=1., double nu=0.5, double epsilon=0.1,
6060
np.ndarray[np.float64_t, ndim=1, mode='c']
@@ -342,7 +342,7 @@ def predict_proba(
342342
np.ndarray[np.float64_t, ndim=1, mode='c'] intercept,
343343
np.ndarray[np.float64_t, ndim=1, mode='c'] probA=np.empty(0),
344344
np.ndarray[np.float64_t, ndim=1, mode='c'] probB=np.empty(0),
345-
int svm_type=0, str kernel='rbf', int degree=3,
345+
int svm_type=0, kernel='rbf', int degree=3,
346346
double gamma=0.1, double coef0=0.,
347347
np.ndarray[np.float64_t, ndim=1, mode='c']
348348
class_weight=np.empty(0),
@@ -462,7 +462,7 @@ def decision_function(
462462
def cross_validation(
463463
np.ndarray[np.float64_t, ndim=2, mode='c'] X,
464464
np.ndarray[np.float64_t, ndim=1, mode='c'] Y,
465-
int n_fold, svm_type=0, str kernel='rbf', int degree=3,
465+
int n_fold, svm_type=0, kernel='rbf', int degree=3,
466466
double gamma=0.1, double coef0=0., double tol=1e-3,
467467
double C=1., double nu=0.5, double epsilon=0.1,
468468
np.ndarray[np.float64_t, ndim=1, mode='c']

sklearn/svm/tests/test_svm.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -505,27 +505,25 @@ def test_bad_input():
505505

506506

507507
def test_unicode_kernel():
508-
# Test that a unicode kernel name does not cause a TypeError on clf.fit
508+
# Test that a unicode kernel name does not cause a TypeError
509509
if six.PY2:
510510
# Test unicode (same as str on python3)
511-
clf = svm.SVC(kernel=unicode('linear'))
512-
clf.fit(X, Y)
513-
514-
# Test ascii bytes (str is bytes in python2)
515-
clf = svm.SVC(kernel=str('linear'))
516-
clf.fit(X, Y)
517-
else:
518-
# Test unicode (str is unicode in python3)
519-
clf = svm.SVC(kernel=str('linear'))
520-
clf.fit(X, Y)
521-
522-
# Test ascii bytes (same as str on python2)
523-
clf = svm.SVC(kernel=bytes('linear', 'ascii'))
511+
clf = svm.SVC(kernel=u'linear', probability=True)
524512
clf.fit(X, Y)
513+
clf.predict_proba(T)
514+
svm.libsvm.cross_validation(iris.data,
515+
iris.target.astype(np.float64), 5,
516+
kernel=u'linear',
517+
random_seed=0)
525518

526519
# Test default behavior on both versions
527-
clf = svm.SVC(kernel='linear')
520+
clf = svm.SVC(kernel='linear', probability=True)
528521
clf.fit(X, Y)
522+
clf.predict_proba(T)
523+
svm.libsvm.cross_validation(iris.data,
524+
iris.target.astype(np.float64), 5,
525+
kernel='linear',
526+
random_seed=0)
529527

530528

531529
def test_sparse_precomputed():

0 commit comments

Comments
 (0)
0