|
20 | 20 | from sklearn.metrics.pairwise import rbf_kernel
|
21 | 21 | from sklearn.utils import check_random_state
|
22 | 22 | from sklearn.utils._testing import assert_warns
|
23 |
| -from sklearn.utils._testing import assert_warns_message, assert_raise_message |
| 23 | +from sklearn.utils._testing import assert_raise_message |
24 | 24 | from sklearn.utils._testing import ignore_warnings
|
25 | 25 | from sklearn.utils._testing import assert_no_warnings
|
| 26 | +from sklearn.utils.validation import _num_samples |
26 | 27 | from sklearn.utils import shuffle
|
27 | 28 | from sklearn.exceptions import ConvergenceWarning
|
28 | 29 | from sklearn.exceptions import NotFittedError, UndefinedMetricWarning
|
@@ -125,7 +126,7 @@ def test_precomputed():
|
125 | 126 |
|
126 | 127 | kfunc = lambda x, y: np.dot(x, y.T)
|
127 | 128 | clf = svm.SVC(kernel=kfunc)
|
128 |
| - clf.fit(X, Y) |
| 129 | + clf.fit(np.array(X), Y) |
129 | 130 | pred = clf.predict(T)
|
130 | 131 |
|
131 | 132 | assert_array_equal(clf.dual_coef_, [[-0.25, .25]])
|
@@ -542,8 +543,8 @@ def test_negative_weights_svc_leave_just_one_label(Classifier,
|
542 | 543 |
|
543 | 544 | @pytest.mark.parametrize(
|
544 | 545 | "Classifier, model",
|
545 |
| - [(svm.SVC, {'when-left': [0.3998, 0.4], 'when-right': [0.4, 0.3999]}), |
546 |
| - (svm.NuSVC, {'when-left': [0.3333, 0.3333], |
| 546 | + [(svm.SVC, {'when-left': [0.3998, 0.4], 'when-right': [0.4, 0.3999]}), |
| 547 | + (svm.NuSVC, {'when-left': [0.3333, 0.3333], |
547 | 548 | 'when-right': [0.3333, 0.3333]})],
|
548 | 549 | ids=['SVC', 'NuSVC']
|
549 | 550 | )
|
@@ -681,9 +682,9 @@ def test_unicode_kernel():
|
681 | 682 | clf.fit(X, Y)
|
682 | 683 | clf.predict_proba(T)
|
683 | 684 | _libsvm.cross_validation(iris.data,
|
684 |
| - iris.target.astype(np.float64), 5, |
685 |
| - kernel='linear', |
686 |
| - random_seed=0) |
| 685 | + iris.target.astype(np.float64), 5, |
| 686 | + kernel='linear', |
| 687 | + random_seed=0) |
687 | 688 |
|
688 | 689 |
|
689 | 690 | def test_sparse_precomputed():
|
@@ -980,7 +981,7 @@ def test_svc_bad_kernel():
|
980 | 981 | def test_timeout():
|
981 | 982 | a = svm.SVC(kernel=lambda x, y: np.dot(x, y.T), probability=True,
|
982 | 983 | random_state=0, max_iter=1)
|
983 |
| - assert_warns(ConvergenceWarning, a.fit, X, Y) |
| 984 | + assert_warns(ConvergenceWarning, a.fit, np.array(X), Y) |
984 | 985 |
|
985 | 986 |
|
986 | 987 | def test_unfitted():
|
@@ -1026,8 +1027,9 @@ def test_svr_coef_sign():
|
1026 | 1027 | for svr in [svm.SVR(kernel='linear'), svm.NuSVR(kernel='linear'),
|
1027 | 1028 | svm.LinearSVR()]:
|
1028 | 1029 | svr.fit(X, y)
|
1029 |
| - assert_array_almost_equal(svr.predict(X), |
1030 |
| - np.dot(X, svr.coef_.ravel()) + svr.intercept_) |
| 1030 | + assert_array_almost_equal( |
| 1031 | + svr.predict(X), np.dot(X, svr.coef_.ravel()) + svr.intercept_ |
| 1032 | + ) |
1031 | 1033 |
|
1032 | 1034 |
|
1033 | 1035 | def test_linear_svc_intercept_scaling():
|
@@ -1094,7 +1096,7 @@ def test_ovr_decision_function():
|
1094 | 1096 | base_points * [-1, 1], # Q2
|
1095 | 1097 | base_points * [-1, -1], # Q3
|
1096 | 1098 | base_points * [1, -1] # Q4
|
1097 |
| - )) |
| 1099 | + )) |
1098 | 1100 |
|
1099 | 1101 | y_test = [0] * 2 + [1] * 2 + [2] * 2 + [3] * 2
|
1100 | 1102 |
|
@@ -1248,3 +1250,43 @@ def test_svm_probA_proB_deprecated(SVMClass, data, deprecated_prob):
|
1248 | 1250 | "removed in version 0.25.").format(deprecated_prob)
|
1249 | 1251 | with pytest.warns(FutureWarning, match=msg):
|
1250 | 1252 | getattr(clf, deprecated_prob)
|
| 1253 | + |
| 1254 | + |
| 1255 | +@pytest.mark.parametrize("Estimator", [svm.SVC, svm.SVR]) |
| 1256 | +def test_custom_kernel_not_array_input(Estimator): |
| 1257 | + """Test using a custom kernel that is not fed with array-like for floats""" |
| 1258 | + data = ["A A", "A", "B", "B B", "A B"] |
| 1259 | + X = np.array([[2, 0], [1, 0], [0, 1], [0, 2], [1, 1]]) # count encoding |
| 1260 | + y = np.array([1, 1, 2, 2, 1]) |
| 1261 | + |
| 1262 | + def string_kernel(X1, X2): |
| 1263 | + assert isinstance(X1[0], str) |
| 1264 | + n_samples1 = _num_samples(X1) |
| 1265 | + n_samples2 = _num_samples(X2) |
| 1266 | + K = np.zeros((n_samples1, n_samples2)) |
| 1267 | + for ii in range(n_samples1): |
| 1268 | + for jj in range(ii, n_samples2): |
| 1269 | + K[ii, jj] = X1[ii].count('A') * X2[jj].count('A') |
| 1270 | + K[ii, jj] += X1[ii].count('B') * X2[jj].count('B') |
| 1271 | + K[jj, ii] = K[ii, jj] |
| 1272 | + return K |
| 1273 | + |
| 1274 | + K = string_kernel(data, data) |
| 1275 | + assert_array_equal(np.dot(X, X.T), K) |
| 1276 | + |
| 1277 | + svc1 = Estimator(kernel=string_kernel).fit(data, y) |
| 1278 | + svc2 = Estimator(kernel='linear').fit(X, y) |
| 1279 | + svc3 = Estimator(kernel='precomputed').fit(K, y) |
| 1280 | + |
| 1281 | + assert svc1.score(data, y) == svc3.score(K, y) |
| 1282 | + assert svc1.score(data, y) == svc2.score(X, y) |
| 1283 | + if hasattr(svc1, 'decision_function'): # classifier |
| 1284 | + assert_allclose(svc1.decision_function(data), |
| 1285 | + svc2.decision_function(X)) |
| 1286 | + assert_allclose(svc1.decision_function(data), |
| 1287 | + svc3.decision_function(K)) |
| 1288 | + assert_array_equal(svc1.predict(data), svc2.predict(X)) |
| 1289 | + assert_array_equal(svc1.predict(data), svc3.predict(K)) |
| 1290 | + else: # regressor |
| 1291 | + assert_allclose(svc1.predict(data), svc2.predict(X)) |
| 1292 | + assert_allclose(svc1.predict(data), svc3.predict(K)) |
0 commit comments