8000 FIX Arbitrary SVC kernels (#11296) · scikit-learn/scikit-learn@84bc8d3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 84bc8d3

Browse files
authored
FIX Arbitrary SVC kernels (#11296)
1 parent 406184e commit 84bc8d3

File tree

3 files changed

+90
-33
lines changed

3 files changed

+90
-33
lines changed

doc/whats_new/v0.23.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ parameters, may produce different models from the previous version. This often
2323
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
2424
random sampling procedures.
2525

26-
- models come here
26+
- list models here
2727

2828
Details are listed in the changelog below.
2929

@@ -210,6 +210,12 @@ Changelog
210210
`probB_`, are now deprecated as they were not useful. :pr:`15558` by
211211
`Thomas Fan`_.
212212

213+
- |Fix| Fix use of custom kernel not taking float entries such as string
214+
kernels in :class:`svm.SVC` and :class:`svm.SVR`. Note that custom kennels
215+
are now expected to validate their input where they previously received
216+
valid numeric arrays.
217+
:pr:`11296` by `Alexandre Gramfort`_ and :user:`Georgi Peev <georgipeev>`.
218+
213219
:mod:`sklearn.tree`
214220
...................
215221

sklearn/svm/_base.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
from ..utils import compute_class_weight
1515
from ..utils.extmath import safe_sparse_dot
1616
from ..utils.validation import check_is_fitted, _check_large_sparse
17-
from ..utils.validation import _check_sample_weight
17+
from ..utils.validation import _num_samples
18+
from ..utils.validation import _check_sample_weight, check_consistent_length
1819
from ..utils.multiclass import check_classification_targets
1920
from ..exceptions import ConvergenceWarning
2021
from ..exceptions import NotFittedError
@@ -143,9 +144,13 @@ def fit(self, X, y, sample_weight=None):
143144
raise TypeError("Sparse precomputed kernels are not supported.")
144145
self._sparse = sparse and not callable(self.kernel)
145146

146-
X, y = check_X_y(X, y, dtype=np.float64,
147-
order='C', accept_sparse='csr',
148-
accept_large_sparse=False)
147+
if callable(self.kernel):
148+
check_consistent_length(X, y)
149+
else:
150+
X, y = check_X_y(X, y, dtype=np.float64,
151+
order='C', accept_sparse='csr',
152+
accept_large_sparse=False)
153+
149154
y = self._validate_targets(y)
150155

151156
sample_weight = np.asarray([]
@@ -154,24 +159,31 @@ def fit(self, X, y, sample_weight=None):
154159
solver_type = LIBSVM_IMPL.index(self._impl)
155160

156161
# input validation
157-
if solver_type != 2 and X.shape[0] != y.shape[0]:
162+
n_samples = _num_samples(X)
163+
if solver_type != 2 and n_samples != y.shape[0]:
158164
raise ValueError("X and y have incompatible shapes.\n" +
159165
"X has %s samples, but y has %s." %
160-
(X.shape[0], y.shape[0]))
166+
(n_samples, y.shape[0]))
161167

162-
if self.kernel == "precomputed" and X.shape[0] != X.shape[1]:
168+
if self.kernel == "precomputed" and n_samples != X.shape[1]:
163169
raise ValueError("Precomputed matrix must be a square matrix."
164170
" Input is a {}x{} matrix."
165171
.format(X.shape[0], X.shape[1]))
166172

167-
if sample_weight.shape[0] > 0 and sample_weight.shape[0] != X.shape[0]:
173+
if sample_weight.shape[0] > 0 and sample_weight.shape[0] != n_samples:
168174
raise ValueError("sample_weight and X have incompatible shapes: "
169175
"%r vs %r\n"
170176
"Note: Sparse matrices cannot be indexed w/"
171177
"boolean masks (use `indices=True` in CV)."
172178
% (sample_weight.shape, X.shape))
173179

174-
if isinstance(self.gamma, str):
180+
kernel = 'precomputed' if callable(self.kernel) else self.kernel
181+
182+
if kernel == 'precomputed':
183+
# unused but needs to be a float for cython code that ignores
184+
# it anyway
185+
self._gamma = 0.
186+
elif isinstance(self.gamma, str):
175187
if self.gamma == 'scale':
176188
# var = E[X^2] - E[X]^2 if sparse
177189
X_var = ((X.multiply(X)).mean() - (X.mean()) ** 2
@@ -187,10 +199,6 @@ def fit(self, X, y, sample_weight=None):
187199
else:
188200
self._gamma = self.gamma
189201

190-
kernel = self.kernel
191-
if callable(kernel):
192-
kernel = 'precomputed'
193-
194202
fit = self._sparse_fit if self._sparse else self._dense_fit
195203
if self.verbose: # pragma: no cover
196204
print('[LibSVM]', end='')
@@ -199,7 +207,7 @@ def fit(self, X, y, sample_weight=None):
199207
fit(X, y, sample_weight, solver_type, kernel, random_seed=seed)
200208
# see comment on the other call to np.iinfo in this file
201209

202-
self.shape_fit_ = X.shape
210+
self.shape_fit_ = X.shape if hasattr(X, "shape") else (n_samples, )
203211

204212
# In binary case, we need to flip the sign of coef, intercept and
205213
# decision function. Use self._intercept_ and self._dual_coef_
@@ -443,8 +451,10 @@ def _sparse_decision_function(self, X):
443451
def _validate_for_predict(self, X):
444452
check_is_fitted(self)
445453

446-
X = check_array(X, accept_sparse='csr', dtype=np.float64, order="C",
447-
accept_large_sparse=False)
454+
if not callable(self.kernel):
455+
X = check_array(X, accept_sparse='csr', dtype=np.float64,
456+
order="C", accept_large_sparse=False)
457+
448458
if self._sparse and not sp.isspmatrix(X):
449459
X = sp.csr_matrix(X)
450460
if self._sparse:
@@ -454,17 +464,16 @@ def _validate_for_predict(self, X):
454464
raise ValueError(
455465
"cannot use sparse input in %r trained on dense data"
456466
% type(self).__name__)
457-
n_samples, n_features = X.shape
458467

459468
if self.kernel == "precomputed":
460469
if X.shape[1] != self.shape_fit_[0]:
461470
raise ValueError("X.shape[1] = %d should be equal to %d, "
462471
"the number of samples at training time" %
463472
(X.shape[1], self.shape_fit_[0]))
464-
elif n_features != self.shape_fit_[1]:
473+
elif not callable(self.kernel) and X.shape[1] != self.shape_fit_[1]:
465474
raise ValueError("X.shape[1] = %d should be equal to %d, "
466475
"the number of features at training time" %
467-
(n_features, self.shape_fit_[1]))
476+
(X.shape[1], self.shape_fit_[1]))
468477
return X
469478

470479
@property
@@ -920,8 +929,8 @@ def _fit_liblinear(X, y, C, fit_intercept, intercept_scaling, class_weight,
920929
bias = -1.0
921930
if fit_intercept:
922931
if intercept_scaling <= 0:
923-
raise ValueError("Intercept scaling is %r but needs to be greater than 0."
924-
" To disable fitting an intercept,"
932+
raise ValueError("Intercept scaling is %r but needs to be greater "
933+
"than 0. To disable fitting an intercept,"
925934
" set fit_intercept=False." % intercept_scaling)
926935
else:
927936
bias = intercept_scaling

sklearn/svm/tests/test_svm.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from sklearn.metrics.pairwise import rbf_kernel
2121
from sklearn.utils import check_random_state
2222
from sklearn.utils._testing import assert_warns
23-
from sklearn.utils._testing import assert_warns_message, assert_raise_message
23+
from sklearn.utils._testing import assert_raise_message
2424
from sklearn.utils._testing import ignore_warnings
2525
from sklearn.utils._testing import assert_no_warnings
26+
from sklearn.utils.validation import _num_samples
2627
from sklearn.utils import shuffle
2728
from sklearn.exceptions import ConvergenceWarning
2829
from sklearn.exceptions import NotFittedError, UndefinedMetricWarning
@@ -125,7 +126,7 @@ def test_precomputed():
125126

126127
kfunc = lambda x, y: np.dot(x, y.T)
127128
clf = svm.SVC(kernel=kfunc)
128-
clf.fit(X, Y)
129+
clf.fit(np.array(X), Y)
129130
pred = clf.predict(T)
130131

131132
assert_array_equal(clf.dual_coef_, [[-0.25, .25]])
@@ -542,8 +543,8 @@ def test_negative_weights_svc_leave_just_one_label(Classifier,
542543

543544
@pytest.mark.parametrize(
544545
"Classifier, model",
545-
[(svm.SVC, {'when-left': [0.3998, 0.4], 'when-right': [0.4, 0.3999]}),
546-
(svm.NuSVC, {'when-left': [0.3333, 0.3333],
546+
[(svm.SVC, {'when-left': [0.3998, 0.4], 'when-right': [0.4, 0.3999]}),
547+
(svm.NuSVC, {'when-left': [0.3333, 0.3333],
547548
'when-right': [0.3333, 0.3333]})],
548549
ids=['SVC', 'NuSVC']
549550
)
@@ -681,9 +682,9 @@ def test_unicode_kernel():
681682
clf.fit(X, Y)
682683
clf.predict_proba(T)
683684
_libsvm.cross_validation(iris.data,
684-
iris.target.astype(np.float64), 5,
685-
kernel='linear',
686-
random_seed=0)
685+
iris.target.astype(np.float64), 5,
686+
kernel='linear',
687+
random_seed=0)
687688

688689

689690
def test_sparse_precomputed():
@@ -980,7 +981,7 @@ def test_svc_bad_kernel():
980981
def test_timeout():
981982
a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True,
982983
random_state=0, max_iter=1)
983-
assert_warns(ConvergenceWarning, a.fit, X, Y)
984+
assert_warns(ConvergenceWarning, a.fit, np.array(X), Y)
984985

985986

986987
def test_unfitted():
@@ -1026,8 +1027,9 @@ def test_svr_coef_sign():
10261027
for svr in [svm.SVR(kernel='linear'), svm.NuSVR(kernel='linear'),
10271028
svm.LinearSVR()]:
10281029
svr.fit(X, y)
1029-
assert_array_almost_equal(svr.predict(X),
1030-
np.dot(X, svr.coef_.ravel()) + svr.intercept_)
1030+
assert_array_almost_equal(
1031+
svr.predict(X), np.dot(X, svr.coef_.ravel()) + svr.intercept_
1032+
)
10311033

10321034

10331035
def test_linear_svc_intercept_scaling():
@@ -1094,7 +1096,7 @@ def test_ovr_decision_function():
10941096
base_points * [-1, 1], # Q2
10951097
base_points * [-1, -1], # Q3
10961098
base_points * [1, -1] # Q4
1097-
))
1099+
))
10981100

10991101
y_test = [0] * 2 + [1] * 2 + [2] * 2 + [3] * 2
11001102

@@ -1248,3 +1250,43 @@ def test_svm_probA_proB_deprecated(SVMClass, data, deprecated_prob):
12481250
"removed in version 0.25.").format(deprecated_prob)
12491251
with pytest.warns(FutureWarning, match=msg):
12501252
getattr(clf, deprecated_prob)
1253+
1254+
1255+
@pytest.mark.parametrize("Estimator", [svm.SVC, svm.SVR])
1256+
def test_custom_kernel_not_array_input(Estimator):
1257+
"""Test using a custom kernel that is not fed with array-like for floats"""
1258+
data = ["A A", "A", "B", "B B", "A B"]
1259+
X = np.array([[2, 0], [1, 0], [0, 1], [0, 2], [1, 1]]) # count encoding
1260+
y = np.array([1, 1, 2, 2, 1])
1261+
1262+
def string_kernel(X1, X2):
1263+
assert isinstance(X1[0], str)
1264+
n_samples1 = _num_samples(X1)
1265+
n_samples2 = _num_samples(X2)
1266+
K = np.zeros((n_samples1, n_samples2))
1267+
for ii in range(n_samples1):
1268+
for jj in range(ii, n_samples2):
1269+
K[ii, jj] = X1[ii].count('A') * X2[jj].count('A')
1270+
K[ii, jj] += X1[ii].count('B') * X2[jj].count('B')
1271+
K[jj, ii] = K[ii, jj]
1272+
return K
1273+
1274+
K = string_kernel(data, data)
1275+
assert_array_equal(np.dot(X, X.T), K)
1276+
1277+
svc1 = Estimator(kernel=string_kernel).fit(data, y)
1278+
svc2 = Estimator(kernel='linear').fit(X, y)
1279+
svc3 = Estimator(kernel='precomputed').fit(K, y)
1280+
1281+
assert svc1.score(data, y) == svc3.score(K, y)
1282+
assert svc1.score(data, y) == svc2.score(X, y)
1283+
if hasattr(svc1, 'decision_function'): # classifier
1284+
assert_allclose(svc1.decision_function(data),
1285+
svc2.decision_function(X))
1286+
assert_allclose(svc1.decision_function(data),
1287+
svc3.decision_function(K))
1288+
assert_array_equal(svc1.predict(data), svc2.predict(X))
1289+
assert_array_equal(svc1.predict(data), svc3.predict(K))
1290+
else: # regressor
1291+
assert_allclose(svc1.predict(data), svc2.predict(X))
1292+
assert_allclose(svc1.predict(data), svc3.predict(K))

0 commit comments

Comments
 (0)
0