8000 REFACTOR: dic_learning and dict_learning_online refactored into a sin… · scikit-learn/scikit-learn@4d6cab8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 4d6cab8

Browse files
committed
REFACTOR: dic_learning and dict_learning_online refactored into a single function
1 parent 425fc14 commit 4d6cab8

File tree

2 files changed

+96
-123
lines changed

2 files changed

+96
-123
lines changed

sklearn/decomposition/dict_learning.py

Lines changed: 93 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -309,19 +309,28 @@ def sparse_encode(X, dictionary, gram=None, cov=None, algorithm='lasso_lars',
309309
return code
310310

311311

312-
def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
313-
random_state=None):
312+
def _proj_l2(v):
313+
"""Projects v unto l2 unit ball in-place.
314+
"""
315+
vv = np.dot(v, v)
316+
if vv > 1.:
317+
v /= sqrt(vv)
318+
return v
319+
320+
321+
def _update_dict(dictionary, B, A, verbose=False, return_r2=False,
322+
random_state=None, online=False):
314323
"""Update the dense dictionary factor in place.
315324
316325
Parameters
317326
----------
318327
dictionary : array of shape (n_features, n_components)
319328
Value of the dictionary at the previous iteration.
320329
321-
Y : array of shape (n_features, n_samples)
330+
B : array of shape (n_features, n_components)
322331
Data matrix.
323332
324-
code : array of shape (n_components, n_samples)
333+
A : array of shape (n_components, n_components)
325334
Sparse coding of the data against which to optimize the dictionary.
326335
327336
verbose:
@@ -343,35 +352,32 @@ def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
343352
Updated dictionary.
344353
345354
"""
346-
n_components = len(code)
347-
n_samples = Y.shape[0]
355+
n_features, n_components = B.shape
348356
random_state = check_random_state(random_state)
349357
# Residuals, computed 'in-place' for efficiency
350-
R = -np.dot(dictionary, code)
351-
R += Y
358+
R = -np.dot(dictionary, A)
359+
R += B
352360
R = np.asfortranarray(R)
353-
ger, = linalg.get_blas_funcs(('ger',), (dictionary, code))
361+
ger, = linalg.get_blas_funcs(('ger',), (dictionary, A))
354362
for k in range(n_components):
355363
# R <- 1.0 * U_k * V_k^T + R
356-
R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
357-
dictionary[:, k] = np.dot(R, code[k, :].T)
364+
R = ger(1.0, dictionary[:, k], A[k, :], a=R, overwrite_a=True)
365+
dictionary[:, k] = np.dot(R, A[k, :])
358366
# Scale k'th atom
359-
atom_norm_square = np.dot(dictionary[:, k], dictionary[:, k])
360-
if atom_norm_square < 1e-20:
367+
if A[k, k] < 1e-20:
361368
if verbose == 1:
362369
sys.stdout.write("+")
363370
sys.stdout.flush()
364371
elif verbose:
365372
print("Adding new random atom")
366-
dictionary[:, k] = random_state.randn(n_samples)
373+
dictionary[:, k] = random_state.randn(n_features)
367374
# Setting corresponding coefs to 0
368-
code[k, :] = 0.0
369-
dictionary[:, k] /= sqrt(np.dot(dictionary[:, k],
370-
dictionary[:, k]))
375+
A[k, :] = 0.
371376
else:
372-
dictionary[:, k] /= sqrt(atom_norm_square)
373-
# R <- -1.0 * U_k * V_k^T + R
374-
R = ger(-1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
377+
dictionary[:, k] /= A[k, k]
378+
_proj_l2(dictionary[:, k])
379+
# R <- -1.0 * U_k * V_k^T + R
380+
R = ger(-1.0, dictionary[:, k], A[k, :], a=R, overwrite_a=True)
375381
if return_r2:
376382
R **= 2
377383
# R is fortran-ordered. For numpy version < 1.6, sum does not
@@ -472,98 +478,39 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
472478
SparsePCA
473479
MiniBatchSparsePCA
474480
"""
475-
if method not in ('lars', 'cd'):
476-
raise ValueError('Coding method %r not supported as a fit algorithm.'
477-
% method)
478-
method = 'lasso_' + method
481+
return dict_learning_online(
482+
X, n_components=n_components, alpha=alpha, n_iter=max_iter,
483+
return_code=True, dict_init=dict_init, callback=callback,
484+
batch_size=len(X), verbose=verbose, shuffle=False,
485+
return_n_iter=return_n_iter, n_jobs=n_jobs, method=method,
486+
return_inner_stats=False, tol=tol)
479487

480-
t0 = time.time()
481-
# Avoid integer division problems
482-
alpha = float(alpha)
483-
random_state = check_random_state(random_state)
484488

485-
if n_jobs == -1:
486-
n_jobs = cpu_count()
489+
def _compute_residuals_from_code(X, V, U):
490+
"""Computes ||X - UV||_F^2 directly.
487491
488-
# Init the code and the dictionary with SVD of Y
489-
if code_init is not None and dict_init is not None:
490-
code = np.array(code_init, order='F')
491-
# Don't copy V, it will happen below
492-
dictionary = dict_init
493-
else:
494-
code, S, dictionary = linalg.svd(X, full_matrices=False)
495-
dictionary = S[:, np.newaxis] * dictionary
496-
r = len(dictionary)
497-
if n_components <= r: # True even if n_components=None
498-
code = code[:, :n_components]
499-
dictionary = dictionary[:n_components, :]
500-
else:
501-
code = np.c_[code, np.zeros((len(code), n_components - r))]
502-
dictionary = np.r_[dictionary,
503-
np.zeros((n_components - r, dictionary.shape[1]))]
504-
505-
# Fortran-order dict, as we are going to access its row vectors
506-
dictionary = np.array(dictionary, order='F')
507-
508-
residuals = 0
509-
510-
errors = []
511-
current_cost = np.nan
512-
513-
if verbose == 1:
514-
print('[dict_learning]', end=' ')
515-
516-
# If max_iter is 0, number of iterations returned should be zero
517-
ii = -1
518-
519-
for ii in range(max_iter):
520-
dt = (time.time() - t0)
521-
if verbose == 1:
522-
sys.stdout.write(".")
523-
sys.stdout.flush()
524-
elif verbose:
525-
print("Iteration % 3i "
526-
"(elapsed time: % 3is, % 4.1fmn, current cost % 7.3f)"
527-
% (ii, dt, dt / 60, current_cost))
528-
529-
# Update code
530-
code = sparse_encode(X, dictionary, algorithm=method, alpha=alpha,
531-
init=code, n_jobs=n_jobs)
532-
# Update dictionary
533-
dictionary, residuals = _update_dict(dictionary.T, X.T, code.T,
534-
verbose=verbose, return_r2=True,
535-
random_state=random_state)
536-
dictionary = dictionary.T
537-
538-
# Cost function
539-
current_cost = 0.5 * residuals + alpha * np.sum(np.abs(code))
540-
errors.append(current_cost)
541-
542-
if ii > 0:
543-
dE = errors[-2] - errors[-1]
544-
# assert(dE >= -tol * errors[-1])
545-
if dE < tol * errors[-1]:
546-
if verbose == 1:
547-
# A line return
548-
print("")
549-
elif verbose:
550-
print("--- Convergence reached after %d iterations" % ii)
551-
break
552-
if ii % 5 == 0 and callback is not None:
553-
callback(locals())
554-
555-
if return_n_iter:
556-
return code, dictionary, errors, ii + 1
557-
else:
558-
return code, dictionary, errors
492+
Parameters
493+
==========
494+
X: ndarray, shape (n_samples, n_features)
495+
The input data.
496+
V: ndarray, shape (n_features, n_components)
497+
The dictionary.
498+
U: ndarray, shape (n_samples, n_components)
499+
The codes.
500+
"""
501+
residuals = V.dot(U)
502+
residuals -= X.T
503+
residuals **= 2
504+
residuals = np.sum(residuals)
505+
return residuals
559506

560507

561508
def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
562509
return_code=True, dict_init=None, callback=None,
563-
batch_size=3, verbose=False, shuffle=True, n_jobs=1,
564-
method='lars', iter_offset=0, random_state=None,
565-
return_inner_stats=False, inner_stats=None,
566-
return_n_iter=False):
510+
batch_size=None, verbose=False, shuffle=True,
511+
n_jobs=1, method='lars', iter_offset=0, tol=0.,
512+
random_state=None, return_inner_stats=False,
513+
inner_stats=None, return_n_iter=False):
567514
"""Solves a dictionary learning matrix factorization problem online.
568515
569516
Finds the best dictionary and the corresponding sparse code for
@@ -711,6 +658,9 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
711658
copy=False)
712659
X_train = check_array(X_train, order='C', dtype=np.float64, copy=False)
713660

