8000 temporary test cd kl · scikit-learn/scikit-learn@a68d728 · GitHub
[go: up one dir, main page]

Skip to content

Commit a68d728

Browse files
committed
temporary test cd kl
1 parent b41c132 commit a68d728

File tree

2 files changed

+114
-13
lines changed

2 files changed

+114
-13
lines changed

sklearn/decomposition/cdnmf_fast.pyx

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# License: BSD 3 clause
77

88
cimport cython
9-
from libc.math cimport fabs
9+
from libc.math cimport fabs, log
1010

1111

1212
def _update_cdnmf_fast(double[:, ::1] W, double[:, :] HHt, double[:, :] XHt,
@@ -39,3 +39,58 @@ def _update_cdnmf_fast(double[:, ::1] W, double[:, :] HHt, double[:, :] XHt,
3939
W[i, t] = max(W[i, t] - grad / hess, 0.)
4040

4141
return violation
42+
43+
44+
def _update_cdkl_fast(double[:, :] X, double[:, ::1] W, double[:, ::1] Ht,
45+
double[:, :] WH, Py_ssize_t[::1] permutation,
46+
double reset, double epsilon, int max_iter):
47+
48+
cdef double violation = 0
49+
cdef int n_components = W.shape[1]
50+
cdef int n_samples = W.shape[0] # n_features for H update
51+
cdef int n_features = Ht.shape[0] # n_samples for H update
52+
cdef double grad, hess, num, s, temp, div
53+
cdef int j, i, rr, r, n_iter
54+
55+
with nogil:
56+
for rr in range(n_components):
57+
r = permutation[rr]
58+
for i in range(n_samples):
59+
for n_iter in range(max_iter):
60+
61+
grad = 0.
62+
hess = 0.
63+
64+
for j in range(n_features):
65+
div = Ht[j, r] / WH[i, j]
66+
temp = X[i, j] * div
67+
grad += Ht[j, r] - temp
68+
hess += temp * div
69+
70+
if grad == 0:
71+
break
72+
73+
if hess == 0:
74+
s = reset - W[i, r]
75+
with gil:
76+
print(reset, s)
77+
else:
78+
s = max(-grad / hess , -W[i, r])
79+
80+
# projected gradient
81+
pg = min(0., grad) if W[i, r] == 0 else grad
82+
violation += fabs(pg)
83+
84+
# maintain WH
85+
for j in range(n_features):
86+
WH[i, j] += s * Ht[j, r]
87+
88+
# stopping condition
89+
#if epsilon > 0 and fabs(s) < epsilon * W[i, r]:
90+
# W[i, r] += s
91+
# break
92+
93+
# update
94+
W[i, r] += s
95+
96+
return violation

sklearn/decomposition/nmf.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..utils.validation import check_is_fitted, check_non_negative
2929
from ..utils import deprecated
3030
from ..exceptions import ConvergenceWarning
31-
from .cdnmf_fast import _update_cdnmf_fast
31+
from .cdnmf_fast import _update_cdnmf_fast, _update_cdkl_fast
3232

3333
EPSILSON = 1e-9
3434

@@ -609,9 +609,28 @@ def _update_coordinate_descent(X, W, Ht, l1_reg, l2_reg, shuffle,
609609
return _update_cdnmf_fast(W, HHt, XHt, permutation)
610610

611611

612-
def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0,
613-
l1_reg_H=0, l2_reg_W=0, l2_reg_H=0, update_H=True,
614-
verbose=0, shuffle=False, random_state=None):
612+
def _update_coordinate_descent_kl(X, W, Ht, l1_reg, l2_reg, shuffle,
613+
random_state, WH, reset):
614+
epsilon = 0.
615+
max_iter = 1
616+
n_components = Ht.shape[1]
617+
618+
# TODO add regularization
619+
620+
if shuffle:
621+
permutation = random_state.permutation(n_components)
622+
else:
623+
permutation = np.arange(n_components)
624+
# The following seems to be required on 64-bit Windows w/ Python 3.5.
625+
permutation = np.asarray(permutation, dtype=np.intp)
626+
return _update_cdkl_fast(X, W, Ht, WH, permutation, reset,
627+
epsilon, max_iter)
628+
629+
630+
def _fit_coordinate_descent(X, W, H, beta_loss='frobenius', tol=1e-4,
631+
max_iter=200, l1_reg_W=0, l1_reg_H=0, l2_reg_W=0,
632+
l2_reg_H=0, update_H=True, verbose=0,
633+
shuffle=False, random_state=None):
615634
"""Compute Non-negative Matrix Factorization (NMF) with Coordinate Descent
616635
617636
The objective function is minimized with an alternating minimization of W
@@ -680,21 +699,46 @@ def _fit_coordinate_descent(X, W, H, tol=1e-4, max_iter=200, l1_reg_W=0,
680699
"""
681700
# so W and Ht are both in C order in memory
682701
Ht = check_array(H.T, order='C')
683-
X = check_array(X, accept_sparse='csr')
702+
703+
beta_loss = _beta_loss_to_float(beta_loss)
704+
if beta_loss == 2:
705+
X = check_array(X, accept_sparse='csr')
706+
elif beta_loss == 1:
707+
X = check_array(X, dtype=np.float64, accept_sparse=None)
708+
WH = fast_dot(W, Ht.T)
709+
reset = X.mean() * 0.01
710+
else:
711+
raise ValueError("cd solver only accepts beta_loss in ('frobenius', "
712+
"'kullback-leibler', 2, 1). Got %s" % str(beta_loss))
684713

685714
rng = check_random_state(random_state)
686715

687716
for n_iter in range(max_iter):
688717
violation = 0.
689718

690719
# Update W
691-
violation += _update_coordinate_descent(X, W, Ht, l1_reg_W,
692-
l2_reg_W, shuffle, rng)
720+
if beta_loss == 2:
721+
violation += _update_coordinate_descent(X, W, Ht, l1_reg_W,
722+
l2_reg_W,
723+
shuffle, rng)
724+
else:
725+
violation += _update_coordinate_descent_kl(X, W, Ht, l1_reg_W,
726+
l2_reg_W,
727+
shuffle, rng, WH,
728+
reset)
693729
# Update H
694730
if update_H:
695-
violation += _update_coordinate_descent(X.T, Ht, W, l1_reg_H,
696-
l2_reg_H, shuffle, rng)
697-
731+
if beta_loss == 2:
732+
violation += _update_coordinate_descent(X.T, Ht, W,
733+
l1_reg_H,
734+
l2_reg_H,
735+
shuffle, rng)
736+
else:
737+
violation += _update_coordinate_descent_kl(X.T, Ht, W,
738+
l1_reg_H,
739+
l2_reg_H,
740+
shuffle, rng,
741+
WH.T, reset)
698742
if n_iter == 0:
699743
violation_init = violation
700744

@@ -1185,7 +1229,8 @@ def non_negative_factorization(X, W=None, H=None, n_components=None,
11851229
sparseness, beta,
11861230
eta)
11871231
elif solver == 'cd':
1188-
W, H, n_iter = _fit_coordinate_descent(X, W, H, tol, max_iter,
1232+
W, H, n_iter = _fit_coordinate_descent(X, W, H, beta_loss,
1233+
tol, max_iter,
11891234
l1_reg_W, l1_reg_H,
11901235
l2_reg_W, l2_reg_H,
11911236
update_H=update_H,
@@ -1306,7 +1351,8 @@ class NMF(BaseEstimator, TransformerMixin):
13061351
For 0 < l1_ratio < 1, the penalty is a combination of L1 and L2.
13071352
13081353
.. versionadded:: 0.17
1309-
Regularization parameter *l1_ratio* used in the Coordinate Descent solver.
1354+
Regularization parameter *l1_ratio* used in the
1355+
Coordinate Descent solver.
13101356
13111357
shuffle : boolean, default: False
13121358
If true, randomize the order of coordinates in the CD solver.

0 commit comments

Comments
 (0)
0