8000 ENH refactor NMF and add CD solver · scikit-learn/scikit-learn@26896c4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 26896c4

Browse files
committed
ENH refactor NMF and add CD solver
1 parent 64d3f45 commit 26896c4

File tree

6 files changed

+6627
-318
lines changed

6 files changed

+6627
-318
lines changed

benchmarks/bench_plot_nmf.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sklearn.externals.six.moves import xrange
1717

1818

19-
def alt_nnmf(V, r, max_iter=1000, tol=1e-3, R=None):
19+
def alt_nnmf(V, r, max_iter=1000, tol=1e-3, init='random'):
2020
'''
2121
A, S = nnmf(X, r, tol=1e-3, R=None)
2222
@@ -33,8 +33,8 @@ def alt_nnmf(V, r, max_iter=1000, tol=1e-3, R=None):
3333
tol : double
3434
tolerance threshold for early exit (when the update factor is within
3535
tol of 1., the function exits)
36-
R : integer, optional
37-
random seed
36+
init : string
37+
Method used to initialize the procedure.
3838
3939
Returns
4040
-------
@@ -52,12 +52,7 @@ def alt_nnmf(V, r, max_iter=1000, tol=1e-3, R=None):
5252
# Nomenclature in the function follows Lee & Seung
5353
eps = 1e-5
5454
n, m = V.shape
55-
if R == "svd":
56-
W, H = _initialize_nmf(V, r)
57-
elif R is None:
58-
R = np.random.mtrand._rand
59-
W = np.abs(R.standard_normal((n, r)))
60-
H = np.abs(R.standard_normal((r, m)))
55+
W, H = _initialize_nmf(V, r, init, random_state=0)
6156

6257
for i in xrange(max_iter):
6358
updateH = np.dot(W.T, V) / (np.dot(np.dot(W.T, W), H) + eps)
@@ -78,17 +73,15 @@ def report(error, time):
7873

7974

8075
def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
81-
it = 0
8276
timeset = defaultdict(lambda: [])
8377
err = defaultdict(lambda: [])
8478

85-
max_it = len(samples_range) * len(features_range)
8679
for n_samples in samples_range:
8780
for n_features in features_range:
8881
print("%2d samples, %2d features" % (n_samples, n_features))
8982
print('=======================')
9083
X = np.abs(make_low_rank_matrix(n_samples, n_features,
91-
effective_rank=rank, tail_strength=0.2))
84+
effective_rank=rank, tail_strength=0.2))
9285

9386
gc.collect()
9487
print("benchmarking nndsvd-nmf: ")
@@ -122,7 +115,7 @@ def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
122115
gc.collect()
123116
print("benchmarking random-nmf")
124117
tstart = time()
125-
m = NMF(n_components=30, init=None, max_iter=1000,
118+
m = NMF(n_components=30, init='random', max_iter=1000,
126119
tol=tolerance).fit(X)
127120
tend = time() - tstart
128121
timeset['random-nmf'].append(tend)
@@ -132,7 +125,7 @@ def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
132125
gc.collect()
133126
print("benchmarking alt-random-nmf")
134127
tstart = time()
135-
W, H = alt_nnmf(X, r=30, R=None, tol=tolerance)
128+
W, H = alt_nnmf(X, r=30, init='random', tol=tolerance)
136129
tend = time() - tstart
137130
timeset['alt-random-nmf'].append(tend)
138131
err['alt-random-nmf'].append(np.linalg.norm(X - np.dot(W, H)))
@@ -151,7 +144,8 @@ def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
151144
timeset, err = benchmark(samples_range, features_range)
152145

153146
for i, results in enumerate((timeset, err)):
154-
fig = plt.figure('scikit-learn Non-Negative Matrix Factorization benchmark results')
147+
fig = plt.figure('scikit-learn Non-Negative Matrix Factorization'
148+
'benchmark results')
155149
ax = fig.gca(projection='3d')
156150
for c, (label, timings) in zip('rbgcm', sorted(results.iteritems())):
157151
X, Y = np.meshgrid(samples_range, features_range)

0 commit comments

Comments
 (0)
0