661+
if batch_size is None:
662+
batch_size = n_samples
663+
online = batch_size < n_samples
714664
batches = gen_batches(n_samples, batch_size)
715665
batches = itertools.cycle(batches)
716666

@@ -726,6 +676,8 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
726676
# If n_iter is zero, we need to return zero.
727677
ii = iter_offset - 1
728678

679+
err = 0.
680+
errors = []
729681
for ii, batch in zip(range(iter_offset, iter_offset + n_iter), batches):
730682
this_X = X_train[batch]
731683
dt = (time.time() - t0)
@@ -741,26 +693,46 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
741693
alpha=alpha, n_jobs=n_jobs).T
742694

743695
# Update the auxiliary variables
744-
if ii < batch_size - 1:
745-
theta = float((ii + 1) * batch_size)
696+
if online:
697+
if ii < batch_size - 1:
698+
theta = float((ii + 1) * batch_size)
699+
else:
700+
theta = float(batch_size ** 2 + ii + 1 - batch_size)
701+
beta = (theta + 1 - batch_size) / (theta + 1)
746702
else:
747-
theta = float(batch_size ** 2 + ii + 1 - batch_size)
748-
beta = (theta + 1 - batch_size) / (theta + 1)
749-
703+
beta = 0.
750704
A *= beta
751705
A += np.dot(this_code, this_code.T)
752706
B *= beta
753707
B += np.dot(this_X.T, this_code.T)
754708

