8000 Merge pull request #24 from ogrisel/fix-nan-1d-regularized-covariance · scikit-learn/scikit-learn@021bf74 · GitHub
[go: up one dir, main page]

Skip to content

Commit 021bf74

Browse files
committed
Merge pull request #24 from ogrisel/fix-nan-1d-regularized-covariance
FIX regularized covariance on 1D data
2 parents 8ef0b9a + 9f50699 commit 021bf74

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

sklearn/covariance/shrunk_covariance_.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ def ledoit_wolf_shrinkage(X, assume_centered=False, block_size=1000):
228228
# get final beta as the min between beta and delta
229229
beta = min(beta, delta)
230230
# finally get shrinkage
231-
shrinkage = beta / delta
232-
231+
shrinkage = 0 if beta == 0 else beta / delta
233232
return shrinkage
234233

235234

@@ -461,7 +460,7 @@ def oas(X, assume_centered=False):
461460
num = alpha + mu ** 2
462461
den = (n_samples + 1.) * (alpha - (mu ** 2) / n_features)
463462

464-
shrinkage = min(num / den, 1.)
463+
shrinkage = 1. if den == 0 else min(num / den, 1.)
465464
shrunk_cov = (1. - shrinkage) * emp_cov
466465
shrunk_cov.flat[::n_features + 1] += shrinkage * mu
467466

sklearn/covariance/tests/test_covariance.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def test_covariance():
6161
X_1sample = np.arange(5)
6262
cov = EmpiricalCovariance()
6363
assert_warns(UserWarning, cov.fit, X_1sample)
64+
assert_array_almost_equal(cov.covariance_,
65+
np.zeros(shape=(5, 5), dtype=np.float64))
6466

6567
# test integer type
6668
X_integer = np.asarray([[0, 1], [1, 0]])
@@ -181,9 +183,11 @@ def test_ledoit_wolf():
181183

182184
# test with one sample
183185
# FIXME I don't know what this test does
184-
#X_1sample = np.arange(5)
185-
#lw = LedoitWolf()
186-
#assert_warns(UserWarning, lw.fit, X_1sample)
186+
X_1sample = np.arange(5)
187+
lw = LedoitWolf()
188+
assert_warns(UserWarning, lw.fit, X_1sample)
189+
assert_array_almost_equal(lw.covariance_,
190+
np.zeros(shape=(5, 5), dtype=np.float64))
187191

188192
# test shrinkage coeff on a simple data set (without saving precision)
189193
lw = LedoitWolf(store_precision=False)
@@ -253,9 +257,11 @@ def test_oas():
253257

254258
# test with one sample
255259
# FIXME I don't know what this test does
256-
#X_1sample = np.arange(5)
257-
#oa = OAS()
258-
#assert_warns(UserWarning, oa.fit, X_1sample)
260+
X_1sample = np.arange(5)
261+
oa = OAS()
262+
assert_warns(UserWarning, oa.fit, X_1sample)
263+
assert_array_almost_equal(oa.covariance_,
264+
np.zeros(shape=(5, 5), dtype=np.float64))
259265

260266
# test shrinkage coeff on a simple data set (without saving precision)
261267
oa = OAS(store_precision=False)

0 commit comments

Comments
 (0)
0