10000 MAINT add more intuition on OAS computation based on literature (#23867) · jeremiedbb/scikit-learn@eceae05 · GitHub
[go: up one dir, main page]

Skip to content

Commit eceae05

Browse files
Micky774glemaitre
authored andcommitted
MAINT add more intuition on OAS computation based on literature (scikit-learn#23867)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 9a5e673 commit eceae05

File tree

2 files changed

+90
-24
lines changed

2 files changed

+90
-24
lines changed

doc/modules/covariance.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,10 @@ object to the same sample.
160160

161161
.. topic:: References:
162162

163-
.. [2] Chen et al., "Shrinkage Algorithms for MMSE Covariance Estimation",
164-
IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010.
163+
.. [2] :arxiv:`"Shrinkage algorithms for MMSE covariance estimation.",
164+
Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O.
165+
IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010.
166+
<0907.4698>`
165167
166168
.. topic:: Examples:
167169

sklearn/covariance/_shrunk_covariance.py

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,59 @@
2323
from ..utils._param_validation import Interval
2424

2525

26+
def _oas(X, *, assume_centered=False):
27+
"""Estimate covariance with the Oracle Approximating Shrinkage algorithm.
28+
29+
The formulation is based on [1]_.
30+
[1] "Shrinkage algorithms for MMSE covariance estimation.",
31+
Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O.
32+
IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010.
33+
https://arxiv.org/pdf/0907.4698.pdf
34+
"""
35+
if len(X.shape) == 2 and X.shape[1] == 1:
36+
# for only one feature, the result is the same whatever the shrinkage
37+
if not assume_centered:
38+
X = X - X.mean()
39+
return np.atleast_2d((X**2).mean()), 0.0
40+
41+
n_samples, n_features = X.shape
42+
43+
emp_cov = empirical_covariance(X, assume_centered=assume_centered)
44+
45+
# The shrinkage is defined as:
46+
# shrinkage = min(
47+
# trace(S @ S.T) + trace(S)**2) / ((n + 1) (trace(S @ S.T) - trace(S)**2 / p), 1
48+
# )
49+
# where n and p are n_samples and n_features, respectively (cf. Eq. 23 in [1]).
50+
# The factor 2 / p is omitted since it does not impact the value of the estimator
51+
# for large p.
52+
53+
# Instead of computing trace(S)**2, we can compute the average of the squared
54+
# elements of S that is equal to trace(S)**2 / p**2.
55+
# See the definition of the Frobenius norm:
56+
# https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
57+
alpha = np.mean(emp_cov**2)
58+
mu = np.trace(emp_cov) / n_features
59+
mu_squared = mu**2
60+
61+
# The factor 1 / p**2 will cancel out since it is in both the numerator and
62+
# denominator
63+
num = alpha + mu_squared
64+
den = (n_samples + 1) * (alpha - mu_squared / n_features)
65+
shrinkage = 1.0 if den == 0 else min(num / den, 1.0)
66+
67+
# The shrunk covariance is defined as:
68+
# (1 - shrinkage) * S + shrinkage * F (cf. Eq. 4 in [1])
69+
# where S is the empirical covariance and F is the shrinkage target defined as
70+
# F = trace(S) / n_features * np.identity(n_features) (cf. Eq. 3 in [1])
71+
shrunk_cov = (1.0 - shrinkage) * emp_cov
72+
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu
73+
74+
return shrunk_cov, shrinkage
75+
76+
77+
###############################################################################
78+
# Public API
2679
# ShrunkCovariance estimator
2780

2881

@@ -500,7 +553,9 @@ def fit(self, X, y=None):
500553

501554
# OAS estimator
502555
def oas(X, *, assume_centered=False):
503-
"""Estimate covariance with the Oracle Approximating Shrinkage algorithm.
556+
"""Estimate covariance with the Oracle Approximating Shrinkage as proposed in [1]_.
557+
558+
Read more in the :ref:`User Guide <shrunk_covariance>`.
504559
505560
Parameters
506561
----------
@@ -524,14 +579,25 @@ def oas(X, *, assume_centered=False):
524579
525580
Notes
526581
-----
527-
The regularised (shrunk) covariance is:
582+
The regularised covariance is:
528583
529-
(1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features)
584+
(1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features),
530585
531-
where mu = trace(cov) / n_features
586+
where mu = trace(cov) / n_features and shrinkage is given by the OAS formula
587+
(see [1]_).
588+
589+
The shrinkage formulation implemented here differs from Eq. 23 in [1]_. In
590+
the original article, formula (23) states that 2/p (p being the number of
591+
features) is multiplied by Trace(cov*cov) in both the numerator and
592+
denominator, but this operation is omitted because for a large p, the value
593+
of 2/p is so small that it doesn't affect the value of the estimator.
532594
533-
The formula we used to implement the OAS is slightly modified compared
534-
to the one given in the article. See :class:`OAS` for more details.
595+
References
596+
----------
597+
.. [1] :arxiv:`"Shrinkage algorithms for MMSE covariance estimation.",
598+
Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O.
599+
IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010.
600+
<0907.4698>`
535601
"""
536602
X = np.asarray(X)
537603
# for only one feature, the result is the same whatever the shrinkage
@@ -565,20 +631,10 @@ def oas(X, *, assume_centered=False):
565631

566632

567633
class OAS(EmpiricalCovariance):
568-
"""Oracle Approximating Shrinkage Estimator.
634+
"""Oracle Approximating Shrinkage Estimator as proposed in [1]_.
569635
570636
Read more in the :ref:`User Guide <shrunk_covariance>`.
571637
572-
OAS is a particular form of shrinkage described in
573-
"Shrinkage Algorithms for MMSE Covariance Estimation"
574-
Chen et al., IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010.
575-
576-
The formula used here does not correspond to the one given in the
577-
article. In the original article, formula (23) states that 2/p is
578-
multiplied by Trace(cov*cov) in both the numerator and denominator, but
579-
this operation is omitted because for a large p, the value of 2/p is
580-
so small that it doesn't affect the value of the estimator.
581-
582638
Parameters
583639
----------
584640
store_precision : bool, default=True
@@ -635,15 +691,23 @@ class OAS(EmpiricalCovariance):
635691
-----
636692
The regularised covariance is:
637693
638-
(1 - shrinkage) * cov + shrinkage * mu * np.identit BC8A y(n_features)
694+
(1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features),
639695
640-
where mu = trace(cov) / n_features
641-
and shrinkage is given by the OAS formula (see References)
696+
where mu = trace(cov) / n_features and shrinkage is given by the OAS formula
697+
(see [1]_).
698+
699+
The shrinkage formulation implemented here differs from Eq. 23 in [1]_. In
700+
the original article, formula (23) states that 2/p (p being the number of
701+
features) is multiplied by Trace(cov*cov) in both the numerator and
702+
denominator, but this operation is omitted because for a large p, the value
703+
of 2/p is so small that it doesn't affect the value of the estimator.
642704
643705
References
644706
----------
645-
"Shrinkage Algorithms for MMSE Covariance Estimation"
646-
Chen et al., IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010.
707+
.. [1] :arxiv:`"Shrinkage algorithms for MMSE covariance estimation.",
708+
Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O.
709+
IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010.
710+
<0907.4698>`
647711
648712
Examples
649713
--------

0 commit comments

Comments
 (0)
0