8000 REFACTOR: dic_learning and dict_learning_online refactored into a sin… · scikit-learn/scikit-learn@584b955 · GitHub
[go: up one dir, main page]

Skip to content

Commit 584b955

Browse files
committed
REFACTOR: dic_learning and dict_learning_online refactored into a single function
1 parent 9525461 commit 584b955

File tree

1 file changed

+113
-116
lines changed

1 file changed

+113
-116
lines changed

sklearn/decomposition/dict_learning.py

Lines changed: 113 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -309,19 +309,45 @@ def sparse_encode(X, dictionary, gram=None, cov=None, algorithm='lasso_lars',
309309
return code
310310

311311

312-
def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
313-
random_state=None):
314-
"""Update the dense dictionary factor in place.
312+
def _project_l2(v, copy=False):
313+
"""Projects v unto l2 unit ball.
314+
"""
315+
if copy:
316+
v = v.copy()
317+
vv = np.dot(v, v)
318+
if vv > 1.:
319+
v /= sqrt(vv)
320+
return v
321+
322+
323+
def _update_dict(dictionary, B, A, verbose=False, return_r2=False,
324+
random_state=None, online=False):
325+
"""Update the dense dictionary factor in place. It does a pass of BCD
326+
to update by minimizing
327+
328+
.5 * ||Xt - DC||_F^2 = .5 * tr(DtDA) - 2 * tr(DtB) + tr(XXt),
329+
330+
as a function of the dictionary (D), where
331+
332+
B = XtCt, A = CCt, R = B - DA, C = matrix of codes
333+
334+
Note that the update of the kth atom is given by (see eqn 10 of
335+
ref paper below):
336+
337+
R = R + D[:, k]A[k, :] # rank-1 update
338+
D[:, k] = R[:, k] / A[k, k] # = D[:, k] + R[k] / A[k, k]
339+
D[:, k] = proj(D[:, k]) # eqn 10
340+
R = R - D[:, k]A[k, :] # rank-1 update
315341
316342
Parameters
317343
----------
318344
dictionary : array of shape (n_features, n_components)
319345
Value of the dictionary at the previous iteration.
320346
321-
Y : array of shape (n_features, n_samples)
347+
B : array of shape (n_features, n_components)
322348
Data matrix.
323349
324-
code : array of shape (n_components, n_samples)
350+
A : array of shape (n_components, n_components)
325351
Sparse coding of the data against which to optimize the dictionary.
326352
327353
verbose:
@@ -342,36 +368,39 @@ def _update_dict(dictionary, Y, code, verbose=False, return_r2=False,
342368
dictionary : array of shape (n_features, n_components)
343369
Updated dictionary.
344370
371+
Notes
372+
-----
373+
**References:**
374+
375+
J. Mairal, F. Bach, J. Ponce, G. Sapiro, 2009: Online dictionary learning
376+
for sparse coding (http://www.di.ens.fr/sierra/pdfs/icml09.pdf)
377+
345378
"""
346-
n_components = len(code)
347-
n_samples = Y.shape[0]
379+
n_features, n_components = B.shape
348380
random_state = check_random_state(random_state)
349381
# Residuals, computed 'in-place' for efficiency
350-
R = -np.dot(dictionary, code)
351-
R += Y
382+
R = -np.dot(dictionary, A)
383+
R += B
352384
R = np.asfortranarray(R)
353-
ger, = linalg.get_blas_funcs(('ger',), (dictionary, code))
385+
ger, = linalg.get_blas_funcs(('ger',), (dictionary, A))
354386
for k in range(n_components):
355387
# R <- 1.0 * U_k * V_k^T + R
356-
R = ger(1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
357-
dictionary[:, k] = np.dot(R, code[k, :].T)
388+
R = ger(1.0, dictionary[:, k], A[k, :], a=R, overwrite_a=True)
358389
# Scale k'th atom
359-
atom_norm_square = np.dot(dictionary[:, k], dictionary[:, k])
360-
if atom_norm_square < 1e-20:
390+
if A[k, k] < 1e-20:
361391
if verbose == 1:
362392
sys.stdout.write("+")
363393
sys.stdout.flush()
364394
elif verbose:
365395
print("Adding new random atom")
366-
dictionary[:, k] = random_state.randn(n_samples)
396+
dictionary[:, k] = random_state.randn(n_features)
367397
# Setting corresponding coefs to 0
368-
code[k, :] = 0.0
369-
dictionary[:, k] /= sqrt(np.dot(dictionary[:, k],
370-
dictionary[:, k]))
398+
A[k, :] = 0.
371399
else:
372-
dictionary[:, k] /= sqrt(atom_norm_square)
373-
# R <- -1.0 * U_k * V_k^T + R
374-
R = ger(-1.0, dictionary[:, k], code[k, :], a=R, overwrite_a=True)
400+
dictionary[:, k] = R[:, k] / A[k, k]
401+
dictionary[:, k] = _project_l2(dictionary[:, k])
402+
# R <- -1.0 * U_k * V_k^T + R
403+
R = ger(-1.0, dictionary[:, k], A[k, :], a=R, overwrite_a=True)
375404
if return_r2:
376405
R **= 2
377406
# R is fortran-ordered. For numpy version < 1.6, sum does not
@@ -434,7 +463,7 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
434463
code_init : array of shape (n_samples, n_components),
435464
Initial value for the sparse code for warm restart scenarios.
436465
437-
callback :
466+
callback : optional (default=None)
438467
Callable that gets invoked every five iterations.
439468
440469
verbose :
@@ -472,111 +501,71 @@ def dict_learning(X, n_components, alpha, max_iter=100, tol=1e-8,
472501
SparsePCA
473502
MiniBatchSparsePCA
474503
"""
475-
if method not in ('lars', 'cd'):
476-
raise ValueError('Coding method %r not supported as a fit algorithm.'
477-
% method)
478-
method = 'lasso_' + method
479-
480-
t0 = time.time()
481-
# Avoid integer division problems
482-
alpha = float(alpha)
483-
random_state = check_random_state(random_state)
484-
485-
if n_jobs == -1:
486-
n_jobs = cpu_count()
487-
488-
# Init the code and the dictionary with SVD of Y
489-
if code_init is not None and dict_init is not None:
490-
code = np.array(code_init, order='F')
491-
# Don't copy V, it will happen below
492-
dictionary = dict_init
493-
else:
494-
code, S, dictionary = linalg.svd(X, full_matrices=False)
495-
dictionary = S[:, np.newaxis] * dictionary
496-
r = len(dictionary)
497-
if n_components <= r: # True even if n_components=None
498-
code = code[:, :n_components]
499-
dictionary = dictionary[:n_components, :]
500-
else:
501-
code = np.c_[code, np.zeros((len(code), n_components - r))]
502-
dictionary = np.r_[dictionary,
503-
np.zeros((n_components - r, dictionary.shape[1]))]
504-
505-
# Fortran-order dict, as we are going to access its row vectors
506-
dictionary = np.array(dictionary, order='F')
507-
508-
residuals = 0
509-
510504
errors = []
511-
current_cost = np.nan
512505

513-
if verbose == 1:
514-
print('[dict_learning]', end=' ')
515-
516-
# If max_iter is 0, number of iterations returned should be zero
517-
ii = -1
518-
519-
for ii in range(max_iter):
520-
dt = (time.time() - t0)
521-
if verbose == 1:
522-
sys.stdout.write(".")
523-
sys.stdout.flush()
524-
elif verbose:
525-
print("Iteration % 3i "
526-
"(elapsed time: % 3is, % 4.1fmn, current cost % 7.3f)"
527-
% (ii, dt, dt / 60, current_cost))
528-
529-
# Update code
530-
code = sparse_encode(X, dictionary, algorithm=method, alpha=alpha,
531-
init=code, n_jobs=n_jobs)
532-
# Update dictionary
533-
dictionary, residuals = _update_dict(dictionary.T, X.T, code.T,
534-
verbose=verbose, return_r2=True,
535-
random_state=random_state)
536-
dictionary = dictionary.T
537-
538-
# Cost function
539-
current_cost = 0.5 * residuals + alpha * np.sum(np.abs(code))
540-
errors.append(current_cost)
541-
542-
if ii > 0:
506+
def _callback(env):
507+
# Check convergence
508+
residuals = env["dictionary"].dot(env["this_code"])
509+
residuals -= env["this_X"].T
510+
residuals **= 2
511+
residuals = np.sum(residuals)
512+
err = .5 * residuals + alpha * np.sum(np.abs(env["this_code"]))
513+
errors.append(err)
514+
if len(errors) > 1:
543515
dE = errors[-2] - errors[-1]
544516
# assert(dE >= -tol * errors[-1])
545-
if dE < tol * errors[-1]:
517+
if np.abs(dE) < tol * errors[-1]:
546518
if verbose == 1:
547519
# A line return
548520
print("")
549521
elif verbose:
550-
print("--- Convergence reached after %d iterations" % ii)
551-
break
552-
if ii % 5 == 0 and callback is not None:
553-
callback(locals())
554-
522+
print(
523+
"--- Convergence reached after %d iterations" % (
524+
env["ii"]))
525+
return False
526+
return True
527+
528+
# call unified dict-learning API in batch-mode
529+
out = dict_learning_online(
530+
X, n_components=n_components, alpha=alpha, n_iter=max_iter,
531+
return_code=True, dict_init=dict_init, callback=_callback,
532+
batch_size=len(X), verbose=verbose, shuffle=False,
533+
return_n_iter=return_n_iter, n_jobs=n_jobs, method=method,
534+
return_inner_stats=False, tol=tol)
555535
if return_n_iter:
556-
return code, dictionary, errors, ii + 1
536+
code, dictionary, n_iter = out
537+
return code, dictionary, errors, n_iter
557538
else:
539+
code, dictionary = out
558540
return code, dictionary, errors
559541

560542

561543
def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
562544
return_code=True, dict_init=None, callback=None,
563-
batch_size=3, verbose=False, shuffle=True, n_jobs=1,
564-
method='lars', iter_offset=0, random_state=None,
565-
return_inner_stats=False, inner_stats=None,
566-
return_n_iter=False):
545+
batch_size=3, verbose=False, shuffle=True,
546+
n_jobs=1, method='lars', iter_offset=0, tol=0.,
547+
random_state=None, return_inner_stats=False,
548+
inner_stats=None, return_n_iter=False):
567549
"""Solves a dictionary learning matrix factorization problem online.
568550
569551
Finds the best dictionary and the corresponding sparse code for
570552
approximating the data matrix X by solving::
571553
572554
(U^*, V^*) = argmin 0.5 || X - U V ||_2^2 + alpha * || U ||_1
573555
(U,V)
574-
with || V_k ||_2 = 1 for all 0 <= k < n_components
556+
with || V_k ||_2 <= 1 for all 0 <= k < n_components
575557
576558
where V is the dictionary and U is the sparse code. This is
577559
accomplished by repeatedly iterating over mini-batches by slicing
578560
the input data.
579561
562+
This function has two modes:
563+
564+
1. Batch mode, activated when batch_size is None or
565+
batch_size >= n_samples. This is the default.
566+
567+
2. Online mode, activated when batch_size < n_samples.
568+
580569
Read more in the :ref:`User Guide <DictionaryLearning>`.
581570
582571
Parameters
@@ -593,14 +582,18 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
593582
n_iter : int,
594583
Number of iterations to perform.
595584
585+
tol : float,
586+
Tolerance for the stopping condition. Used in batch mode.
587+
596588
return_code : boolean,
597589
Whether to also return the code U or just the dictionary V.
598590
599591
dict_init : array of shape (n_components, n_features),
600592
Initial value for the dictionary for warm restart scenarios.
601593
602594
callback :
603-
Callable that gets invoked every five iterations.
595+
Callable that gets invoked every five iterations. If it returns
596+
non-True, then the main loop (iteration on data) is aborted.
604597
605598
batch_size : int,
606599
The number of samples to take in each batch.
@@ -711,6 +704,9 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
711704
copy=False)
712705
X_train = check_array(X_train, order='C', dtype=np.float64, copy=False)
713706

707+
if batch_size is None:
708+
batch_size = n_samples
709+
online = batch_size < n_samples
714710
batches = gen_batches(n_samples, batch_size)
715711
batches = itertools.cycle(batches)
716712

@@ -741,26 +737,27 @@ def dict_learning_online(X, n_components=2, alpha=1, n_iter=100,
741737
alpha=alpha, n_jobs=n_jobs).T
742738

743739
# Update the auxiliary variables
744-
if ii < batch_size - 1:
745-
theta = float((ii + 1) * batch_size)
740+
if online:
741+
if ii < batch_size - 1:
742+
theta = float((ii + 1) * batch_size)
743+
else:
744+
theta = float(batch_size ** 2 + ii + 1 - batch_size)
745+
beta = (theta + 1 - batch_size) / (theta + 1)
746+
A += np.dot(this_code, this_code.T)
747+
B *= beta
748+
B += np.dot(this_X.T, this_code.T)
746749
else:
747-
theta = float(batch_size ** 2 + ii + 1 - batch_size)
748-
beta = (theta + 1 - batch_size) / (theta + 1)
749-
750-
A *= beta
751-
A += np.dot(this_code, this_code.T)
752-
B *= beta
753-
B += np.dot(this_X.T, this_code.T)
750+
A = np.dot(this_code, this_code.T)
751+
B = np.dot(this_X.T, this_code.T)
754752

755753
# Update dictionary
756754
dictionary = _update_dict(dictionary, B, A, verbose=verbose,
757-
random_state=random_state)
758-
# XXX: Can the residuals be of any use?
755+
random_state=random_state, online=True,
756+
return_r2=False)
759757

760-
# Maybe we need a stopping criteria based on the amount of
761-
# modification in the dictionary
762-
if callback is not None:
763-
callback(locals())
758+
# Check convergence
759+
if callback is not None and not callback(locals()):
760+
break
764761

765762
if return_inner_stats:
766763
if return_n_iter:

0 commit comments

Comments
 (0)
0