10000 ENH allow shrunk_covariance to handle multiple matrices at once (#25275) · punndcoder28/scikit-learn@2d4197d · GitHub
[go: up one dir, main page]

Skip to content

Commit 2d4197d

Browse files
qbarthelemyagramfortglemaitrejeremiedbb
authored
ENH allow shrunk_covariance to handle multiple matrices at once (scikit-learn#25275)
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent a88a33a commit 2d4197d

File tree

3 files changed

+35
-9
lines changed

3 files changed

+35
-9
lines changed

doc/whats_new/v1.4.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,13 @@ Changelog
320320
with `np.int64` indices are not supported.
321321
:pr:`27240` by :user:`Yao Xiao <Charlie-XIAO>`.
322322

323+
:mod:`sklearn.covariance`
324+
.........................
325+
326+
- |Enhancement| Allow :func:`covariance.shrunk_covariance` to process
327+
multiple covariance matrices at once by handling nd-arrays.
328+
:pr:`25275` by :user:`Quentin Barthélemy <qbarthelemy>`.
329+
323330
- |API| |FIX| :class:`~compose.ColumnTransformer` now replaces `"passthrough"`
324331
with a corresponding :class:`~preprocessing.FunctionTransformer` in the
325332
fitted ``transformers_`` attribute.

sklearn/covariance/_shrunk_covariance.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,23 +109,23 @@ def _oas(X, *, assume_centered=False):
109109
prefer_skip_nested_validation=True,
110110
)
111111
def shrunk_covariance(emp_cov, shrinkage=0.1):
112-
"""Calculate a covariance matrix shrunk on the diagonal.
112+
"""Calculate covariance matrices shrunk on the diagonal.
113113
114114
Read more in the :ref:`User Guide <shrunk_covariance>`.
115115
116116
Parameters
117117
----------
118-
emp_cov : array-like of shape (n_features, n_features)
119-
Covariance matrix to be shrunk.
118+
emp_cov : array-like of shape (..., n_features, n_features)
119+
Covariance matrices to be shrunk, at least 2D ndarray.
120120
121121
shrinkage : float, default=0.1
122122
Coefficient in the convex combination used for the computation
123123
of the shrunk estimate. Range is [0, 1].
124124
125125
Returns
126126
-------
127-
shrunk_cov : ndarray of shape (n_features, n_features)
128-
Shrunk covariance.
127+
shrunk_cov : ndarray of shape (..., n_features, n_features)
128+
Shrunk covariance matrices.
129129
130130
Notes
131131
-----
@@ -135,12 +135,13 @@ def shrunk_covariance(emp_cov, shrinkage=0.1):
135135
136136
where `mu = trace(cov) / n_features`.
137137
"""
138-
emp_cov = check_array(emp_cov)
139-
n_features = emp_cov.shape[0]
138+
emp_cov = check_array(emp_cov, allow_nd=True)
139+
n_features = emp_cov.shape[-1]
140140

141-
mu = np.trace(emp_cov) / n_features
142141
shrunk_cov = (1.0 - shrinkage) * emp_cov
143-
shrunk_cov.flat[:: n_features + 1] += shrinkage * mu
142+
mu = np.trace(emp_cov, axis1=-2, axis2=-1) / n_features
143+
mu = np.expand_dims(mu, axis=tuple(range(mu.ndim, emp_cov.ndim)))
144+
shrunk_cov += shrinkage * mu * np.eye(n_features)
144145

145146
return shrunk_cov
146147

sklearn/covariance/tests/test_covariance.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,25 @@ def test_covariance():
8181
assert_array_equal(cov.location_, np.zeros(X.shape[1]))
8282

8383

84+
@pytest.mark.parametrize("n_matrices", [1, 3])
85+
def test_shrunk_covariance_func(n_matrices):
86+
"""Check `shrunk_covariance` function."""
87+
88+
n_features = 2
89+
cov = np.ones((n_features, n_features))
90+
cov_target = np.array([[1, 0.5], [0.5, 1]])
91+
92+
if n_matrices > 1:
93+
cov = np.repeat(cov[np.newaxis, ...], n_matrices, axis=0)
94+
cov_target = np.repeat(cov_target[np.newaxis, ...], n_matrices, axis=0)
95+
96+
cov_shrunk = shrunk_covariance(cov, 0.5)
97+
assert_allclose(cov_shrunk, cov_target)
98+
99+
84100
def test_shrunk_covariance():
101+
"""Check consistency between `ShrunkCovariance` and `shrunk_covariance`."""
102+
85103
# Tests ShrunkCovariance module on a simple dataset.
86104
# compare shrunk covariance obtained from data and from MLE estimate
87105
cov = ShrunkCovariance(shrinkage=0.5)

0 commit comments

Comments
 (0)
0