|
3 | 3 | import numpy as np
|
4 | 4 | import pytest
|
5 | 5 |
|
| 6 | +from scipy.sparse import issparse |
6 | 7 | from sklearn.utils._testing import assert_warns
|
7 | 8 | from sklearn.utils._testing import assert_no_warnings
|
8 | 9 | from sklearn.semi_supervised import _label_propagation as label_propagation
|
9 | 10 | from sklearn.metrics.pairwise import rbf_kernel
|
| 11 | +from sklearn.model_selection import train_test_split |
| 12 | +from sklearn.neighbors import NearestNeighbors |
10 | 13 | from sklearn.datasets import make_classification
|
11 | 14 | from sklearn.exceptions import ConvergenceWarning
|
12 | 15 | from numpy.testing import assert_array_almost_equal
|
@@ -152,3 +155,39 @@ def test_convergence_warning():
|
152 | 155 |
|
153 | 156 | mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=500)
|
154 | 157 | assert_no_warnings(mdl.fit, X, y)
|
| 158 | + |
| 159 | + |
| 160 | +def test_predict_sparse_callable_kernel(): |
| 161
8000
| + # This is a non-regression test for #15866 |
| 162 | + |
| 163 | + # Custom sparse kernel (top-K RBF) |
| 164 | + def topk_rbf(X, Y=None, n_neighbors=10, gamma=1e-5): |
| 165 | + nn = NearestNeighbors(n_neighbors=10, metric='euclidean', n_jobs=-1) |
| 166 | + nn.fit(X) |
| 167 | + W = -1 * nn.kneighbors_graph(Y, mode='distance').power(2) * gamma |
| 168 | + np.exp(W.data, out=W.data) |
| 169 | + assert issparse(W) |
| 170 | + return W.T |
| 171 | + |
| 172 | + n_classes = 4 |
| 173 | + n_samples = 500 |
| 174 | + n_test = 10 |
| 175 | + X, y = make_classification(n_classes=n_classes, |
| 176 | + n_samples=n_samples, |
| 177 | + n_features=20, |
| 178 | + n_informative=20, |
| 179 | + n_redundant=0, |
| 180 | + n_repeated=0, |
| 181 | + random_state=0) |
| 182 | + |
| 183 | + X_train, X_test, y_train, y_test = train_test_split(X, y, |
| 184 | + test_size=n_test, |
| 185 | + random_state=0) |
| 186 | + |
| 187 | + model = label_propagation.LabelSpreading(kernel=topk_rbf) |
| 188 | + model.fit(X_train, y_train) |
| 189 | + assert model.score(X_test, y_test) >= 0.9 |
| 190 | + |
| 191 | + model = label_propagation.LabelPropagation(kernel=topk_rbf) |
| 192 | + model.fit(X_train, y_train) |
| 193 | + assert model.score(X_test, y_test) >= 0.9 |
0 commit comments