8000 FIX Convergence warning and n_iter_ in LabelPropagation (#5893) · CoderPat/scikit-learn@e5b892e · GitHub
[go: up one dir, main page]

Skip to content

Commit e5b892e

Browse files
musically-utjnothman
authored andcommitted
FIX Convergence warning and n_iter_ in LabelPropagation (scikit-learn#5893)
1 parent dc9ab80 commit e5b892e

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

sklearn/semi_supervised/label_propagation.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
>>> from sklearn.semi_supervised import LabelPropagation
3535
>>> label_prop_model = LabelPropagation()
3636
>>> iris = datasets.load_iris()
37-
>>> random_unlabeled_points = np.where(np.random.randint(0, 2,
38-
... size=len(iris.target)))
37+
>>> rng = np.random.RandomState(42)
38+
>>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3
3939
>>> labels = np.copy(iris.target)
4040
>>> labels[random_unlabeled_points] = -1
4141
>>> label_prop_model.fit(iris.data, labels)
@@ -53,6 +53,7 @@
5353
"""
5454

5555
# Authors: Clay Woolam <clay@woolam.org>
56+
# Utkarsh Upadhyay <mail@musicallyut.in>
5657
# License: BSD
5758
from abc import ABCMeta, abstractmethod
5859

@@ -67,13 +68,7 @@
6768
from ..utils.extmath import safe_sparse_dot
6869
from ..utils.multiclass import check_classification_targets
6970
from ..utils.validation import check_X_y, check_is_fitted, check_array
70-
71-
72-
# Helper functions
73-
74-
def _not_converged(y_truth, y_prediction, tol=1e-3):
75-
"""basic convergence check"""
76-
return np.abs(y_truth - y_prediction).sum() > tol
71+
from ..exceptions import ConvergenceWarning
7772

7873

7974
class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
@@ -97,7 +92,7 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
9792
alpha : float
9893
Clamping factor
9994
100-
max_iter : float
95+
max_iter : integer
10196
Change maximum number of iterations allowed
10297
10398
tol : float
@@ -264,12 +259,14 @@ def fit(self, X, y):
264259

265260
l_previous = np.zeros((self.X_.shape[0], n_classes))
266261

267-
remaining_iter = self.max_iter
268262
unlabeled = unlabeled[:, np.newaxis]
269263
if sparse.isspmatrix(graph_matrix):
270264
graph_matrix = graph_matrix.tocsr()
271-
while (_not_converged(self.label_distributions_, l_previous, self. 8000 tol)
272-
and remaining_iter > 1):
265+
266+
for self.n_iter_ in range(self.max_iter):
267+
if np.abs(self.label_distributions_ - l_previous).sum() < self.tol:
268+
break
269+
273270
l_previous = self.label_distributions_
274271
self.label_distributions_ = safe_sparse_dot(
275272
graph_matrix, self.label_distributions_)
@@ -285,7 +282,12 @@ def fit(self, X, y):
285282
# clamp
286283
self.label_distributions_ = np.multiply(
287284
alpha, self.label_distributions_) + y_static
288-
remaining_iter -= 1
285+
else:
286+
warnings.warn(
287+
'max_iter=%d was reached without convergence.' % self.max_iter,
288+
category=ConvergenceWarning
289+
)
290+
self.n_iter_ += 1
289291

290292
normalizer = np.sum(self.label_distributions_, axis=1)[:, np.newaxis]
291293
self.label_distributions_ /= normalizer
@@ -294,7 +296,6 @@ def fit(self, X, y):
294296
transduction = self.classes_[np.argmax(self.label_distributions_,
295297
axis=1)]
296298
self.transduction_ = transduction.ravel()
297-
self.n_iter_ = self.max_iter - remaining_iter
298299
return self
299300

300301

@@ -324,7 +325,7 @@ class LabelPropagation(BaseLabelPropagation):
324325
This parameter will be removed in 0.21.
325326
'alpha' is fixed to zero in 'LabelPropagation'.
326327
327-
max_iter : float
328+
max_iter : integer
328329
Change maximum number of iterations allowed
329330
330331
tol : float
@@ -358,8 +359,8 @@ class LabelPropagation(BaseLabelPropagation):
358359
>>> from sklearn.semi_supervised import LabelPropagation
359360
>>> label_prop_model = LabelPropagation()
360361
>>> iris = datasets.load_iris()
361-
>>> random_unlabeled_points = np.where(np.random.randint(0, 2,
362-
... size=len(iris.target)))
362+
>>> rng = np.random.RandomState(42)
363+
>>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3
363364
>>> labels = np.copy(iris.target)
364365
>>> labels[random_unlabeled_points] = -1
365366
>>> label_prop_model.fit(iris.data, labels)
@@ -441,7 +442,7 @@ class LabelSpreading(BaseLabelPropagation):
441442
alpha=0 means keeping the initial label information; alpha=1 means
442443
replacing all initial information.
443444
444-
max_iter : float
445+
max_iter : integer
445446
maximum number of iterations allowed
446447
447448
tol : float
@@ -475,8 +476,8 @@ class LabelSpreading(BaseLabelPropagation):
475476
>>> from sklearn.semi_supervised import LabelSpreading
476477
>>> label_prop_model = LabelSpreading()
477478
>>> iris = datasets.load_iris()
478-
>>> random_unlabeled_points = np.where(np.random.randint(0, 2,
479-
... size=len(iris.target)))
479+
>>> rng = np.random.RandomState(42)
480+
>>> random_unlabeled_points = rng.rand(len(iris.target)) < 0.3
480481
>>> labels = np.copy(iris.target)
481482
>>> labels[random_unlabeled_points] = -1
482483
>>> label_prop_model.fit(iris.data, labels)

sklearn/semi_supervised/tests/test_label_propagation.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.semi_supervised import label_propagation
1010
from sklearn.metrics.pairwise import rbf_kernel
1111
from sklearn.datasets import make_classification
12+
from sklearn.exceptions import ConvergenceWarning
1213
from numpy.testing import assert_array_almost_equal
1314
from numpy.testing import assert_array_equal
1415

@@ -70,7 +71,7 @@ def test_alpha_deprecation():
7071
y[::3] = -1
7172

7273
lp_default = label_propagation.LabelPropagation(kernel='rbf', gamma=0.1)
73-
lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_
74+
lp_default_y = lp_default.fit(X, y).transduction_
7475

7576
lp_0 = label_propagation.LabelPropagation(alpha=0, kernel='rbf', gamma=0.1)
7677
lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_
@@ -108,7 +109,8 @@ def test_label_propagation_closed_form():
108109
labelled_idx = (Y[:, (-1,)] == 0).nonzero()[0]
109110

110111
clf = label_propagation.LabelPropagation(max_iter=10000,
111-
gamma=0.1).fit(X, y)
112+
gamma=0.1)
113+
clf.fit(X, y)
112114
# adopting notation from Zhu et al 2002
113115
T_bar = clf._build_graph()
114116
Tuu = T_bar[np.meshgrid(unlabelled_idx, unlabelled_idx, indexing='ij')]
@@ -145,3 +147,22 @@ def test_convergence_speed():
145147
# this should converge quickly:
146148
assert mdl.n_iter_ < 10
147149
assert_array_equal(mdl.predict(X), [0, 1, 1])
150+
151+
152+
def test_convergence_warning():
153+
# This is a non-regression test for #5774
154+
X = np.array([[1., 0.], [0., 1.], [1., 2.5]])
155+
y = np.array([0, 1, -1])
156+
mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=1)
157+
assert_warns(ConvergenceWarning, mdl.fit, X, y)
158+
assert_equal(mdl.n_iter_, mdl.max_iter)
159+
160+
mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=1)
161+
assert_warns(ConvergenceWarning, mdl.fit, X, y)
162+
assert_equal(mdl.n_iter_, mdl.max_iter)
163+
164+
mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=500)
165+
assert_no_warnings(mdl.fit, X, y)
166+
167+
mdl = label_propagation.LabelPropagation(kernel='rbf', max_iter=500)
168+
assert_no_warnings(mdl.fit, X, y)

0 commit comments

Comments
 (0)
0