8000 [MRG+1] allow callable kernels in cross-validation (#8005) · scikit-learn/scikit-learn@5d0c7f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5d0c7f5

Browse files
amuellerjnothman
authored andcommitted
[MRG+1] allow callable kernels in cross-validation (#8005)
1 parent 6a01e89 commit 5d0c7f5

File tree

5 files changed

+63
-32
lines changed

5 files changed

+63
-32
lines changed

doc/whats_new.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ Enhancements
9999
A ``TypeError`` will be raised for any other kwargs. :issue:`8028`
100100
by :user:`Alexander Booth <alexandercbooth>`.
101101

102+
- :class:`model_selection.GridSearchCV`, :class:`model_selection.RandomizedSearchCV`
103+
and :func:`model_selection.cross_val_score` now allow estimators with callable
104+
kernels which were previously prohibited. :issue:`8005` by `Andreas Müller`_ .
105+
106+
102107
Bug fixes
103108
.........
104109

sklearn/model_selection/tests/test_search.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -443,15 +443,6 @@ def test_grid_search_precomputed_kernel_error_nonsquare():
443443
assert_raises(ValueError, cv.fit, K_train, y_train)
444444

445445

446-
def test_grid_search_precomputed_kernel_error_kernel_function():
447-
# Test that grid search returns an error when using a kernel_function
448-
X_, y_ = make_classification(n_samples=200, n_features=100, random_state=0)
449-
kernel_function = lambda x1, x2: np.dot(x1, x2.T)
450-
clf = SVC(kernel=kernel_function)
451-
cv = GridSearchCV(clf, {'C': [0.1, 1.0]})
452-
assert_raises(ValueError, cv.fit, X_, y_)
453-
454-
455446
class BrokenClassifier(BaseEstimator):
456447
"""Broken classifier that cannot be fit twice"""
457448

sklearn/model_selection/tests/test_validation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,12 @@ def test_cross_val_score_precomputed():
310310
score_precomputed = cross_val_score(svm, linear_kernel, y)
311311
svm = SVC(kernel="linear")
312312
score_linear = cross_val_score(svm, X, y)
313-
assert_array_equal(score_precomputed, score_linear)
313+
assert_array_almost_equal(score_precomputed, score_linear)
314+
315+
# test with callable
316+
svm = SVC(kernel=lambda x, y: np.dot(x, y.T))
317+
score_callable = cross_val_score(svm, X, y)
318+
assert_array_almost_equal(score_precomputed, score_callable)
314319

315320
# Error raised for non-square X
316321
svm = SVC(kernel="precomputed")

sklearn/svm/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,7 @@ def __init__(self, impl, kernel, degree, gamma, coef0,
104104
@property
105105
def _pairwise(self):
106106
# Used by cross_val_score.
107-
kernel = self.kernel
108-
return kernel == "precomputed" or callable(kernel)
107+
return self.kernel == "precomputed"
109108

110109
def fit(self, X, y, sample_weight=None):
111110
"""Fit the SVM model according to the given training data.

sklearn/utils/metaestimators.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -81,31 +81,62 @@ def if_delegate_has_method(delegate):
8181

8282

8383
def _safe_split(estimator, X, y, indices, train_indices=None):
84-
"""Create subset of dataset and properly handle kernels."""
85-
from ..gaussian_process.kernels import Kernel as GPKernel
84+
"""Create subset of dataset and properly handle kernels.
8685
87-
if (hasattr(estimator, 'kernel') and callable(estimator.kernel) and
88-
not isinstance(estimator.kernel, GPKernel)):
89-
# cannot compute the kernel values with custom function
90-
raise ValueError("Cannot use a custom kernel function. "
91-
"Precompute the kernel matrix instead.")
86+
Slice X, y according to indices for cross-validation, but take care of
87+
precomputed kernel-matrices or pairwise affinities / distances.
9288
93-
if not hasattr(X, "shape"):
94-
if getattr(estimator, "_pairwise", False):
89+
If ``estimator._pairwise is True``, X needs to be square and
90+
we slice rows and columns. If ``train_indices`` is not None,
91+
we slice rows using ``indices`` (assumed the test set) and columns
92+
using ``train_indices``, indicating the training set.
93+
94+
Labels y will always be sliced only along the last axis.
95+
96+
Parameters
97+
----------
98+
estimator : object
99+
Estimator to determine whether we should slice only rows or rows and
100+
columns.
101+
102+
X : array-like, sparse matrix or iterable
103+
Data to be sliced. If ``estimator._pairwise is True``,
104+
this needs to be a square array-like or sparse matrix.
105+
106+
y : array-like, sparse matrix or iterable
107+
Targets to be sliced.
108+
109+
indices : array of int
110+
Rows to select from X and y.
111+
If ``estimator._pairwise is True`` and ``train_indices is None``
112+
then ``indices`` will also be used to slice columns.
113+
114+
train_indices : array of int or None, default=None
F438
115+
If ``estimator._pairwise is True`` and ``train_indices is not None``,
116+
then ``train_indices`` will be use to slice the columns of X.
117+
118+
Returns
119+
-------
120+
X_sliced : array-like, sparse matrix or list
121+
Sliced data.
122+
123+
y_sliced : array-like, sparse matrix or list
124+
Sliced targets.
125+
126+
"""
127+
if getattr(estimator, "_pairwise", False):
128+
if not hasattr(X, "shape"):
95129
raise ValueError("Precomputed kernels or affinity matrices have "
96130
"to be passed as arrays or sparse matrices.")
97-
X_subset = [X[index] for index in indices]
98-
else:
99-
if getattr(estimator, "_pairwise", False):
100-
# X is a precomputed square kernel matrix
101-
if X.shape[0] != X.shape[1]:
102-
raise ValueError("X should be a square kernel matrix")
103-
if train_indices is None:
104-
X_subset = X[np.ix_(indices, indices)]
105-
else:
106-
X_subset = X[np.ix_(indices, train_indices)]
131+
# X is a precomputed square kernel matrix
132+
if X.shape[0] != X.shape[1]:
133+
raise ValueError("X should be a square kernel matrix")
134+
if train_indices is None:
135+
X_subset = X[np.ix_(indices, indices)]
107136
else:
108-
X_subset = safe_indexing(X, indices)
137+
X_subset = X[np.ix_(indices, train_indices)]
138+
else:
139+
X_subset = safe_indexing(X, indices)
109140

110141
if y is not None:
111142
y_subset = safe_indexing(y, indices)

0 commit comments

Comments
 (0)
0