|
28 | 28 | from ..utils.validation import check_is_fitted, check_non_negative
|
29 | 29 | from ..utils import deprecated
|
30 | 30 | from ..exceptions import ConvergenceWarning
|
31 |
| -from .cdnmf_fast import _update_cdnmf_fast |
| 31 | +from .cdnmf_fast import _update_cdnmf_fast, _update_cdkl_fast |
32 | 32 |
|
33 | 33 | EPSILSON = 1e-9
|
34 | 34 |
|
@@ -609,9 +609,28 @@ def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle,
|
609 | 609 | return _update_cdnmf_fast(W, HHt, XHt, permutation)
|
610 | 610 |
|
611 | 611 |
|
612 |
| -def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0, |
613 |
| - l1_reg_H=0, l2_reg_W=0, l2_reg_H=0, update_H=True, |
614 |
| - verbose=0, shuffle=False, random_state=None): |
| 612 | +def _update_coordinate_descent_kl(X, W, Ht, l1_reg, l2_reg, shuffle, |
| 613 | + random_state, WH, reset): |
| 614 | + epsilon = 0. |
| 615 | + max_iter = 1 |
| 616 | + n_components = Ht.shape[1] |
| 617 | + |
| 618 | + # TODO add regularization |
| 619 | + |
| 620 | + if shuffle: |
| 621 | + permutation = random_state.permutation(n_components) |
| 622 | + else: |
| 623 | + permutation = np.arange(n_components) |
| 624 | + # The following seems to be required on 64-bit Windows w/ Python 3.5. |
| 625 | + permutation = np.asarray(permutation, dtype=np.intp) |
| 626 | + return _update_cdkl_fast(X, W, Ht, WH, permutation, reset, |
| 627 | + epsilon, max_iter) |
| 628 | + |
| 629 | + |
| 630 | +def _fit_coordinate_descent(X, W, H, beta_loss='frobenius', tol=1e-4, |
| 631 | + max_iter=200, l1_reg_W=0, l1_reg_H=0, l2_reg_W=0, |
| 632 | + l2_reg_H=0, update_H=True, verbose=0, |
| 633 | + shuffle=False, random_state=None): |
615 | 634 | """Compute Non-negative Matrix Factorization (NMF) with Coordinate Descent
|
616 | 635 |
|
617 | 636 | The objective function is minimized with an alternating minimization of W
|
@@ -680,21 +699,46 @@ def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0,
|
680 | 699 | """
|
681 | 700 | # so W and Ht are both in C order in memory
|
682 | 701 | Ht = check_array(H.T, order='C')
|
683 |
| - X = check_array(X, accept_sparse='csr') |
| 702 | + |
| 703 | + beta_loss = _beta_loss_to_float(beta_loss) |
| 704 | + if beta_loss == 2: |
| 705 | + X = check_array(X, accept_sparse='csr') |
| 706 | + elif beta_loss == 1: |
| 707 | + X = check_array(X, dtype=np.float64, accept_sparse=None) |
| 708 | + WH = fast_dot(W, Ht.T) |
| 709 | + reset = X.mean() * 0.01 |
| 710 | + else: |
| 711 | + raise ValueError("cd solver only accepts beta_loss in ('frobenius', " |
| 712 | + "'kullback-leibler', 2, 1). Got %s" % str(beta_loss)) |
684 | 713 |
|
685 | 714 | rng = check_random_state(random_state)
|
686 | 715 |
|
687 | 716 | for n_iter in range(max_iter):
|
688 | 717 | violation = 0.
|
689 | 718 |
|
690 | 719 | # Update W
|
691 |
| - violation += _update_coordinate_descent(X, W, Ht, l1_reg_W, |
692 |
| - l2_reg_W, shuffle, rng) |
| 720 | + if beta_loss == 2: |
| 721 | + violation += _update_coordinate_descent(X, W, Ht, l1_reg_W, |
| 722 | + l2_reg_W, |
| 723 | + shuffle, rng) |
| 724 | + else: |
| 725 | + violation += _update_coordinate_descent_kl(X, W, Ht, l1_reg_W, |
| 726 | + l2_reg_W, |
| 727 | + shuffle, rng, WH, |
| 728 | + reset) |
693 | 729 | # Update H
|
694 | 730 | if update_H:
|
695 |
| - violation += _update_coordinate_descent(X.T, Ht, W, l1_reg_H, |
696 |
| - l2_reg_H, shuffle, rng) |
697 |
| - |
| 731 | + if beta_loss == 2: |
| 732 | + violation += _update_coordinate_descent(X.T, Ht, W, |
| 733 | + l1_reg_H, |
| 734 | + l2_reg_H, |
| 735 | + shuffle, rng) |
| 736 | + else: |
| 737 | + violation += _update_coordinate_descent_kl(X.T, Ht, W, |
| 738 | + l1_reg_H, |
| 739 | + l2_reg_H, |
| 740 | + shuffle, rng, |
| 741 | + WH.T, reset) |
698 | 742 | if n_iter == 0:
|
699 | 743 | violation_init = violation
|
700 | 744 |
|
@@ -1185,7 +1229,8 @@ def non_negative_factorization(X, W=None, H=None, n_components=None,
|
1185 | 1229 | sparseness, beta,
|
1186 | 1230 | eta)
|
1187 | 1231 | elif solver == 'cd':
|
1188 |
| - W, H, n_iter = _fit_coordinate_descent(X, W, H, tol, max_iter, |
| 1232 | + W, H, n_iter = _fit_coordinate_descent(X, W, H, beta_loss, |
| 1233 | + tol, max_iter, |
1189 | 1234 | l1_reg_W, l1_reg_H,
|
1190 | 1235 | l2_reg_W, l2_reg_H,
|
1191 | 1236 | update_H=update_H,
|
@@ -1306,7 +1351,8 @@ class NMF(BaseEstimator, TransformerMixin):
|
1306 | 1351 | For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
|
1307 | 1352 |
|
1308 | 1353 | .. versionadded:: 0.17
|
1309 |
| - Regularization parameter *l1_ratio* used in the Coordinate Descent solver. |
| 1354 | + Regularization parameter *l1_ratio* used in the |
| 1355 | + Coordinate Descent solver. |
1310 | 1356 |
|
1311 | 1357 | shuffle : boolean, default: False
|
1312 | 1358 | If true, randomize the order of coordinates in the CD solver.
|
|
0 commit comments