8000 update what's new + more test · scikit-learn/scikit-learn@4fde439 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4fde439

Browse files
committed
update what's new + more test
1 parent d26c5f1 commit 4fde439

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

doc/whats_new/v0.23.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,14 @@ Changelog
185185
`probB_`, are now deprecated as they were not useful. :pr:`15558` by
186186
`Thomas Fan`_.
187187

188+
- |Fix| Fix use of custom kernel not taking float entries such as string
189+
kernels in :class:`svm.SVC` and :class:`svm.SVR`.
190+
:pr:`11296` by `Alexandre Gramfort`_ and :user:`Georgi Peev <georgipeev>`.
191+
192+
- |API| Do not enforce float entries in X when using custom kernel
193+
in :class:`svm.SVC` and :class:`svm.SVR`.
194+
:pr:`11296` by `Alexandre Gramfort`_ and :user:`Georgi Peev <georgipeev>`.
195+
188196
:mod:`sklearn.tree`
189197
...................
190198

sklearn/svm/tests/test_svm.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,7 +1252,8 @@ def test_svm_probA_proB_deprecated(SVMClass, data, deprecated_prob):
12521252
getattr(clf, deprecated_prob)
12531253

12541254

1255-
def test_custom_kernel_not_array_input():
1255+
@pytest.mark.parametrize("Estimator", [svm.SVC, svm.SVR])
1256+
def test_custom_kernel_not_array_input(Estimator):
12561257
"""Test using a custom kernel that is not fed with array-like for floats"""
12571258
data = ["A A", "A", "B", "B B", "A B"]
12581259
X = np.array([[2, 0], [1, 0], [0, 1], [0, 2], [1, 1]]) # count encoding
@@ -1273,15 +1274,21 @@ def string_kernel(X1, X2):
12731274
K = string_kernel(data, data)
12741275
assert_array_equal(np.dot(X, X.T), K)
12751276

1276-
svc1 = svm.SVC(kernel=string_kernel).fit(data, y)
1277-
svc2 = svm.SVC(kernel='linear').fit(X, y)
1278-
svc3 = svm.SVC(kernel='precomputed').fit(K, y)
1277+
svc1 = Estimator(kernel=string_kernel).fit(data, y)
1278+
svc2 = Estimator(kernel='linear').fit(X, y)
1279+
svc3 = Estimator(kernel='precomputed').fit(K, y)
12791280

12801281
assert svc1.score(data, y) == svc3.score(K, y)
12811282
assert svc1.score(data, y) == svc2.score(X, y)
1282-
assert_array_almost_equal(svc1.decision_function(data),
1283-
svc2.decision_function(X))
1284-
assert_array_almost_equal(svc1.decision_function(data),
1285-
svc3.decision_function(K))
1286-
assert_array_equal(svc1.predict(data), svc2.predict(X))
1287-
assert_array_equal(svc1.predict(data), svc3.predict(K))
1283+
if hasattr(svc1, 'decision_function'): # classifier
1284+
assert_array_almost_equal(svc1.decision_function(data),
1285+
svc2.decision_function(X))
1286+
assert_array_almost_equal(svc1.decision_function(data),
1287+
svc3.decision_function(K))
1288+
assert_array_equal(svc1.predict(data), svc2.predict(X))
1289+
assert_array_equal(svc1.predict(data), svc3.predict(K))
1290+
else: # regressor
1291+
assert_array_almost_equal(svc1.predict(data),
1292+
svc2.predict(X))
1293+
assert_array_almost_equal(svc1.predict(data),
1294+
svc3.predict(K))

0 commit comments

Comments
 (0)
0