8000 MAINT Parameters validation for `covariance.oas` (#24904) · jjerphan/scikit-learn@b815369 · GitHub
[go: up one dir, main page]

Skip to content

Commit b815369

Browse files
raghuveerbhatglemaitrejeremiedbb
authored andcommitted
MAINT Parameters validation for covariance.oas (scikit-learn#24904)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent ce49fef commit b815369

File tree

3 files changed

+44
-30
lines changed

3 files changed

+44
-30
lines changed

sklearn/covariance/_shrunk_covariance.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,31 @@ def _ledoit_wolf(X, *, assume_centered, block_size):
4343
return shrunk_cov, shrinkage
4444

4545

46+
def _oas(X, *, assume_centered=False):
47+
"""Estimate covariance with the Oracle Approximating Shrinkage algorithm."""
48+
# for only one feature, the result is the same whatever the shrinkage
49+
if len(X.shape) == 2 and X.shape[1] == 1:
50+
if not assume_centered:
51+
X = X - X.mean()
52+
return np.atleast_2d((X**2).mean()), 0.0
53+
54+
n_samples, n_features = X.shape
55+
56+
emp_cov = empirical_covariance(X, assume_centered=assume_centered)
57+
mu = np.trace(emp_cov) / n_features
58+
59+
# formula from Chen et al.'s **implementation**
60+
alpha = np.mean(emp_cov**2)
61+
num = alpha + mu**2
62+
den = (n_samples + 1.0) * (alpha - (mu**2) / n_features)
63+
64+
shrinkage = 1.0 if den == 0 else min(num / den, 1.0)
65+
shrunk_cov = (1.0 - shrinkage) * emp_cov
66+
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu
67+
68+
return shrunk_cov, shrinkage
69+
70+
4671
###############################################################################
4772
# Public API
4873
# ShrunkCovariance estimator
@@ -503,6 +528,7 @@ def fit(self, X, y=None):
503528

504529

505530
# OAS estimator
531+
@validate_params({"X": ["array-like"]})
506532
def oas(X, *, assume_centered=False):
507533
"""Estimate covariance with the Oracle Approximating Shrinkage algorithm.
508534
@@ -537,35 +563,10 @@ def oas(X, *, assume_centered=False):
537563
The formula we used to implement the OAS is slightly modified compared
538564
to the one given in the article. See :class:`OAS` for more details.
539565
"""
540-
X = np.asarray(X)
541-
# for only one feature, the result is the same whatever the shrinkage
542-
if len(X.shape) == 2 and X.shape[1] == 1:
543-
if not assume_centered:
544-
X = X - X.mean()
545-
return np.atleast_2d((X**2).mean()), 0.0
546-
if X.ndim == 1:
547-
X = np.reshape(X, (1, -1))
548-
warnings.warn(
549-
"Only one sample available. You may want to reshape your data array"
550-
)
551-
n_samples = 1
552-
n_features = X.size
553-
else:
554-
n_samples, n_features = X.shape
555-
556-
emp_cov = empirical_covariance(X, assume_centered=assume_centered)
557-
mu = np.trace(emp_cov) / n_features
558-
559-
# formula from Chen et al.'s **implementation**
560-
alpha = np.mean(emp_cov**2)
561-
num = alpha + mu**2
562-
den = (n_samples + 1.0) * (alpha - (mu**2) / n_features)
563-
564-
shrinkage = 1.0 if den == 0 else min(num / den, 1.0)
565-
shrunk_cov = (1.0 - shrinkage) * emp_cov
566-
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu
567-
568-
return shrunk_cov, shrinkage
566+
estimator = OAS(
567+
assume_centered=assume_centered,
568+
).fit(X)
569+
return estimator.covariance_, estimator.shrinkage_
569570

570571

571572
class OAS(EmpiricalCovariance):
@@ -697,7 +698,7 @@ def fit(self, X, y=None):
697698
else:
698699
self.location_ = X.mean(0)
699700

700-
covariance, shrinkage = oas(X - self.location_, assume_centered=True)
701+
covariance, shrinkage = _oas(X - self.location_, assume_centered=True)
701702
self.shrinkage_ = shrinkage
702703
self._set_covariance(covariance)
703704

sklearn/covariance/tests/test_covariance.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
)
2727
from sklearn.covariance._shrunk_covariance import _ledoit_wolf
2828

29+
from .._shrunk_covariance import _oas
30+
2931
X, _ = datasets.load_diabetes(return_X_y=True)
3032
X_1d = X[:, 0]
3133
n_samples, n_features = X.shape
@@ -336,6 +338,16 @@ def test_oas():
336338
assert_almost_equal(oa.score(X), score_, 4)
337339
assert oa.precision_ is None
338340

341+
# test function _oas without assuming centered data
342+
X_1f = X[:, 0:1]
343+
oa = OAS()
344+
oa.fit(X_1f)
345+
# compare shrunk covariance obtained from data and from MLE estimate
346+
_oa_cov_from_mle, _oa_shrinkage_from_mle = _oas(X_1f)
347+
assert_array_almost_equal(_oa_cov_from_mle, oa.covariance_, 4)
348+
assert_almost_equal(_oa_shrinkage_from_mle, oa.shrinkage_)
349+
assert_array_almost_equal((X_1f**2).sum() / n_samples, oa.covariance_, 4)
350+
339351

340352
def test_EmpiricalCovariance_validates_mahalanobis():
341353
"""Checks that EmpiricalCovariance validates data with mahalanobis."""

sklearn/tests/test_public_functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def test_function_param_validation(func_module):
130130
PARAM_VALIDATION_CLASS_WRAPPER_LIST = [
131131
("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"),
132132
("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"),
133+
("sklearn.covariance.oas", "sklearn.covariance.OAS"),
133134
]
134135

135136

0 commit comments

Comments
 (0)
0