3434>>> from sklearn.semi_supervised import LabelPropagation
3535>>> label_prop_model = LabelPropagation()
3636>>> iris = datasets.load_iris()
37- >>> random_unlabeled_points = np.where(np. random.randint(0, 2,
38- ... size= len(iris.target)))
37+ >>> rng = np.random.RandomState(42)
38+ >>> random_unlabeled_points = rng.rand( len(iris.target)) < 0.3
3939>>> labels = np.copy(iris.target)
4040>>> labels[random_unlabeled_points] = -1
4141>>> label_prop_model.fit(iris.data, labels)
5353"""
5454
5555# Authors: Clay Woolam <clay@woolam.org>
56+ # Utkarsh Upadhyay <mail@musicallyut.in>
5657# License: BSD
5758from abc import ABCMeta , abstractmethod
5859
6768from ..utils .extmath import safe_sparse_dot
6869from ..utils .multiclass import check_classification_targets
6970from ..utils .validation import check_X_y , check_is_fitted , check_array
70-
71-
72- # Helper functions
73-
74- def _not_converged (y_truth , y_prediction , tol = 1e-3 ):
75- """basic convergence check"""
76- return np .abs (y_truth - y_prediction ).sum () > tol
71+ from ..exceptions import ConvergenceWarning
7772
7873
7974class BaseLabelPropagation (six .with_metaclass (ABCMeta , BaseEstimator ,
@@ -97,7 +92,7 @@ class BaseLabelPropagation(six.with_metaclass(ABCMeta, BaseEstimator,
9792 alpha : float
9893 Clamping factor
9994
100- max_iter : float
95+ max_iter : integer
10196 Change maximum number of iterations allowed
10297
10398 tol : float
@@ -264,12 +259,14 @@ def fit(self, X, y):
264259
265260 l_previous = np .zeros ((self .X_ .shape [0 ], n_classes ))
266261
267- remaining_iter = self .max_iter
268262 unlabeled = unlabeled [:, np .newaxis ]
269263 if sparse .isspmatrix (graph_matrix ):
270264 graph_matrix = graph_matrix .tocsr ()
271- while (_not_converged (self .label_distributions_ , l_previous , self .tol )
272- and remaining_iter > 1 ):
265+
266+ for self .n_iter_ in range (self .max_iter ):
267+ if np .abs (self .label_distributions_ - l_previous ).sum () < self .tol :
268+ break
269+
273270 l_previous = self .label_distributions_
274271 self .label_distributions_ = safe_sparse_dot (
275272 graph_matrix , self .label_distributions_ )
@@ -285,7 +282,12 @@ def fit(self, X, y):
285282 # clamp
286283 self .label_distributions_ = np .multiply (
287284 alpha , self .label_distributions_ ) + y_static
288- remaining_iter -= 1
285+ else :
286+ warnings .warn (
287+ 'max_iter=%d was reached without convergence.' % self .max_iter ,
288+ category = ConvergenceWarning
289+ )
290+ self .n_iter_ += 1
289291
290292 normalizer = np .sum (self .label_distributions_ , axis = 1 )[:, np .newaxis ]
291293 self .label_distributions_ /= normalizer
@@ -294,7 +296,6 @@ def fit(self, X, y):
294296 transduction = self .classes_ [np .argmax (self .label_distributions_ ,
295297 axis = 1 )]
296298 self .transduction_ = transduction .ravel ()
297- self .n_iter_ = self .max_iter - remaining_iter
298299 return self
299300
300301
@@ -324,7 +325,7 @@ class LabelPropagation(BaseLabelPropagation):
324325 This parameter will be removed in 0.21.
325326 'alpha' is fixed to zero in 'LabelPropagation'.
326327
327- max_iter : float
328+ max_iter : integer
328329 Change maximum number of iterations allowed
329330
330331 tol : float
@@ -358,8 +359,8 @@ class LabelPropagation(BaseLabelPropagation):
358359 >>> from sklearn.semi_supervised import LabelPropagation
359360 >>> label_prop_model = LabelPropagation()
360361 >>> iris = datasets.load_iris()
361- >>> random_unlabeled_points = np.where(np. random.randint(0, 2,
362- ... size= len(iris.target)))
362+ >>> rng = np.random.RandomState(42)
363+ >>> random_unlabeled_points = rng.rand( len(iris.target)) < 0.3
363364 >>> labels = np.copy(iris.target)
364365 >>> labels[random_unlabeled_points] = -1
365366 >>> label_prop_model.fit(iris.data, labels)
@@ -441,7 +442,7 @@ class LabelSpreading(BaseLabelPropagation):
441442 alpha=0 means keeping the initial label information; alpha=1 means
442443 replacing all initial information.
443444
444- max_iter : float
445+ max_iter : integer
445446 maximum number of iterations allowed
446447
447448 tol : float
@@ -475,8 +476,8 @@ class LabelSpreading(BaseLabelPropagation):
475476 >>> from sklearn.semi_supervised import LabelSpreading
476477 >>> label_prop_model = LabelSpreading()
477478 >>> iris = datasets.load_iris()
478- >>> random_unlabeled_points = np.where(np. random.randint(0, 2,
479- ... size= len(iris.target)))
479+ >>> rng = np.random.RandomState(42)
480+ >>> random_unlabeled_points = rng.rand( len(iris.target)) < 0.3
480481 >>> labels = np.copy(iris.target)
481482 >>> labels[random_unlabeled_points] = -1
482483 >>> label_prop_model.fit(iris.data, labels)
0 commit comments