8000 FIX use safe_sparse_dot for callable kernel in LabelSpreading (#15868) · scikit-learn/scikit-learn@d163d5a · GitHub
[go: up one dir, main page]

Skip to content

Commit d163d5a

Browse files
nik-smqinhanmin2014
authored andcommitted
FIX use safe_sparse_dot for callable kernel in LabelSpreading (#15868)
1 parent 74a0874 commit d163d5a

File tree

3 files changed

+49
-1
lines changed

3 files changed

+49
-1
lines changed

doc/whats_new/v0.22.rst

+8
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,14 @@ Changelog
5353
- |Fix| :func:`metrics.classification_report` does no longer ignore the
5454
value of the ``zero_division`` keyword argument. :pr:`15879`
5555
by :user:`Bibhash Chandra Mitra <Bibyutatsu>`.
56+
57+
:mod:`sklearn.semi_supervised`
58+
..............................
59+
60+
- |Fix| :class:`semi_supervised.LabelPropagation` and
61+
:class:`semi_supervised.LabelSpreading` now allow callable kernel function to
62+
return sparse weight matrix.
63+
:pr:`15868` by :user:`Niklas Smedemark-Margulies <nik-sm>`.
5664

5765
:mod:`sklearn.utils`
5866
....................

sklearn/semi_supervised/_label_propagation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,8 @@ class labels
195195
for weight_matrix in weight_matrices])
196196
else:
197197
weight_matrices = weight_matrices.T
198-
probabilities = np.dot(weight_matrices, self.label_distributions_)
198+
probabilities = safe_sparse_dot(
199+
weight_matrices, self.label_distributions_)
199200
normalizer = np.atleast_2d(np.sum(probabilities, axis=1)).T
200201
probabilities /= normalizer
201202
return probabilities

sklearn/semi_supervised/tests/test_label_propagation.py

+39
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import numpy as np
44
import pytest
55

6+
from scipy.sparse import issparse
67
from sklearn.utils._testing import assert_warns
78
from sklearn.utils._testing import assert_no_warnings
89
from sklearn.semi_supervised import _label_propagation as label_propagation
910
from sklearn.metrics.pairwise import rbf_kernel
11+
from sklearn.model_selection import train_test_split
12+
from sklearn.neighbors import NearestNeighbors
1013
from sklearn.datasets import make_classification
1114
from sklearn.exceptions import ConvergenceWarning
1215
from numpy.testing import assert_array_almost_equal
@@ -152,3 +155,39 @@ def test_convergence_warning():
152155

153156
mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=500)
154157
assert_no_warnings(mdl.fit, X, y)
158+
159+
160+
def test_predict_sparse_callable_kernel():
161 8000 +
# This is a non-regression test for #15866
162+
163+
# Custom sparse kernel (top-K RBF)
164+
def topk_rbf(X, Y=None, n_neighbors=10, gamma=1e-5):
165+
nn = NearestNeighbors(n_neighbors=10, metric='euclidean', n_jobs=-1)
166+
nn.fit(X)
167+
W = -1 * nn.kneighbors_graph(Y, mode='distance').power(2) * gamma
168+
np.exp(W.data, out=W.data)
169+
assert issparse(W)
170+
return W.T
171+
172+
n_classes = 4
173+
n_samples = 500
174+
n_test = 10
175+
X, y = make_classification(n_classes=n_classes,
176+
n_samples=n_samples,
177+
n_features=20,
178+
n_informative=20,
179+
n_redundant=0,
180+
n_repeated=0,
181+
random_state=0)
182+
183+
X_train, X_test, y_train, y_test = train_test_split(X, y,
184+
test_size=n_test,
185+
random_state=0)
186+
187+
model = label_propagation.LabelSpreading(kernel=topk_rbf)
188+
model.fit(X_train, y_train)
189+
assert model.score(X_test, y_test) >= 0.9
190+
191+
model = label_propagation.LabelPropagation(kernel=topk_rbf)
192+
model.fit(X_train, y_train)
193+
assert model.score(X_test, y_test) >= 0.9

0 commit comments

Comments
 (0)
0