8000 Make knn kernel undirected. · scikit-learn/scikit-learn@d99d853 · GitHub
[go: up one dir, main page]

Skip to content

Commit d99d853

Browse files
committed
Make knn kernel undirected.
Also, fix/update the tests. Fixes #8008.
1 parent a4fe183 commit d99d853

File tree

2 files changed

+31
-22
lines changed

2 files changed

+31
-22
lines changed

sklearn/semi_supervised/label_propagation.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,16 @@ def _get_kernel(self, X, y=None):
136136
self.nn_fit = NearestNeighbors(self.n_neighbors,
137137
n_jobs=self.n_jobs).fit(X)
138138
if y is None:
139-
return self.nn_fit.kneighbors_graph(self.nn_fit._fit_X,
140-
self.n_neighbors,
141-
mode='connectivity')
139+
# Nearest neighbors returns a directed matrix.
140+
dir_graph = self.nn_fit.kneighbors_graph(self.nn_fit._fit_X,
141+
self.n_neighbors,
142+
mode='connectivity')
143+
# Making the matrix symmetric
144+
un_graph = dir_graph + dir_graph.T
145+
# Since it is a connectivity matrix, all values should b 10000 e
146+
# either 0 or 1
147+
un_graph[un_graph > 1] = 1
148+
return un_graph
142149
else:
143150
return self.nn_fit.kneighbors(y, return_distance=False)
144151
elif callable(self.kernel):

sklearn/semi_supervised/tests/test_label_propagation.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,8 @@ def test_distribution():
3939
labels = [0, 1, -1]
4040
for estimator, parameters in ESTIMATORS:
4141
clf = estimator(**parameters).fit(samples, labels)
42-
if parameters['kernel'] == 'knn':
43-
continue # unstable test; changes in k-NN ordering break it
44-
assert_array_almost_equal(clf.predict_proba([[1., 0.0]]),
45-
np.array([[1., 0.]]), 2)
46-
else:
47-
assert_array_almost_equal(np.asarray(clf.label_distributions_[2]),
48-
np.array([.5, .5]), 2)
42+
assert_array_almost_equal(np.asarray(clf.label_distributions_[2]),
43+
np.array([.5, .5]), decimal=3)
4944

5045

5146
def test_predict():
@@ -62,20 +57,23 @@ def test_predict_proba():
6257
for estimator, parameters in ESTIMATORS:
6358
clf = estimator(**parameters).fit(samples, labels)
6459
assert_array_almost_equal(clf.predict_proba([[1., 1.]]),
65-
np.array([[0.5, 0.5]]))
60+
np.array([[0.5, 0.5]]), decimal=3)
6661

6762

6863
def test_alpha_deprecation():
6964
X, y = make_classification(n_samples=100)
7065
y[::3] = -1
7166

72-
lp_default = label_propagation.LabelPropagation(kernel='rbf', gamma=0.1)
73-
lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_
67+
for kernel in ['rbf', 'knn']:
68+
lp_default = label_propagation.LabelPropagation(kernel=kernel,
69+
gamma=0.1)
70+
lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_
7471

75-
lp_0 = label_propagation.LabelPropagation(alpha=0, kernel='rbf', gamma=0.1)
76-
lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_
72+
lp_0 = label_propagation.LabelPropagation(alpha=0, kernel=kernel,
73+
gamma=0.1)
74+
lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_
7775

78-
assert_array_equal(lp_default_y, lp_0_y)
76+
assert_array_equal(lp_default_y, lp_0_y)
7977

8078

8179
def test_label_spreading_closed_form():
@@ -94,7 +92,8 @@ def test_label_spreading_closed_form():
9492
expected /= expected.sum(axis=1)[:, np.newaxis]
9593
clf = label_propagation.LabelSpreading(max_iter=10000, alpha=alpha)
9694
clf.fit(X, y)
97-
assert_array_almost_equal(expected, clf.label_distributions_, 4)
95+
assert_array_almost_equal(expected, clf.label_distributions_,
96+
decimal=4)
9897

9998

10099
def test_label_propagation_closed_form():
@@ -139,9 +138,12 @@ def test_convergence_speed():
139138
# This is a non-regression test for #5774
140139
X = np.array([[1., 0.], [0., 1.], [1., 2.5]])
141140
y = np.array([0, 1, -1])
142-
mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=5000)
143-
mdl.fit(X, y)
144141

145-
# this should converge quickly:
146-
assert mdl.n_iter_ < 10
147-
assert_array_equal(mdl.predict(X), [0, 1, 1])
142+
for kernel in ['rbf', 'knn']:
143+
mdl = label_propagation.LabelSpreading(kernel=kernel, max_iter=5000,
144+
n_neighbors=2)
145+
mdl.fit(X, y)
146+
147+
# this should converge quickly:
148+
assert mdl.n_iter_ < 10
149+
assert_array_almost_equal(mdl.predict_proba([[0.5, 0.5]]), [[0.5, 0.5]], decimal=3)

0 commit comments

Comments
 (0)
0