8000 FIX add pairwise property to basesearchcv (#15524) · scikit-learn/scikit-learn@3ef8357 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3ef8357

Browse files
xun-tangjnothman
authored andcommitted
FIX add pairwise property to basesearchcv (#15524)
1 parent bf5307b commit 3ef8357

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

doc/whats_new/v0.22.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,19 @@ Changelog
740740
:class:`preprocessing.KernelCenterer`
741741
:pr:`14336` by :user:`Gregory Dexter <gdex1>`.
742742

743+
:mod:`sklearn.model_selection`
744+
..................
745+
746+
- |Fix| :class:`model_selection.GridSearchCV` and
747+
`model_selection.RandomizedSearchCV` now supports the
748+
:term:`_pairwise` property, which prevents an error during cross-validation
749+
for estimators with pairwise inputs (such as
750+
:class:`neighbors.KNeighborsClassifier` when :term:`metric` is set to
751+
'precomputed').
752+
:pr:`13925` by :user:`Isaac S. Robson <isrobson>` and :pr:`15524` by
753+
:user:`Xun Tang <xun-tang>`.
754+
755+
743756
:mod:`sklearn.svm`
744757
..................
745758

sklearn/model_selection/_search.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,11 @@ def __init__(self, estimator, scoring=None, n_jobs=None, iid='deprecated',
414414
def _estimator_type(self):
415415
return self.estimator._estimator_type
416416

417+
@property
418+
def _pairwise(self):
419+
# allows cross-validation to see 'precomputed' metrics
420+
return getattr(self.estimator, '_pairwise', False)
421+
417422
def score(self, X, y=None):
418423
"""Returns the score on the given data, if the estimator has been refit.
419424

sklearn/model_selection/tests/test_search.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@
5656
from sklearn.tree import DecisionTreeClassifier
5757
from sklearn.cluster import KMeans
5858
from sklearn.neighbors import KernelDensity
59+
from sklearn.neighbors import KNeighborsClassifier
5960
from sklearn.metrics import f1_score
6061
from sklearn.metrics import recall_score
6162
from sklearn.metrics import accuracy_score
6263
from sklearn.metrics import make_scorer
6364
from sklearn.metrics import roc_auc_score
65+
from sklearn.metrics.pairwise import euclidean_distances
6466
from sklearn.impute import SimpleImputer
6567
from sklearn.pipeline import Pipeline
6668
from sklearn.linear_model import Ridge, SGDClassifier, LinearRegression
@@ -1798,3 +1800,50 @@ def get_n_splits(self, *args, **kw):
17981800
'inconsistent results. Expected \\d+ '
17991801
'splits, got \\d+'):
18001802
ridge.fit(X[:train_size], y[:train_size])
1803+
1804+
1805+
def test_search_cv__pairwise_property_delegated_to_base_estimator():
1806+
"""
1807+
Test implementation of BaseSearchCV has the _pairwise property
1808+
which matches the _pairwise property of its estimator.
1809+
This test make sure _pairwise is delegated to the base estimator.
1810+
1811+
Non-regression test for issue #13920.
1812+
"""
1813+
est = BaseEstimator()
1814+
attr_message = "BaseSearchCV _pairwise property must match estimator"
1815+
1816+
for _pairwise_setting in [True, False]:
1817+
setattr(est, '_pairwise', _pairwise_setting)
1818+
cv = GridSearchCV(est, {'n_neighbors': [10]})
1819+
assert _pairwise_setting == cv._pairwise, attr_message
1820+
1821+
1822+
def test_search_cv__pairwise_property_equivalence_of_precomputed():
1823+
"""
1824+
Test implementation of BaseSearchCV has the _pairwise property
1825+
which matches the _pairwise property of its estimator.
1826+
This test ensures the equivalence of 'precomputed'.
1827+
1828+
Non-regression test for issue #13920.
1829+
"""
1830+
n_samples = 50
1831+
n_splits = 2
1832+
X, y = make_classification(n_samples=n_samples, random_state=0)
1833+
grid_params = {'n_neighbors': [10]}
1834+
1835+
# defaults to euclidean metric (minkowski p = 2)
1836+
clf = KNeighborsClassifier()
1837+
cv = GridSearchCV(clf, grid_params, cv=n_splits)
1838+
cv.fit(X, y)
1839+
preds_original = cv.predict(X)
1840+
1841+
# precompute euclidean metric to validate _pairwise is working
1842+
X_precomputed = euclidean_distances(X)
1843+
clf = KNeighborsClassifier(metric='precomputed')
1844+
cv = GridSearchCV(clf, grid_params, cv=n_splits)
1845+
cv.fit(X_precomputed, y)
1846+
preds_precomputed = cv.predict(X_precomputed)
1847+
1848+
attr_message = "GridSearchCV not identical with precomputed metric"
1849+
assert (preds_original == preds_precomputed).all(), attr_message

0 commit comments

Comments
 (0)
0