8000 FIX online updates in MiniBatchDictionaryLearning (#25354) · dolfly/scikit-learn@cfd428a · GitHub
[go: up one dir, main page]

Skip to content

Commit cfd428a

Browse files
jeremiedbbogrisel
andauthored
FIX online updates in MiniBatchDictionaryLearning (scikit-learn#25354)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent d431d7e commit cfd428a

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

doc/whats_new/v1.2.rst

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@ Version 1.2.1
99

1010
**In Development**
1111

12+
Changed models
13+
--------------
14+
15+
The following estimators and functions, when fit with the same data and
16+
parameters, may produce different models from the previous version. This often
17+
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
18+
random sampling procedures.
19+
20+
- |Fix| The fitted components in :class:`MiniBatchDictionaryLearning` might differ. The
21+
online updates of the sufficient statistics now properly take the sizes of the batches
22+
into account.
23+
:pr:`25354` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
24+
1225
Changelog
1326
---------
1427

@@ -33,6 +46,11 @@ Changelog
3346
:mod:`sklearn.decomposition`
3447
............................
3548

49+
- |Fix| Fixed a bug in :class:`decomposition.MiniBatchDictionaryLearning` where the
50+
online updates of the sufficient statistics where not correct when calling
51+
`partial_fit` on batches of different sizes.
52+
:pr:`25354` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
53+
3654
- |Fix| :class:`decomposition.DictionaryLearning` better supports readonly NumPy
3755
arrays. In particular, it better supports large datasets which are memory-mapped
3856
when it is used with coordinate descent algorithms (i.e. when `fit_algorithm='cd'`).

sklearn/decomposition/_dict_learning.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2053,16 +2053,16 @@ class MiniBatchDictionaryLearning(_BaseSparseCoding, BaseEstimator):
20532053
20542054
We can check the level of sparsity of `X_transformed`:
20552055
2056-
>>> np.mean(X_transformed == 0)
2057-
0.38...
2056+
>>> np.mean(X_transformed == 0) < 0.5
2057+
True
20582058
20592059
We can compare the average squared euclidean norm of the reconstruction
20602060
error of the sparse coded signal relative to the squared euclidean norm of
20612061
the original signal:
20622062
20632063
>>> X_hat = X_transformed @ dict_learner.components_
20642064
>>> np.mean(np.sum((X_hat - X) ** 2, axis=1) / np.sum(X ** 2, axis=1))
2065-
0.059...
2065+
0.057...
20662066
"""
20672067

20682068
_parameter_constraints: dict = {
@@ -2196,9 +2196,9 @@ def _update_inner_stats(self, X, code, batch_size, step):
21962196
beta = (theta + 1 - batch_size) / (theta + 1)
21972197

21982198
self._A *= beta
2199-
self._A += code.T @ code
2199+
self._A += code.T @ code / batch_size
22002200
self._B *= beta
2201-
self._B += X.T @ code
2201+
self._B += X.T @ code / batch_size
22022202

22032203
def _minibatch_step(self, X, dictionary, random_state, step):
22042204
"""Perform the update on the dictionary for one minibatch."""

0 commit comments

Comments
 (0)
0