|
56 | 56 | from sklearn.tree import DecisionTreeClassifier
|
57 | 57 | from sklearn.cluster import KMeans
|
58 | 58 | from sklearn.neighbors import KernelDensity
|
| 59 | +from sklearn.neighbors import KNeighborsClassifier |
59 | 60 | from sklearn.metrics import f1_score
|
60 | 61 | from sklearn.metrics import recall_score
|
61 | 62 | from sklearn.metrics import accuracy_score
|
62 | 63 | from sklearn.metrics import make_scorer
|
63 | 64 | from sklearn.metrics import roc_auc_score
|
| 65 | +from sklearn.metrics.pairwise import euclidean_distances |
64 | 66 | from sklearn.impute import SimpleImputer
|
65 | 67 | from sklearn.pipeline import Pipeline
|
66 | 68 | from sklearn.linear_model import Ridge, SGDClassifier, LinearRegression
|
@@ -1798,3 +1800,50 @@ def get_n_splits(self, *args, **kw):
|
1798 | 1800 | 'inconsistent results. Expected \\d+ '
|
1799 | 1801 | 'splits, got \\d+'):
|
1800 | 1802 | ridge.fit(X[:train_size], y[:train_size])
|
| 1803 | + |
| 1804 | + |
| 1805 | +def test_search_cv__pairwise_property_delegated_to_base_estimator(): |
| 1806 | + """ |
| 1807 | + Test implementation of BaseSearchCV has the _pairwise property |
| 1808 | + which matches the _pairwise property of its estimator. |
| 1809 | + This test make sure _pairwise is delegated to the base estimator. |
| 1810 | +
|
| 1811 | + Non-regression test for issue #13920. |
| 1812 | + """ |
| 1813 | + est = BaseEstimator() |
| 1814 | + attr_message = "BaseSearchCV _pairwise property must match estimator" |
| 1815 | + |
| 1816 | + for _pairwise_setting in [True, False]: |
| 1817 | + setattr(est, '_pairwise', _pairwise_setting) |
| 1818 | + cv = GridSearchCV(est, {'n_neighbors': [10]}) |
| 1819 | + assert _pairwise_setting == cv._pairwise, attr_message |
| 1820 | + |
| 1821 | + |
| 1822 | +def test_search_cv__pairwise_property_equivalence_of_precomputed(): |
| 1823 | + """ |
| 1824 | + Test implementation of BaseSearchCV has the _pairwise property |
| 1825 | + which matches the _pairwise property of its estimator. |
| 1826 | + This test ensures the equivalence of 'precomputed'. |
| 1827 | +
|
| 1828 | + Non-regression test for issue #13920. |
| 1829 | + """ |
| 1830 | + n_samples = 50 |
| 1831 | + n_splits = 2 |
| 1832 | + X, y = make_classification(n_samples=n_samples, random_state=0) |
| 1833 | + grid_params = {'n_neighbors': [10]} |
| 1834 | + |
| 1835 | + # defaults to euclidean metric (minkowski p = 2) |
| 1836 | + clf = KNeighborsClassifier() |
| 1837 | + cv = GridSearchCV(clf, grid_params, cv=n_splits) |
| 1838 | + cv.fit(X, y) |
| 1839 | + preds_original = cv.predict(X) |
| 1840 | + |
| 1841 | + # precompute euclidean metric to validate _pairwise is working |
| 1842 | + X_precomputed = euclidean_distances(X) |
| 1843 | + clf = KNeighborsClassifier(metric='precomputed') |
| 1844 | + cv = GridSearchCV(clf, grid_params, cv=n_splits) |
| 1845 | + cv.fit(X_precomputed, y) |
| 1846 | + preds_precomputed = cv.predict(X_precomputed) |
| 1847 | + |
| 1848 | + attr_message = "GridSearchCV not identical with precomputed metric" |
| 1849 | + assert (preds_original == preds_precomputed).all(), attr_message |
0 commit comments