8000 REFACTOR: dic_learning and dict_learning_online refactored into a sin… · scikit-learn/scikit-learn@9339fc0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9339fc0

Browse files
committed
REFACTOR: dic_learning and dict_learning_online refactored into a single function
1 parent 0fb9a50 commit 9339fc0

File tree

2 files changed

+181
-101
lines changed

2 files changed

+181
-101
lines changed

sklearn/decomposition/dict_learning.py

Lines changed: 122 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -309,19 +309,45 @@ def sparse_encode(X, dictionary, gram=None, cov=None, algorithm='lasso_lars',
309309
return code
310310

311311

312-
def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
312+
def _project_l2(atom, copy=False):
313+
"""Projects dictionary atom onto l2 unit ball.
314+
"""
315+
if copy:
316+
atom = atom.copy()
317+
atom_norm_squared = np.dot(atom, atom)
318+
if atom_norm_squared > 1.:
319+
atom /= sqrt(atom_norm_squared)
320+
return atom
321+
322+
323+
def _update_dict(dictionary, B, A, verbose=False, return_r2=False,
313324
random_state=None):
314-
"""Update the dense dictionary factor in place.
325+
"""Update the dense dictionary factor in place. It does a pass of BCD
326+
to update by minimizing
327+
328+
.5 * ||Xt - DC||_F^2 = .5 * tr(DtDA) - 2 * tr(DtB) + tr(XXt),
329+
330+
as a function of the dictionary (D), where
331+
332+
B = XtCt, A = CCt, R = B - DA, C = matrix of codes
333+
334+
Note that the update of the kth atom is given by (see eqn 10 of
335+
ref paper below):
336+
337+
R = R + D[:, k]A[k, :] # rank-1 update
338+
D[:, k] = R[:, k] / A[k, k] # = D[:, k] + R[k] / A[k, k]
339+
D[:, k] = proj(D[:, k]) # eqn 10
340+
R = R - D[:, k]A[k, :] # rank-1 update
315341
316342
Parameters
317343
----------
318344
dictionary : array of shape (n_features, n_components)
319345
Value of the dictionary at the previous iteration.
320346
321-
Y : array of shape (n_features, n_samples)
347+
B : array of shape (n_features, n_components)
322348
Data matrix.
323349
324-
code : array of shape (n_components, n_samples)
350+
A : array of shape (n_components, n_components)
325351
Sparse coding of the data against which to optimize the dictionary.
326352
327353
verbose:
@@ -342,36 +368,39 @@ def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
342368
dictionary : array of shape (n_features, n_components)
343369
Updated dictionary.
344370
371+
Notes
372+
-----
373+
**References:**
374+
375+
J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning
376+
for sparse coding (http://www.di.ens.fr/sierra/pdfs/icml09.pdf)
377+
345378
"""
346-
n_components = len(code)
347-
n_samples = Y.shape[0]
379+
n_features, n_components = B.shape
348380
random_state = check_random_state(random_state)
349381
# Residuals, computed 'in-place' for efficiency
350-
R = -np.dot(dictionary, code)
351-
R += Y
382+
R = -np.dot(dictionary, A)
383+
R += B
352384
R = np.asfortranarray(R)
353-
ger, = linalg.get_blas_funcs(('ger',), (dictionary, code))
385+
ger, = linalg.get_blas_funcs(('ger',), (dictionary, A))
354386
for k in range(n_components):
355387
# R <- 1.0 * U_k * V_k^T + R
356-
R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
357-
dictionary[:, k] = np.dot(R, code[k, :].T)
388+
R = ger(1.0, dictionary[:, k], A[k, :], a=R, overwrite_a=True)
358389
# Scale k'th atom
359-
atom_norm_square = np.dot(dictionary[:, k], dictionary[:, k])
360-
if atom_norm_square < 1e-20:
390+
if A[k, k] < 1e-20:
361391
if verbose == 1:
362392
sys.stdout.write("+")
363393
sys.stdout.flush()
364394
elif verbose:
365395
print("Adding new random atom")
366-
dictionary[:, k] = random_state.randn(n_samples)
396+
dictionary[:, k] = random_state.randn(n_features)
367397
# Setting corresponding coefs to 0
368-
code[k, :] = 0.0
369-
dictionary[:, k] /= sqrt(np.dot(dictionary[:, k],
370-
dictionary[:, k]))
398+
A[k, :] = 0.
371399
else:
372-
dictionary[:, k] /= sqrt(atom_norm_square)
373-
# R <- -1.0 * U_k * V_k^T + R
374-
R = ger(-1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
400+
dictionary[:, k] = R[:, k] / A[k, k]
401+
dictionary[:, k] = _project_l2(dictionary[:, k])
402+
# R <- -1.0 * U_k * V_k^T + R
403+
R = ger(-1.0, dictionary[:, k], A[k, :], a=R, overwrite_a=True)
375404
if return_r2:
376405
R **= 2
377406
# R is fortran-ordered. For numpy version < 1.6, sum does not
@@ -434,7 +463,7 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
434463
code_init : array of shape (n_samples, n_components),
435464
Initial value for the sparse code for warm restart scenarios.
436465
437-
callback :
466+
callback : optional (default=None)
438467
Callable that gets invoked every five iterations.
439468
440469
verbose :
@@ -472,20 +501,7 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
472501
SparsePCA
473502
MiniBatchSparsePCA
474503
"""
475-
if method not in ('lars', 'cd'):
476-
raise ValueError('Coding method %r not supported as a fit algorithm.'
477-
% method)
478-
method = 'lasso_' + method
479-
480-
t0 = time.time()
481-
# Avoid integer division problems
482-
alpha = float(alpha)
483-
random_state = check_random_state(random_state)
484-
485-
if n_jobs == -1:
486-
n_jobs = cpu_count()
487-
488-
# Init the code and the dictionary with SVD of Y
504+
# Init the code and the dictionary with SVD of X
489505
if code_init is not None and dict_init is not None:
490506
code = np.array(code_init, order='F')
491507
# Don't copy V, it will happen below
@@ -494,67 +510,57 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
494510
code, S, dictionary = linalg.svd(X, full_matrices=False)
495511
dictionary = S[:, np.newaxis] * dictionary
496512
r = len(dictionary)
497-
if n_components <= r: # True even if n_components=None
513+
if n_components <= r: # True even if n_components is None
498514
code = code[:, :n_components]
499515
dictionary = dictionary[:n_components, :]
500516
else:
501517
code = np.c_[code, np.zeros((len(code), n_components - r))]
502518
dictionary = np.r_[dictionary,
503519
np.zeros((n_components - r, dictionary.shape[1]))]
504520

505-
# Fortran-order dict, as we are going to access its row vectors
506-
dictionary = np.array(dictionary, order='F')
507-
508-
residuals = 0
509-
510521
errors = []
511-
current_cost = np.nan
512-
513-
if verbose == 1:
514-
print('[dict_learning]', end=' ')
515-
516-
# If max_iter is 0, number of iterations returned should be zero
517-
ii = -1
518-
519-
for ii in range(max_iter):
520-
dt = (time.time() - t0)
521-
if verbose == 1:
522-
sys.stdout.write(".")
523-
sys.stdout.flush()
524-
elif verbose:
525-
print("Iteration % 3i "
526-
"(elapsed time: % 3is, % 4.1fmn, current cost % 7.3f)"
527-
% (ii, dt, dt / 60, current_cost))
528-
529-
# Update code
530-
code = sparse_encode(X, dictionary, algorithm=method, alpha=alpha,
531-
init=code, n_jobs=n_jobs)
532-
# Update dictionary
533-
dictionary, residuals = _update_dict(dictionary.T, X.T, code.T,
534-
verbose=verbose, return_r2=True,
535-
random_state=random_state)
536-
dictionary = dictionary.T
537522

538-
# Cost function
539-
current_cost = 0.5 * residuals + alpha * np.sum(np.abs(code))
540-
errors.append(current_cost)
541-
542-
if ii > 0:
523+
def _callback(env):
524+
"""Callback for checking convergence.
525+
"""
526+
residuals = env["dictionary"].dot(env["this_code"])
527+
residuals -= env["this_X"].T
528+
residuals **= 2
529+
residuals = np.sum(residuals)
530+
err = .5 * residuals + alpha * np.sum(np.abs(env["this_code"]))
531+
errors.append(err)
532+
if len(errors) > 1:
543533
dE = errors[-2] - errors[-1]
544534
# assert(dE >= -tol * errors[-1])
545-
if dE < tol * errors[-1]:
535+
if np.abs(dE) < tol * errors[-1]:
546536
if verbose == 1:
547537
# A line return
548538
print("")
549539
elif verbose:
550-
print("--- Convergence reached after %d iterations" % ii)
551-
break
552-
if ii % 5 == 0 and callback is not None:
553-
callback(locals())
554-
540+
print(
541+
"--- Convergence reached after %d iterations" % (
542+
env["ii"]))
543+
return False
544+
return True
545+
546+
# call unified dict-learning API in batch-mode
547+
if min(max_iter, len(X)) < 1:
548+
# nothing to do
549+
if return_n_iter:
550+
return code, dictionary, errors, 0
551+
else:
552+
return code, dictionary, errors
553+
out = dict_learning_online(
554+
X, n_components=n_components, alpha=alpha, n_iter=max_iter,
555+
return_code=True, dict_init=dictionary, callback=_callback,
556+
batch_size=len(X), verbose=verbose, shuffle=False, n_jobs=n_jobs,
557+
method=method, random_state=random_state,
558+
return_n_iter=return_n_iter)
555559
if return_n_iter:
556-
return code, dictionary, errors, ii + 1
560+
code, dictionary, n_iter = out
561+
return code, dictionary, errors, n_iter
557562
else:
563+
code, dictionary = out
558564
return code, dictionary, errors
559565

560566

@@ -571,12 +577,19 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
571577
572578
(U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1
573579
(U,V)
574-
with || V_k ||_2 = 1 for all 0 <= k < n_components
580+
with || V_k ||_2 <= 1 for all 0 <= k < n_components
575581
576582
where V is the dictionary and U is the sparse code. This is
577583
accomplished by repeatedly iterating over mini-batches by slicing
578584
the input data.
579585
586+
This function has two modes:
587+
588+
1. Batch mode, activated when batch_size is None or
589+
batch_size >= n_samples. This is the default.
590+
591+
2. Online mode, activated when batch_size < n_samples.
592+
580593
Read more in the :ref:`User Guide <DictionaryLearning>`.
581594
582595
Parameters
@@ -600,7 +613,8 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
600613
Initial value for the dictionary for warm restart scenarios.
601614
602615
callback :
603-
Callable that gets invoked every five iterations.
616+
Callable that gets invoked every five iterations. If it returns
617+
non-True, then the main loop (iteration on data) is aborted.
604618
605619
batch_size : int,
606620
The number of samples to take in each batch.
@@ -668,15 +682,18 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
668682
MiniBatchSparsePCA
669683
670684
"""
685+
n_samples, n_features = X.shape
671686
if n_components is None:
672-
n_components = X.shape[1]
687+
n_components = n_features
688+
if batch_size is None:
689+
batch_size = n_samples
690+
online = batch_size < n_samples
673691

674692
if method not in ('lars', 'cd'):
675693
raise ValueError('Coding method not supported as a fit algorithm.')
676694
method = 'lasso_' + method
677695

678696
t0 = time.time()
679-
n_samples, n_features = X.shape
680697
# Avoid integer division problems
681698
alpha = float(alpha)
682699
random_state = check_random_state(random_state)
@@ -688,8 +705,11 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
688705
if dict_init is not None:
689706
dictionary = dict_init
690707
else:
691-
_, S, dictionary = randomized_svd(X, n_components,
692-
random_state=random_state)
708+
if online:
709+
code, S, dictionary = randomized_svd(X, n_components,
710+
random_state=random_state)
711+
else:
712+
code, S, dictionary = linalg.svd(X, full_matrices=False)
693713
dictionary = S[:, np.newaxis] * dictionary
694714
r = len(dictionary)
695715
if n_components <= r:
@@ -741,26 +761,28 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
741761
alpha=alpha, n_jobs=n_jobs).T
742762

743763
# Update the auxiliary variables
744-
if ii < batch_size - 1:
745-
theta = float((ii + 1) * batch_size)
764+
if online:
765+
if ii < batch_size - 1:
766+
theta = float((ii + 1) * batch_size)
767+
else:
768+
theta = float(batch_size ** 2 + ii + 1 - batch_size)
769+
beta = (theta + 1 - batch_size) / (theta + 1)
770+
A *= beta
771+
A += np.dot(this_code, this_code.T)
772+
B *= beta
773+
B += np.dot(this_X.T, this_code.T)
746774
else:
747-
theta = float(batch_size ** 2 + ii + 1 - batch_size)
748-
beta = (theta + 1 - batch_size) / (theta + 1)
749-
750-
A *= beta
751-
A += np.dot(this_code, this_code.T)
752-
B *= beta
753-
B += np.dot(this_X.T, this_code.T)
775+
A = np.dot(this_code, this_code.T)
776+
B = np.dot(this_X.T, this_code.T)
754777

755778
# Update dictionary
756779
dictionary = _update_dict(dictionary, B, A, verbose=verbose,
757-
random_state=random_state)
758-
# XXX: Can the residuals be of any use?
780+
random_state=random_state,
781+
return_r2=False)
759782

760-
# Maybe we need a stopping criteria based on the amount of
761-
# modification in the dictionary
762-
if callback is not None:
763-
callback(locals())
783+
# Check convergence
784+
if callback is not None and not callback(locals()):
785+
break
764786

765787
if return_inner_stats:
766788
if return_n_iter:

0 commit comments

Comments
 (0)
0