755709
# Update dictionary
756710
dictionary = _update_dict(dictionary, B, A, verbose=verbose,
757-
random_state=random_state)
758-
# XXX: Can the residuals be of any use?
759-
760-
# Maybe we need a stopping criteria based on the amount of
761-
# modification in the dictionary
762-
if callback is not None:
763-
callback(locals())
711+
random_state=random_state, online=True,
712+
return_r2=False)
713+
714+
# Check convergence
715+
if not online and callback is None:
716+
residuals = _compute_residuals_from_code(this_X, dictionary,
717+
this_code)
718+
err = .5 * residuals + alpha * np.sum(np.abs(this_code))
719+
errors.append(err)
720+
if len(errors) > 1:
721+
dE = errors[-2] - errors[-1]
722+
# assert(dE >= -tol * errors[-1])
723+
if np.abs(dE) < tol * errors[-1]:
724+
if verbose == 1:
725+
# A line return
726+
print("")
727+
elif verbose:
728+
print(
729+
"--- Convergence reached after %d iterations" % ii)
730+
break
731+
elif callback is not None:
732+
# Maybe we need a stopping criteria based on the amount of
733+
# modification in the dictionary
734+
if not callback(locals()):
735+
break
764736

765737
if return_inner_stats:
766738
if return_n_iter:
@@ -778,14 +750,14 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
778750
dt = (time.time() - t0)
779751
print('done (total time: % 3is, % 4.1fmn)' % (dt, dt / 60))
780752
if return_n_iter:
781-
return code, dictionary.T, ii - iter_offset + 1
753+
return code, dictionary.T, errors, ii - iter_offset + 1
782754
else:
783-
return code, dictionary.T
755+
return code, dictionary.T, errors
784756

785757
if return_n_iter:
786-
return dictionary.T, ii - iter_offset + 1
758+
return dictionary.T, errors, ii - iter_offset + 1
787759
else:
788-
return dictionary.T
760+
return dictionary.T, errors
789761

790762

791763
class SparseCodingMixin(TransformerMixin):

sklearn/decomposition/tests/test_dict_learning.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from sklearn.utils.testing import assert_raises
1414
from sklearn.utils.testing import ignore_warnings
1515
from sklearn.utils.testing import TempMemmap
16+
from sklearn.utils.testing import assert_almost_equal
1617

1718
from sklearn.decomposition import DictionaryLearning
1819
from sklearn.decomposition import MiniBatchDictionaryLearning
@@ -128,8 +129,8 @@ def test_dict_learning_split():
128129
def test_dict_learning_online_shapes():
129130
rng = np.random.RandomState(0)
130131
n_components = 8
131-
code, dictionary = dict_learning_online(X, n_components=n_components,
132-
alpha=1, random_state=rng)
132+
code, dictionary, _ = dict_learning_online(X, n_components=n_components,
133+
alpha=1, random_state=rng)
133134
assert_equal(code.shape, (n_samples, n_components))
134135
assert_equal(dictionary.shape, (n_components, n_features))
135136
assert_equal(np.dot(code, dictionary).shape, X.shape)

0 commit comments

Comments
 (0)
0