8000 FIX Projected Gradient NMF stopping condition · scikit-learn/scikit-learn@dadc2fc · GitHub
[go: up one dir, main page]

Skip to content

Commit dadc2fc

Browse files
committed
FIX Projected Gradient NMF stopping condition
1 parent 8a0b317 commit dadc2fc

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

sklearn/decomposition/nmf.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -499,17 +499,18 @@ def fit_transform(self, X, y=None):
499499
- safe_sparse_dot(X, H.T, dense_output=True))
500500
gradH = (np.dot(np.dot(W.T, W), H)
501501
- safe_sparse_dot(W.T, X, dense_output=True))
502-
init_grad = norm(np.r_[gradW, gradH.T])
503-
tolW = max(0.001, self.tol) * init_grad # why max?
502+
init_grad = (gradW ** 2).sum() + (gradH ** 2).sum()
503+
tolW = max(0.001, self.tol) * np.sqrt(init_grad) # why max?
504504
tolH = tolW
505505

506-
tol = self.tol * init_grad
506+
tol = init_grad * self.tol ** 2
507507

508508
for n_iter in range(1, self.max_iter + 1):
509-
# stopping condition
510-
# as discussed in paper
511-
proj_norm = norm(np.r_[gradW[np.logical_or(gradW < 0, W > 0)],
512-
gradH[np.logical_or(gradH < 0, H > 0)]])
509+
# stopping condition on the norm of the projected gradient
510+
proj_norm = (
511+
((gradW * np.logical_or(gradW < 0, W > 0)) ** 2).sum() +
512+
((gradH * np.logical_or(gradH < 0, H > 0)) ** 2).sum())
513+
513514
if proj_norm < tol:
514515
break
515516

0 commit comments

Comments
 (0)
0