10000 ENH Calculate normed stress (Stress-1) in `manifold.MDS` (#22562) · scikit-learn/scikit-learn@ae51c13 · GitHub
[go: up one dir, main page]

Skip to content

Commit ae51c13

Browse files
Micky774cmarmorotheconradglemaitrethomasjpfan
authored
ENH Calculate normed stress (Stress-1) in manifold.MDS (#22562)
Co-authored-by: Chiara Marmo <cmarmo@users.noreply.github.com> Co-authored-by: Roth E Conrad <rotheconrad@gatech.edu> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 1a3cbd0 commit ae51c13

File tree

4 files changed

+131
-27
lines changed

4 files changed

+131
-27
lines changed

doc/modules/manifold.rst

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -462,14 +462,28 @@ Nonmetric MDS
462462
-------------
463463

464464
Non metric :class:`MDS` focuses on the ordination of the data. If
465-
:math:`S_{ij} < S_{jk}`, then the embedding should enforce :math:`d_{ij} <
466-
d_{jk}`. A simple algorithm to enforce that is to use a monotonic regression
467-
of :math:`d_{ij}` on :math:`S_{ij}`, yielding disparities :math:`\hat{d}_{ij}`
468-
in the same order as :math:`S_{ij}`.
465+
:math:`S_{ij} > S_{jk}`, then the embedding should enforce :math:`d_{ij} <
466+
d_{jk}`. For this reason, we discuss it in terms of dissimilarities
467+
(:math:`\delta_{ij}`) instead of similarities (:math:`S_{ij}`). Note that
468+
dissimilarities can easily be obtained from similarities through a simple
469+
transform, e.g. :math:`\delta_{ij}=c_1-c_2 S_{ij}` for some real constants
470+
:math:`c_1, c_2`. A simple algorithm to enforce proper ordination is to use a
471+
monotonic regression of :math:`d_{ij}` on :math:`\delta_{ij}`, yielding
472+
disparities :math:`\hat{d}_{ij}` in the same order as :math:`\delta_{ij}`.
469473

470474
A trivial solution to this problem is to set all the points on the origin. In
471-
order to avoid that, the disparities :math:`\hat{d}_{ij}` are normalized.
475+
order to avoid that, the disparities :math:`\hat{d}_{ij}` are normalized. Note
476+
that since we only care about relative ordering, our objective should be
477+
invariant to simple translation and scaling, however the stress used in metric
478+
MDS is sensitive to scaling. To address this, non-metric MDS may use a
479+
normalized stress, known as Stress-1 defined as
472480

481+
.. math::
482+
\sqrt{\frac{\sum_{i < j} (d_{ij} - \hat{d}_{ij})^2}{\sum_{i < j} d_{ij}^2}}.
483+
484+
The use of normalized Stress-1 can be enabled by setting `normalized_stress=True`,
485+
however it is only compatible with the non-metric MDS problem and will be ignored
486+
in the metric case.
473487

474488
.. figure:: ../auto_examples/manifold/images/sphx_glr_plot_mds_001.png
475489
:target: ../auto_examples/manifold/plot_mds.html
@@ -484,11 +498,11 @@ order to avoid that, the disparities :math:`\hat{d}_{ij}` are normalized.
484498
Borg, I.; Groenen P. Springer Series in Statistics (1997)
485499

486500
* `"Nonmetric multidimensional scaling: a numerical method"
487-
<https://link.springer.com/article/10.1007%2FBF02289694>`_
501+
<http://cda.psych.uiuc.edu/psychometrika_highly_cited_articles/kruskal_1964b.pdf>`_
488502
Kruskal, J. Psychometrika, 29 (1964)
489503

490504
* `"Multidimensional scaling by optimizing goodness of fit to a nonmetric hypothesis"
491-
<https://link.springer.com/article/10.1007%2FBF02289565>`_
505+
<http://cda.psych.uiuc.edu/psychometrika_highly_cited_articles/kruskal_1964a.pdf>`_
492506
Kruskal, J. Psychometrika, 29, (1964)
493507

494508
.. _t_sne:

doc/whats_new/v1.2.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,14 @@ Changelog
233233
:mod:`sklearn.manifold`
234234
.......................
235235

236+
- |Feature| Adds option to use the normalized stress in `manifold.MDS`. This is
237+
enabled by setting the new `normalize` parameter to `True`.
238+
:pr:`10168` by :user:`Łukasz Borchmann <Borchmann>`,
239+
:pr:`12285` by :user:`Matthias Miltenberger <mattmilten>`,
240+
:pr:`13042` by :user:`Matthieu Parizy <matthieu-pa>`,
241+
:pr:`18094` by :user:`Roth E Conrad <rotheconrad>` and
242+
:pr:`22562` by :user:`Meekail Zain <micky774>`.
243+
236244
- |Enhancement| Adds `eigen_tol` parameter to
237245
:class:`manifold.SpectralEmbedding`. Both :func:`manifold.spectral_embedding`
238246
and :class:`manifold.SpectralEmbedding` now propogate `eigen_tol` to all

sklearn/manifold/_mds.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def _smacof_single(
2929
verbose=0,
3030
eps=1e-3,
3131
random_state=None,
32+
normalized_stress=False,
3233
):
3334
"""Computes multidimensional scaling using SMACOF algorithm.
3435
@@ -58,13 +59,21 @@ def _smacof_single(
5859
5960
eps : float, default=1e-3
6061
Relative tolerance with respect to stress at which to declare
61-
convergence.
62+
convergence. The value of `eps` should be tuned separately depending
63+
on whether or not `normalized_stress` is being used.
6264
6365
random_state : int, RandomState instance or None, default=None
6466
Determines the random number generator used to initialize the centers.
6567
Pass an int for reproducible results across multiple function calls.
6668
See :term:`Glossary <random_state>`.
6769
70+
normalized_stress : bool, default=False
71+
Whether use and return normed stress value (Stress-1) instead of raw
72+
stress calculated by default. Only supported in non-metric MDS. The
73+
caller must ensure that if `normalized_stress=True` then `metric=False`
74+
75+
.. versionadded:: 1.2
76+
6877
Returns
6978
-------
7079
X : ndarray of shape (n_samples, n_components)
@@ -73,9 +82,23 @@ def _smacof_single(
7382
stress : float
7483
The final value of the stress (sum of squared distance of the
7584
disparities and the distances for all constrained points).
85+
If `normalized_stress=True`, and `metric=False` returns Stress-1.
86+
A value of 0 indicates "perfect" fit, 0.025 excellent, 0.05 good,
87+
0.1 fair, and 0.2 poor [1]_.
7688
7789
n_iter : int
7890
The number of iterations corresponding to the best stress.
91+
92+
References
93+
----------
94+
.. [1] "Nonmetric multidimensional scaling: a numerical method" Kruskal, J.
95+
Psychometrika, 29 (1964)
96+
97+
.. [2] "Multidimensional scaling by optimizing goodness of fit to a nonmetric
98+
hypothesis" Kruskal, J. Psychometrika, 29, (1964)
99+
100+
.. [3] "Modern Multidimensional Scaling - Theory and Applications" Borg, I.;
101+
Groenen P. Springer Series in Statistics (1997)
79102
"""
80103
dissimilarities = check_symmetric(dissimilarities, raise_exception=True)
81104

@@ -121,7 +144,8 @@ def _smacof_single(
121144

122145
# Compute stress
123146
stress = ((dis.ravel() - disparities.ravel()) ** 2).sum() / 2
124-
147+
if normalized_stress:
148+
stress = np.sqrt(stress / ((disparities.ravel() ** 2).sum() / 2))
125149
# Update X using the Guttman transform
126150
dis[dis == 0] = 1e-5
127151
ratio = disparities / dis
@@ -155,6 +179,7 @@ def smacof(
155179
eps=1e-3,
156180
random_state=None,
157181
return_n_iter=False,
182+
normalized_stress=False,
158183
):
159184
"""Compute multidimensional scaling using the SMACOF algorithm.
160185
@@ -217,7 +242,8 @@ def smacof(
217242
218243
eps : float, default=1e-3
219244
Relative tolerance with respect to stress at which to declare
220-
convergence.
245+
convergence. The value of `eps` should be tuned separately depending
246+
on whether or not `normalized_stress` is being used.
221247
222248
random_state : int, RandomState instance or None, default=None
223249
Determines the random number generator used to initialize the centers.
@@ -227,6 +253,12 @@ def smacof(
227253
return_n_iter : bool, default=False
228254
Whether or not to return the number of iterations.
229255
256+
normalized_stress : bool, default=False
257+
Whether use and return normed stress value (Stress-1) instead of raw
258+
stress calculated by default. Only supported in non-metric MDS.
259+
260+
.. versionadded:: 1.2
261+
230262
Returns
231263
-------
232264
X : ndarray of shape (n_samples, n_components)
@@ -235,26 +267,33 @@ def smacof(
235267
stress : float
236268
The final value of the stress (sum of squared distance of the
237269
disparities and the distances for all constrained points).
270+
If `normalized_stress=True`, and `metric=False` returns Stress-1.
271+
A value of 0 indicates "perfect" fit, 0.025 excellent, 0.05 good,
272+
0.1 fair, and 0.2 poor [1]_.
238273
239274
n_iter : int
240275
The number of iterations corresponding to the best stress. Returned
241276
only if ``return_n_iter`` is set to ``True``.
242277
243-
Notes
244-
-----
245-
"Modern Multidimensional Scaling - Theory and Applications" Borg, I.;
246-
Groenen P. Springer Series in Statistics (1997)
278+
References
279+
----------
280+
.. [1] "Nonmetric multidimensional scaling: a numerical method" Kruskal, J.
281+
Psychometrika, 29 (1964)
247282
248-
"Nonmetric multidimensional scaling: a numerical method" Kruskal, J.
249-
Psychometrika, 29 (1964)
283+
.. [2] "Multidimensional scaling by optimizing goodness of fit to a nonmetric
284+
hypothesis" Kruskal, J. Psychometrika, 29, (1964)
250285
251-
"Multidimensional scaling by optimizing goodness of fit to a nonmetric
252-
hypothesis" Kruskal, J. Psychometrika, 29, (1964)
286+
.. [3] "Modern Multidimensional Scaling - Theory and Applications" Borg, I.;
287+
Groenen P. Springer Series in Statistics (1997)
253288
"""
254289

255290
dissimilarities = check_array(dissimilarities)
256291
random_state = check_random_state(random_state)
257-
292+
if normalized_stress and metric:
293+
raise ValueError(
294+
"Normalized stress is not supported for metric MDS. Either set"
295+
" `normalized_stress=False` or use `metric=False`."
296+
)
258297
if hasattr(init, "__array__"):
259298
init = np.asarray(init).copy()
260299
if not n_init == 1:
@@ -277,6 +316,7 @@ def smacof(
277316
verbose=verbose,
278317
eps=eps,
279318
random_state=random_state,
319+
normalized_stress=normalized_stress,
280320
)
281321
if best_stress is None or stress < best_stress:
282322
best_stress = stress
@@ -294,6 +334,7 @@ def smacof(
294334
verbose=verbose,
295335
eps=eps,
296336
random_state=seed,
337+
normalized_stress=normalized_stress,
297338
)
298339
for seed in seeds
299340
)
@@ -335,7 +376,8 @@ class MDS(BaseEstimator):
335376
336377
eps : float, default=1e-3
337378
Relative tolerance with respect to stress at which to declare
338-
convergence.
379+
convergence. The value of `eps` should be tuned separately depending
380+
on whether or not `normalized_stress` is being used.
339381
340382
n_jobs : int, default=None
341383
The number of jobs to use for the computation. If multiple
@@ -361,6 +403,12 @@ class MDS(BaseEstimator):
361403
Pre-computed dissimilarities are passed directly to ``fit`` and
362404
``fit_transform``.
363405
406+
normalized_stress : bool, default=False
407+
Whether use and return normed stress value (Stress-1) instead of raw
408+
stress calculated by default. Only supported in non-metric MDS.
409+
410+
.. versionadded:: 1.2
411+
364412
Attributes
365413
----------
366414
embedding_ : ndarray of shape (n_samples, n_components)
@@ -369,6 +417,9 @@ class MDS(BaseEstimator):
369417
stress_ : float
370418
The final value of the stress (sum of squared distance of the
371419
disparities and the distances for all constrained points).
420+
If `normalized_stress=True`, and `metric=False` returns Stress-1.
421+
A value of 0 indicates "perfect" fit, 0.025 excellent, 0.05 good,
422+
0.1 fair, and 0.2 poor [1]_.
372423
373424
dissimilarity_matrix_ : ndarray of shape (n_samples, n_samples)
374425
Pairwise dissimilarities between the points. Symmetric matrix that:
@@ -405,14 +456,14 @@ class MDS(BaseEstimator):
405456
406457
References
407458
----------
408-
"Modern Multidimensional Scaling - Theory and Applications" Borg, I.;
409-
Groenen P. Springer Series in Statistics (1997)
459+
.. [1] "Nonmetric multidimensional scaling: a numerical method" Kruskal, J.
460+
Psychometrika, 29 (1964)
410461
411-
"Nonmetric multidimensional scaling: a numerical method" Kruskal, J.
412-
Psychometrika, 29 (1964)
462+
.. [2] "Multidimensional scaling by optimizing goodness of fit to a nonmetric
463+
hypothesis" Kruskal, J. Psychometrika, 29, (1964)
413464
414-
"Multidimensional scaling by optimizing goodness of fit to a nonmetric
415-
hypothesis" Kruskal, J. Psychometrika, 29, (1964)
465+
.. [3] "Modern M D73A ultidimensional Scaling - Theory and Applications" Borg, I.;
466+
Groenen P. Springer Series in Statistics (1997)
416467
417468
Examples
418469
--------
@@ -437,6 +488,7 @@ class MDS(BaseEstimator):
437488
"n_jobs": [None, Integral],
438489
"random_state": ["random_state"],
439490
"dissimilarity": [StrOptions({"euclidean", "precomputed"})],
491+
"normalized_stress": ["boolean"],
440492
}
441493

442494
def __init__(
@@ -451,6 +503,7 @@ def __init__(
451503
n_jobs=None,
452504
random_state=None,
453505
dissimilarity="euclidean",
506+
normalized_stress=False,
454507
):
455508
self.n_components = n_components
456509
self.dissimilarity = dissimilarity
@@ -461,6 +514,7 @@ def __init__(
461514
self.verbose = verbose
462515
self.n_jobs = n_jobs
463516
self.random_state = random_state
517+
self.normalized_stress = normalized_stress
464518

465519
def _more_tags(self):
466520
return {"pairwise": self.dissimilarity == "precomputed"}
@@ -544,6 +598,7 @@ def fit_transform(self, X, y=None, init=None):
544598
eps=self.eps,
545599
random_state=self.random_state,
546600
return_n_iter=True,
601+
normalized_stress=self.normalized_stress,
547602
)
548603

549604
return self.embedding_

sklearn/manifold/tests/test_mds.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from numpy.testing import assert_array_almost_equal
2+
from numpy.testing import assert_array_almost_equal, assert_allclose
33
import pytest
44

55
from sklearn.manifold import _mds as mds
@@ -42,3 +42,30 @@ def test_MDS():
4242
sim = np.array([[0, 5, 3, 4], [5, 0, 2, 2], [3, 2, 0, 1], [4, 2, 1, 0]])
4343
mds_clf = mds.MDS(metric=False, n_jobs=3, dissimilarity="precomputed")
4444
mds_clf.fit(sim)
45+
46+
47+
@pytest.mark.parametrize("k", [0.5, 1.5, 2])
48+
def test_normed_stress(k):
49+
"""Test that non-metric MDS normalized stress is scale-invariant."""
50+
sim = np.array([[0, 5, 3, 4], [5, 0, 2, 2], [3, 2, 0, 1], [4, 2, 1, 0]])
51+
52+
X1, stress1 = mds.smacof(
53+
sim, metric=False, normalized_stress=True, max_iter=5, random_state=0
54+
)
55+
X2, stress2 = mds.smacof(
56+
k * sim, metric=False, normalized_stress=True, max_iter=5, random_state=0
57+
)
58+
59+
assert_allclose(stress1, stress2, rtol=1e-5)
60+
assert_allclose(X1, X2, rtol=1e-5)
61+
62+
63+
def test_normalize_metric_warning():
64+
"""
65+
Test that a UserWarning is emitted when using normalized stress with
66+
metric-MDS.
67+
"""
68+
msg = "Normalized stress is not supported"
69+
sim = np.array([[0, 5, 3, 4], [5, 0, 2, 2], [3, 2, 0, 1], [4, 2, 1, 0]])
70+
with pytest.raises(ValueError, match=msg):
71+
mds.smacof(sim, metric=True, normalized_stress=True)

0 commit comments

Comments
 (0)
0