|
23 | 23 | from sklearn.utils._testing import assert_raise_message
|
24 | 24 | from sklearn.utils._testing import ignore_warnings
|
25 | 25 | from sklearn.utils._testing import assert_no_warnings
|
| 26 | +from sklearn.utils.validation import _num_samples |
26 | 27 | from sklearn.utils import shuffle
|
27 | 28 | from sklearn.exceptions import ConvergenceWarning
|
28 | 29 | from sklearn.exceptions import NotFittedError, UndefinedMetricWarning
|
@@ -125,7 +126,7 @@ def test_precomputed():
|
125 | 126 |
|
126 | 127 | kfunc = lambda x, y: np.dot(x, y.T)
|
127 | 128 | clf = svm.SVC(kernel=kfunc)
|
128 |
| - clf.fit(X, Y) |
| 129 | + clf.fit(np.array(X), Y) |
129 | 130 | pred = clf.predict(T)
|
130 | 131 |
|
131 | 132 | assert_array_equal(clf.dual_coef_, [[-0.25, .25]])
|
@@ -980,7 +981,7 @@ def test_svc_bad_kernel():
|
980 | 981 | def test_timeout():
|
981 | 982 | a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True,
|
982 | 983 | random_state=0, max_iter=1)
|
983 |
| - assert_warns(ConvergenceWarning, a.fit, X, Y) |
| 984 | + assert_warns(ConvergenceWarning, a.fit, np.array(X), Y) |
984 | 985 |
|
985 | 986 |
|
986 | 987 | def test_unfitted():
|
@@ -1250,30 +1251,33 @@ def test_svm_probA_proB_deprecated(SVMClass, data, deprecated_prob):
|
1250 | 1251 | getattr(clf, deprecated_prob)
|
1251 | 1252 |
|
1252 | 1253 |
|
1253 |
| -def test_callable_kernel(): |
1254 |
| - data = ["foo", "foof", "b", "a", "qwert", "1234567890", "abcde", "bar", "", "q"] |
1255 |
| - targets = [1, 1, 2, 2, 1, 3, 1, 1, 2, 2] |
1256 |
| - targets = np.array(targets) |
1257 |
| - |
1258 |
| - def string_kernel(X, X2): |
1259 |
| - assert isinstance(X[0], str) |
1260 |
| - len = _num_samples(X) |
1261 |
| - len2 = _num_samples(X2) |
1262 |
| - ret = np.zeros((len, len2)) |
1263 |
| - smaller = np.min(ret.shape) |
1264 |
| - ret[np.arange(smaller), np.arange(smaller)] = 1 |
1265 |
| - return ret |
1266 |
| - |
1267 |
| - svc = svm.SVC(kernel=string_kernel) |
1268 |
| - svc.fit(data, targets) |
1269 |
| - svc.score(data, targets) |
1270 |
| - svc.score(np.array(data), targets) |
1271 |
| - |
1272 |
| - svc.fit(np.array(data), targets) |
1273 |
| - svc.score(data, targets) |
1274 |
| - svc.score(np.array(data), targets) |
1275 |
| - |
1276 |
| - |
1277 |
| -def test_string_kernel(): |
1278 |
| - # meaningful string kernel test |
1279 |
| - assert True |
| 1254 | +def test_custom_kernel_not_array_input(): |
| 1255 | + """Test using a custom kernel that is not fed with array-like for floats""" |
| 1256 | + data = ["A A", "A", "B", "B B", "A B"] |
| 1257 | + X = np.array([[2, 0], [1, 0], [0, 1], [0, 2], [1, 1]]) # count encoding |
| 1258 | + y = np.array([1, 1, 2, 2, 1]) |
| 1259 | + |
| 1260 | + def string_kernel(X1, X2): |
| 1261 | + assert isinstance(X1[0], str) |
| 1262 | + n_samples1 = _num_samples(X1) |
| 1263 | + n_samples2 = _num_samples(X2) |
| 1264 | + K = np.zeros((n_samples1, n_samples2)) |
| 1265 | + for ii in range(n_samples1): |
| 1266 | + for jj in range(ii, n_samples2): |
| 1267 | + K[ii, jj] = X1[ii].count('A') * X2[jj].count('A') |
| 1268 | + K[ii, jj] += X1[ii].count('B') * X2[jj].count('B') |
| 1269 | + K[jj, ii] = K[ii, jj] |
| 1270 | + return K |
| 1271 | + |
| 1272 | + K = string_kernel(data, data) |
| 1273 | + assert_array_equal(np.dot(X, X.T), K) |
| 1274 | + |
| 1275 | + svc1 = svm.SVC(kernel=string_kernel).fit(data, y) |
| 1276 | + svc2 = svm.SVC(kernel='linear').fit(X, y) |
| 1277 | + svc3 = svm.SVC(kernel='precomputed').fit(K, y) |
| 1278 | + |
| 1279 | + assert svc1.score(data, y) == svc3.score(K, y) |
| 1280 | + assert svc1.score(data, y) == svc2.score(X, y) |
| 1281 | + assert_array_almost_equal(svc1.decision_function(data), |
| 1282 | + svc2.decision_function(X)) |
| 1283 | + assert_array_equal(svc1.predict(data), svc2.predict(X)) |
0 commit comments