8000 Merge pull request #6900 from fishcorn/fishcorn-patch-1 · scikit-learn/scikit-learn@d41b706 · GitHub
[go: up one dir, main page]

Skip to content

Commit d41b706

Browse files
Merge pull request #6900 from fishcorn/fishcorn-patch-1
[MRG+1] Make KernelCenterer a _pairwise operation
2 parents 94faf0e + 055bc4c commit d41b706

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

sklearn/preprocessing/data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,6 +1585,10 @@ def transform(self, K, y=None, copy=True):
15851585

15861586
return K
15871587

1588+
@property
1589+
def _pairwise(self):
1590+
return True
1591+
15881592

15891593
def add_dummy_feature(X, value=1.0):
15901594
"""Augment dataset with an additional dummy feature.

sklearn/preprocessing/tests/test_data.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@
5252
from sklearn.preprocessing.data import PolynomialFeatures
5353
from sklearn.exceptions import DataConversionWarning
5454

55+
from sklearn.pipeline import Pipeline
56+
from sklearn.cross_validation import cross_val_predict
57+
from sklearn.svm import SVR
58+
5559
from sklearn import datasets
5660

5761
iris = datasets.load_iris()
@@ -1370,6 +1374,26 @@ def test_center_kernel():
13701374
assert_array_almost_equal(K_pred_centered, K_pred_centered2)
13711375

13721376

1377+
def test_cv_pipeline_precomputed():
1378+
"""Cross-validate a regression on four coplanar points with the same
1379+
value. Use precomputed kernel to ensure Pipeline with KernelCenterer
1380+
is treated as a _pairwise operation."""
1381+
X = np.array([[3, 0, 0], [0, 3, 0], [0, 0, 3], [1, 1, 1]])
1382+
y_true = np.ones((4,))
1383+
K = X.dot(X.T)
1384+
kcent = KernelCenterer()
1385+
pipeline = Pipeline([("kernel_centerer", kcent), ("svr", SVR())])
1386+
1387+
# did the pipeline set the _pairwise attribute?
1388+
assert_true(pipeline._pairwise)
1389+
1390+
# test cross-validation, score should be almost perfect
1391+
# NB: this test is pretty vacuous -- it's mainly to test integration
1392+
# of Pipeline and KernelCenterer
1393+
y_pred = cross_val_predict(pipeline, K, y_true, cv=4)
1394+
assert_array_almost_equal(y_true, y_pred)
1395+
1396+
13731397
def test_fit_transform():
13741398
rng = np.random.RandomState(0)
13751399
X = rng.random_sample((5, 4))

0 commit comments

Comments
 (0)
0