34
34
>>> from sklearn.semi_supervised import LabelPropagation
35
35
>>> label_prop_model = LabelPropagation()
36
36
>>> iris = datasets.load_iris()
37
- >>> random_unlabeled_points = np.where(np. random.randint(0, 2,
38
- ... size= len(iris.target)))
37
+ >>> rng = np.random.RandomState(42)
38
+ >>> random_unlabeled_points = rng.rand( len(iris.target)) < 0.3
39
39
>>> labels = np.copy(iris.target)
40
40
>>> labels[random_unlabeled_points] = -1
41
41
>>> label_prop_model.fit(iris.data, labels)
53
53
"""
54
54
55
55
# Authors: Clay Woolam <clay@woolam.org>
56
+ # Utkarsh Upadhyay <mail@musicallyut.in>
56
57
# License: BSD
57
58
from abc import ABCMeta , abstractmethod
58
59
67
68
from ..utils .extmath import safe_sparse_dot
68
69
from ..utils .multiclass import check_classification_targets
69
70
from ..utils .validation import check_X_y , check_is_fitted , check_array
70
-
71
-
72
- # Helper functions
73
-
74
- def _not_converged (y_truth , y_prediction , tol = 1e-3 ):
75
- """basic convergence check"""
76
- return np .abs (y_truth - y_prediction ).sum () > tol
71
+ from ..exceptions import ConvergenceWarning
77
72
78
73
79
74
class BaseLabelPropagation (six .with_metaclass (ABCMeta , BaseEstimator ,
@@ -97,7 +92,7 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
97
92
alpha : float
98
93
Clamping factor
99
94
100
- max_iter : float
95
+ max_iter : integer
101
96
Change maximum number of iterations allowed
102
97
103
98
tol : float
@@ -264,12 +259,14 @@ def fit(self, X, y):
264
259
265
260
l_previous = np .zeros ((self .X_ .shape [0 ], n_classes ))
266
261
267
- remaining_iter = self .max_iter
268
262
unlabeled = unlabeled [:, np .newaxis ]
269
263
if sparse .isspmatrix (graph_matrix ):
270
264
graph_matrix = graph_matrix .tocsr ()
271
- while (_not_converged (self .label_distributions_ , l_previous , self .
8000
tol )
272
- and remaining_iter > 1 ):
265
+
266
+ for self .n_iter_ in range (self .max_iter ):
267
+ if np .abs (self .label_distributions_ - l_previous ).sum () < self .tol :
268
+ break
269
+
273
270
l_previous = self .label_distributions_
274
271
self .label_distributions_ = safe_sparse_dot (
275
272
graph_matrix , self .label_distributions_ )
@@ -285,7 +282,12 @@ def fit(self, X, y):
285
282
# clamp
286
283
self .label_distributions_ = np .multiply (
287
284
alpha , self .label_distributions_ ) + y_static
288
- remaining_iter -= 1
285
+ else :
286
+ warnings .warn (
287
+ 'max_iter=%d was reached without convergence.' % self .max_iter ,
288
+ category = ConvergenceWarning
289
+ )
290
+ self .n_iter_ += 1
289
291
290
292
normalizer = np .sum (self .label_distributions_ , axis = 1 )[:, np .newaxis ]
291
293
self .label_distributions_ /= normalizer
@@ -294,7 +296,6 @@ def fit(self, X, y):
294
296
transduction = self .classes_ [np .argmax (self .label_distributions_ ,
295
297
axis = 1 )]
296
298
self .transduction_ = transduction .ravel ()
297
- self .n_iter_ = self .max_iter - remaining_iter
298
299
return self
299
300
300
301
@@ -324,7 +325,7 @@ class LabelPropagation(BaseLabelPropagation):
324
325
This parameter will be removed in 0.21.
325
326
'alpha' is fixed to zero in 'LabelPropagation'.
326
327
327
- max_iter : float
328
+ max_iter : integer
328
329
Change maximum number of iterations allowed
329
330
330
331
tol : float
@@ -358,8 +359,8 @@ class LabelPropagation(BaseLabelPropagation):
358
359
>>> from sklearn.semi_supervised import LabelPropagation
359
360
>>> label_prop_model = LabelPropagation()
360
361
>>> iris = datasets.load_iris()
361
- >>> random_unlabeled_points = np.where(np. random.randint(0, 2,
362
- ... size= len(iris.target)))
362
+ >>> rng = np.random.RandomState(42)
363
+ >>> random_unlabeled_points = rng.rand( len(iris.target)) < 0.3
363
364
>>> labels = np.copy(iris.target)
364
365
>>> labels[random_unlabeled_points] = -1
365
366
>>> label_prop_model.fit(iris.data, labels)
@@ -441,7 +442,7 @@ class LabelSpreading(BaseLabelPropagation):
441
442
alpha=0 means keeping the initial label information; alpha=1 means
442
443
replacing all initial information.
443
444
444
- max_iter : float
445
+ max_iter : integer
445
446
maximum number of iterations allowed
446
447
447
448
tol : float
@@ -475,8 +476,8 @@ class LabelSpreading(BaseLabelPropagation):
475
476
>>> from sklearn.semi_supervised import LabelSpreading
476
477
>>> label_prop_model = LabelSpreading()
477
478
>>> iris = datasets.load_iris()
478
- >>> random_unlabeled_points = np.where(np. random.randint(0, 2,
479
- ... size= len(iris.target)))
479
+ >>> rng = np.random.RandomState(42)
480
+ >>> random_unlabeled_points = rng.rand( len(iris.target)) < 0.3
480
481
>>> labels = np.copy(iris.target)
481
482
>>> labels[random_unlabeled_points] = -1
482
483
>>> label_prop_model.fit(iris.data, labels)
0 commit comments