From dadc2fc8f2cf85a4577a393c35b4c1e3826070a8 Mon Sep 17 00:00:00 2001 From: Vlad Niculae Date: Tue, 29 Oct 2013 16:46:30 +0100 Subject: [PATCH] FIX Projected Gradient NMF stopping condition --- sklearn/decomposition/nmf.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sklearn/decomposition/nmf.py b/sklearn/decomposition/nmf.py index c347b85f67d2a..16f96fb435b11 100644 --- a/sklearn/decomposition/nmf.py +++ b/sklearn/decomposition/nmf.py @@ -499,17 +499,18 @@ def fit_transform(self, X, y=None): - safe_sparse_dot(X, H.T, dense_output=True)) gradH = (np.dot(np.dot(W.T, W), H) - safe_sparse_dot(W.T, X, dense_output=True)) - init_grad = norm(np.r_[gradW, gradH.T]) - tolW = max(0.001, self.tol) * init_grad # why max? + init_grad = (gradW ** 2).sum() + (gradH ** 2).sum() + tolW = max(0.001, self.tol) * np.sqrt(init_grad) # why max? tolH = tolW - tol = self.tol * init_grad + tol = init_grad * self.tol ** 2 for n_iter in range(1, self.max_iter + 1): - # stopping condition - # as discussed in paper - proj_norm = norm(np.r_[gradW[np.logical_or(gradW < 0, W > 0)], - gradH[np.logical_or(gradH < 0, H > 0)]]) + # stopping condition on the norm of the projected gradient + proj_norm = ( + ((gradW * np.logical_or(gradW < 0, W > 0)) ** 2).sum() + + ((gradH * np.logical_or(gradH < 0, H > 0)) ** 2).sum()) + if proj_norm < tol: break