23
23
from ..utils ._param_validation import Interval
24
24
25
25
26
+ def _oas (X , * , assume_centered = False ):
27
+ """Estimate covariance with the Oracle Approximating Shrinkage algorithm.
28
+
29
+ The formulation is based on [1]_.
30
+ [1] "Shrinkage algorithms for MMSE covariance estimation.",
31
+ Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O.
32
+ IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010.
33
+ https://arxiv.org/pdf/0907.4698.pdf
34
+ """
35
+ if len (X .shape ) == 2 and X .shape [1 ] == 1 :
36
+ # for only one feature, the result is the same whatever the shrinkage
37
+ if not assume_centered :
38
+ X = X - X .mean ()
39
+ return np .atleast_2d ((X ** 2 ).mean ()), 0.0
40
+
41
+ n_samples , n_features = X .shape
42
+
43
+ emp_cov = empirical_covariance (X , assume_centered = assume_centered )
44
+
45
+ # The shrinkage is defined as:
46
+ # shrinkage = min(
47
+ # trace(S @ S.T) + trace(S)**2) / ((n + 1) (trace(S @ S.T) - trace(S)**2 / p), 1
48
+ # )
49
+ # where n and p are n_samples and n_features, respectively (cf. Eq. 23 in [1]).
50
+ # The factor 2 / p is omitted since it does not impact the value of the estimator
51
+ # for large p.
52
+
53
+ # Instead of computing trace(S)**2, we can compute the average of the squared
54
+ # elements of S that is equal to trace(S)**2 / p**2.
55
+ # See the definition of the Frobenius norm:
56
+ # https://en.wikipedia.org/wiki/Matrix_norm#Frobenius_norm
57
+ alpha = np .mean (emp_cov ** 2 )
58
+ mu = np .trace (emp_cov ) / n_features
59
+ mu_squared = mu ** 2
60
+
61
+ # The factor 1 / p**2 will cancel out since it is in both the numerator and
62
+ # denominator
63
+ num = alpha + mu_squared
64
+ den = (n_samples + 1 ) * (alpha - mu_squared / n_features )
65
+ shrinkage = 1.0 if den == 0 else min (num / den , 1.0 )
66
+
67
+ # The shrunk covariance is defined as:
68
+ # (1 - shrinkage) * S + shrinkage * F (cf. Eq. 4 in [1])
69
+ # where S is the empirical covariance and F is the shrinkage target defined as
70
+ # F = trace(S) / n_features * np.identity(n_features) (cf. Eq. 3 in [1])
71
+ shrunk_cov = (1.0 - shrinkage ) * emp_cov
72
+ shrunk_cov .flat [:: n_features + 1 ] += shrinkage * mu
73
+
74
+ return shrunk_cov , shrinkage
75
+
76
+
77
+ ###############################################################################
78
+ # Public API
26
79
# ShrunkCovariance estimator
27
80
28
81
@@ -500,7 +553,9 @@ def fit(self, X, y=None):
500
553
501
554
# OAS estimator
502
555
def oas (X , * , assume_centered = False ):
503
- """Estimate covariance with the Oracle Approximating Shrinkage algorithm.
556
+ """Estimate covariance with the Oracle Approximating Shrinkage as proposed in [1]_.
557
+
558
+ Read more in the :ref:`User Guide <shrunk_covariance>`.
504
559
505
560
Parameters
506
561
----------
@@ -524,14 +579,25 @@ def oas(X, *, assume_centered=False):
524
579
525
580
Notes
526
581
-----
527
- The regularised (shrunk) covariance is:
582
+ The regularised covariance is:
528
583
529
- (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features)
584
+ (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features),
530
585
531
- where mu = trace(cov) / n_features
586
+ where mu = trace(cov) / n_features and shrinkage is given by the OAS formula
587
+ (see [1]_).
588
+
589
+ The shrinkage formulation implemented here differs from Eq. 23 in [1]_. In
590
+ the original article, formula (23) states that 2/p (p being the number of
591
+ features) is multiplied by Trace(cov*cov) in both the numerator and
592
+ denominator, but this operation is omitted because for a large p, the value
593
+ of 2/p is so small that it doesn't affect the value of the estimator.
532
594
533
- The formula we used to implement the OAS is slightly modified compared
534
- to the one given in the article. See :class:`OAS` for more details.
595
+ References
596
+ ----------
597
+ .. [1] :arxiv:`"Shrinkage algorithms for MMSE covariance estimation.",
598
+ Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O.
599
+ IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010.
600
+ <0907.4698>`
535
601
"""
536
602
X = np .asarray (X )
537
603
# for only one feature, the result is the same whatever the shrinkage
@@ -565,20 +631,10 @@ def oas(X, *, assume_centered=False):
565
631
566
632
567
633
class OAS (EmpiricalCovariance ):
568
- """Oracle Approximating Shrinkage Estimator.
634
+ """Oracle Approximating Shrinkage Estimator as proposed in [1]_ .
569
635
570
636
Read more in the :ref:`User Guide <shrunk_covariance>`.
571
637
572
- OAS is a particular form of shrinkage described in
573
- "Shrinkage Algorithms for MMSE Covariance Estimation"
574
- Chen et al., IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010.
575
-
576
- The formula used here does not correspond to the one given in the
577
- article. In the original article, formula (23) states that 2/p is
578
- multiplied by Trace(cov*cov) in both the numerator and denominator, but
579
- this operation is omitted because for a large p, the value of 2/p is
580
- so small that it doesn't affect the value of the estimator.
581
-
582
638
Parameters
583
639
----------
584
640
store_precision : bool, default=True
@@ -635,15 +691,23 @@ class OAS(EmpiricalCovariance):
635
691
-----
636
692
The regularised covariance is:
637
693
638
- (1 - shrinkage) * cov + shrinkage * mu * np.identit
BC8A
y(n_features)
694
+ (1 - shrinkage) * cov + shrinkage * mu * np.identity(n_features),
639
695
640
- where mu = trace(cov) / n_features
641
- and shrinkage is given by the OAS formula (see References)
696
+ where mu = trace(cov) / n_features and shrinkage is given by the OAS formula
697
+ (see [1]_).
698
+
699
+ The shrinkage formulation implemented here differs from Eq. 23 in [1]_. In
700
+ the original article, formula (23) states that 2/p (p being the number of
701
+ features) is multiplied by Trace(cov*cov) in both the numerator and
702
+ denominator, but this operation is omitted because for a large p, the value
703
+ of 2/p is so small that it doesn't affect the value of the estimator.
642
704
643
705
References
644
706
----------
645
- "Shrinkage Algorithms for MMSE Covariance Estimation"
646
- Chen et al., IEEE Trans. on Sign. Proc., Volume 58, Issue 10, October 2010.
707
+ .. [1] :arxiv:`"Shrinkage algorithms for MMSE covariance estimation.",
708
+ Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O.
709
+ IEEE Transactions on Signal Processing, 58(10), 5016-5029, 2010.
710
+ <0907.4698>`
647
711
648
712
Examples
649
713
--------
0 commit comments