8000 fix/improve tests · scikit-learn/scikit-learn@f324b1f · GitHub
[go: up one dir, main page]

Skip to content

Commit f324b1f

Browse files
committed
fix/improve tests
1 parent e009342 commit f324b1f

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

sklearn/svm/_base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def fit(self, X, y, sample_weight=None):
149149
X, y = check_X_y(X, y, dtype=np.float64,
150150
order='C', accept_sparse='csr',
151151
accept_large_sparse=False)
152-
X, y = check_X_y(X, y, dtype=np.float64, order='C', accept_sparse='csr')
153152

154153
y = self._validate_targets(y)
155154

@@ -177,7 +176,13 @@ def fit(self, X, y, sample_weight=None):
177176
"boolean masks (use `indices=True` in CV)."
178177
% (sample_weight.shape, X.shape))
179178

180-
if isinstance(self.gamma, str):
179+
kernel = self.kernel
180+
if callable(kernel):
181+
kernel = 'precomputed'
182+
183+
if kernel == 'precomputed':
184+
self._gamma = 0. # unused but needs to be a float
185+
elif isinstance(self.gamma, str):
181186
if self.gamma == 'scale':
182187
# var = E[X^2] - E[X]^2 if sparse
183188
X_var = ((X.multiply(X)).mean() - (X.mean()) ** 2
@@ -193,10 +198,6 @@ def fit(self, X, y, sample_weight=None):
193198
else:
194199
self._gamma = self.gamma
195200

196-
kernel = self.kernel
197-
if callable(kernel):
198-
kernel = 'precomputed'
199-
200201
fit = self._sparse_fit if self._sparse else self._dense_fit
201202
if self.verbose: # pragma: no cover
202203
print('[LibSVM]', end='')

sklearn/svm/tests/test_svm.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sklearn.utils._testing import assert_raise_message
2424
from sklearn.utils._testing import ignore_warnings
2525
from sklearn.utils._testing import assert_no_warnings
26+
from sklearn.utils.validation import _num_samples
2627
from sklearn.utils import shuffle
2728
from sklearn.exceptions import ConvergenceWarning
2829
from sklearn.exceptions import NotFittedError, UndefinedMetricWarning
@@ -125,7 +126,7 @@ def test_precomputed():
125126

126127
kfunc = lambda x, y: np.dot(x, y.T)
127128
clf = svm.SVC(kernel=kfunc)
128-
clf.fit(X, Y)
129+
clf.fit(np.array(X), Y)
129130
pred = clf.predict(T)
130131

131132
assert_array_equal(clf.dual_coef_, [[-0.25, .25]])
@@ -980,7 +981,7 @@ def test_svc_bad_kernel():
980981
def test_timeout():
981982
a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True,
982983
random_state=0, max_iter=1)
983-
assert_warns(ConvergenceWarning, a.fit, X, Y)
984+
assert_warns(ConvergenceWarning, a.fit, np.array(X), Y)
984985

985986

986987
def test_unfitted():
@@ -1250,30 +1251,33 @@ def test_svm_probA_proB_deprecated(SVMClass, data, deprecated_prob):
12501251
getattr(clf, deprecated_prob)
12511252

12521253

1253-
def test_callable_kernel():
1254-
data = ["foo", "foof", "b", "a", "qwert", "1234567890", "abcde", "bar", "", "q"]
1255-
targets = [1, 1, 2, 2, 1, 3, 1, 1, 2, 2]
1256-
targets = np.array(targets)
1257-
1258-
def string_kernel(X, X2):
1259-
assert isinstance(X[0], str)
1260-
len = _num_samples(X)
1261-
len2 = _num_samples(X2)
1262-
ret = np.zeros((len, len2))
1263-
smaller = np.min(ret.shape)
1264-
ret[np.arange(smaller), np.arange(smaller)] = 1
1265-
return ret
1266-
1267-
svc = svm.SVC(kernel=string_kernel)
1268-
svc.fit(data, targets)
1269-
svc.score(data, targets)
1270-
svc.score(np.array(data), targets)
1271-
1272-
svc.fit(np.array(data), targets)
1273-
svc.score(data, targets)
1274-
svc.score(np.array(data), targets)
1275-
1276-
1277-
def test_string_kernel():
1278-
# meaningful string kernel test
1279-
assert True
1254+
def test_custom_kernel_not_array_input():
1255+
"""Test using a custom kernel that is not fed with array-like for floats"""
1256+
data = ["A A", "A", "B", "B B", "A B"]
1257+
X = np.array([[2, 0], [1, 0], [0, 1], [0, 2], [1, 1]]) # count encoding
1258+
y = np.array([1, 1, 2, 2, 1])
1259+
1260+
def string_kernel(X1, X2):
1261+
assert isinstance(X1[0], str)
1262+
n_samples1 = _num_samples(X1)
1263+
n_samples2 = _num_samples(X2)
1264+
K = np.zeros((n_samples1, n_samples2))
1265+
for ii in range(n_samples1):
1266+
for jj in range(ii, n_samples2):
1267+
K[ii, jj] = X1[ii].count('A') * X2[jj].count('A')
1268+
K[ii, jj] += X1[ii].count('B') * X2[jj].count('B')
1269+
K[jj, ii] = K[ii, jj]
1270+
return K
1271+
1272+
K = string_kernel(data, data)
1273+
assert_array_equal(np.dot(X, X.T), K)
1274+
1275+
svc1 = svm.SVC(kernel=string_kernel).fit(data, y)
1276+
svc2 = svm.SVC(kernel='linear').fit(X, y)
1277+
svc3 = svm.SVC(kernel='precomputed').fit(K, y)
1278+
1279+
assert svc1.score(data, y) == svc3.score(K, y)
1280+
assert svc1.score(data, y) == svc2.score(X, y)
1281+
assert_array_almost_equal(svc1.decision_function(data),
1282+
svc2.decision_function(X))
1283+
assert_array_equal(svc1.predict(data), svc2.predict(X))

0 commit comments

Comments
 (0)
0