@@ -43,6 +43,31 @@ def _ledoit_wolf(X, *, assume_centered, block_size):
43
43
return shrunk_cov , shrinkage
44
44
45
45
46
+ def _oas (X , * , assume_centered = False ):
47
+ """Estimate covariance with the Oracle Approximating Shrinkage algorithm."""
48
+ # for only one feature, the result is the same whatever the shrinkage
49
+ if len (X .shape ) == 2 and X .shape [1 ] == 1 :
50
+ if not assume_centered :
51
+ X = X - X .mean ()
52
+ return np .atleast_2d ((X ** 2 ).mean ()), 0.0
53
+
54
+ n_samples , n_features = X .shape
55
+
56
+ emp_cov = empirical_covariance (X , assume_centered = assume_centered )
57
+ mu = np .trace (emp_cov ) / n_features
58
+
59
+ # formula from Chen et al.'s **implementation**
60
+ alpha = np .mean (emp_cov ** 2 )
61
+ num = alpha + mu ** 2
62
+ den = (n_samples + 1.0 ) * (alpha - (mu ** 2 ) / n_features )
63
+
64
+ shrinkage = 1.0 if den == 0 else min (num / den , 1.0 )
65
+ shrunk_cov = (1.0 - shrinkage ) * emp_cov
66
+ shrunk_cov .flat [:: n_features + 1 ] += shrinkage * mu
67
+
68
+ return shrunk_cov , shrinkage
69
+
70
+
46
71
###############################################################################
47
72
# Public API
48
73
# ShrunkCovariance estimator
@@ -503,6 +528,7 @@ def fit(self, X, y=None):
503
528
504
529
505
530
# OAS estimator
531
+ @validate_params ({"X" : ["array-like" ]})
506
532
def oas (X , * , assume_centered = False ):
507
533
"""Estimate covariance with the Oracle Approximating Shrinkage algorithm.
508
534
@@ -537,35 +563,10 @@ def oas(X, *, assume_centered=False):
537
563
The formula we used to implement the OAS is slightly modified compared
538
564
to the one given in the article. See :class:`OAS` for more details.
539
565
"""
540
- X = np .asarray (X )
541
- # for only one feature, the result is the same whatever the shrinkage
542
- if len (X .shape ) == 2 and X .shape [1 ] == 1 :
543
- if not assume_centered :
544
- X = X - X .mean ()
545
- return np .atleast_2d ((X ** 2 ).mean ()), 0.0
546
- if X .ndim == 1 :
547
- X = np .reshape (X , (1 , - 1 ))
548
- warnings .warn (
549
- "Only one sample available. You may want to reshape your data array"
550
- )
551
- n_samples = 1
552
- n_features = X .size
553
- else :
554
- n_samples , n_features = X .shape
555
-
556
- emp_cov = empirical_covariance (X , assume_centered = assume_centered )
557
- mu = np .trace (emp_cov ) / n_features
558
-
559
- # formula from Chen et al.'s **implementation**
560
- alpha = np .mean (emp_cov ** 2 )
561
- num = alpha + mu ** 2
562
- den = (n_samples + 1.0 ) * (alpha - (mu ** 2 ) / n_features )
563
-
564
- shrinkage = 1.0 if den == 0 else min (num / den , 1.0 )
565
- shrunk_cov = (1.0 - shrinkage ) * emp_cov
566
- shrunk_cov .flat [:: n_features + 1 ] += shrinkage * mu
567
-
568
- return shrunk_cov , shrinkage
566
+ estimator = OAS (
567
+ assume_centered = assume_centered ,
568
+ ).fit (X )
569
+ return estimator .covariance_ , estimator .shrinkage_
569
570
570
571
571
572
class OAS (EmpiricalCovariance ):
@@ -697,7 +698,7 @@ def fit(self, X, y=None):
697
698
else :
698
699
self .location_ = X .mean (0 )
699
700
700
- covariance , shrinkage = oas (X - self .location_ , assume_centered = True )
701
+ covariance , shrinkage = _oas (X - self .location_ , assume_centered = True )
701
702
self .shrinkage_ = shrinkage
702
703
self ._set_covariance (covariance )
703
704
0 commit comments