8000 ENH: add dtype option to cov and corrcoef (#17456) · numpy/numpy@156cd05 · GitHub
[go: up one dir, main page]

Skip to content

Commit 156cd05

Browse files
lschwetlickeric-wieserrossbar
authored
ENH: add dtype option to cov and corrcoef (#17456)
Adds a keyword-only dtype parameter to correlate and coerrcoef to allow user to specify the dtype of the output. Co-authored-by: Eric Wieser <wieser.eric@gmail.com> Co-authored-by: Ross Barnowski <rossbar@berkeley.edu>
1 parent 84a4fcb commit 156cd05

File tree

3 files changed

+41
-9
lines changed

3 files changed

+41
-9
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
``dtype`` option for `cov` and `corrcoef`
2+
----------------------------------------------------
3+
The ``dtype`` option is now available for `numpy.cov` and `numpy.corrcoef`.
4+
It specifies which data-type the returned result should have.
5+
By default the functions still return a `numpy.float64` result.

numpy/lib/function_base.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,13 +2268,13 @@ def _vectorize_call_with_signature(self, func, args):
22682268

22692269

22702270
def _cov_dispatcher(m, y=None, rowvar=None, bias=None, ddof=None,
2271-
fweights=None, aweights=None):
2271+
fweights=None, aweights=None, *, dtype=None):
22722272
return (m, y, fweights, aweights)
22732273

22742274

22752275
@array_function_dispatch(_cov_dispatcher)
22762276
def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
2277-
aweights=None):
2277+
aweights=None, *, dtype=None):
22782278
"""
22792279
Estimate a covariance matrix, given data and weights.
22802280
@@ -2325,6 +2325,11 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
23252325
weights can be used to assign probabilities to observation vectors.
23262326
23272327
.. versionadded:: 1.10
2328+
dtype : data-type, optional
2329+
Data-type of the result. By default, the return data-type will have
2330+
at least `numpy.float64` precision.
2331+
2332+
.. versionadded:: 1.20
23282333
23292334
Returns
23302335
-------
@@ -2400,13 +2405,16 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
24002405
if m.ndim > 2:
24012406
raise ValueError("m has more than 2 dimensions")
24022407

2403-
if y is None:
2404-
dtype = np.result_type(m, np.float64)
2405-
else:
2408+
if y is not None:
24062409
y = np.asarray(y)
24072410
if y.ndim > 2:
24082411
raise ValueError("y has more than 2 dimensions")
2409-
dtype = np.result_type(m, y, np.float64)
2412+
2413+
if dtype is None:
2414+
if y is None:
2415+
dtype = np.result_type(m, np.float64)
2416+
else:
2417+
dtype = np.result_type(m, y, np.float64)
24102418

24112419
X = array(m, ndmin=2, dtype=dtype)
24122420
if not rowvar and X.shape[0] != 1:
@@ -2486,12 +2494,14 @@ def cov(m, y=None, rowvar=True, bias=False, ddof=None, fweights=None,
24862494
return c.squeeze()
24872495

24882496

2489-
def _corrcoef_dispatcher(x, y=None, rowvar=None, bias=None, ddof=None):
2497+
def _corrcoef_dispatcher(x, y=None, rowvar=None, bias=None, ddof=None, *,
2498+
dtype=None):
24902499
return (x, y)
24912500

24922501

24932502
@array_function_dispatch(_corrcoef_dispatcher)
2494-
def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue):
2503+
def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue, *,
2504+
dtype=None):
24952505
"""
24962506
Return Pearson product-moment correlation coefficients.
24972507
@@ -2525,6 +2535,11 @@ def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue):
25252535
Has no effect, do not use.
25262536
25272537
.. deprecated:: 1.10.0
2538+
dtype : data-type, optional
2539+
Data-type of the result. By default, the return data-type will have
2540+
at least `numpy.float64` precision.
2541+
2542+
.. versionadded:: 1.20
25282543
25292544
Returns
25302545
-------
@@ -2616,7 +2631,7 @@ def corrcoef(x, y=None, rowvar=True, bias=np._NoValue, ddof=np._NoValue):
26162631
# 2015-03-15, 1.10
26172632
warnings.warn('bias and ddof have no effect and are deprecated',
26182633
DeprecationWarning, stacklevel=3)
2619-
c = cov(x, y, rowvar)
2634+
c = cov(x, y, rowvar, dtype=dtype)
26202635
try:
26212636
d = diag(c)
26222637
except ValueError:

numpy/lib/tests/test_function_base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2023,6 +2023,12 @@ def test_extreme(self):
20232023
assert_array_almost_equal(c, np.array([[1., -1.], [-1., 1.]]))
20242024
assert_(np.all(np.abs(c) <= 1.0))
20252025

2026+
@pytest.mark.parametrize("test_type", [np.half, np.single, np.double, np.longdouble])
2027+
def test_corrcoef_dtype(self, test_type):
2028+
cast_A = self.A.astype(test_type)
2029+
res = corrcoef(cast_A, dtype=test_type)
2030+
assert test_type == res.dtype
2031+
20262032

20272033
class TestCov:
20282034
x1 = np.array([[0, 2], [1, 1], [2, 0]]).T
@@ -2123,6 +2129,12 @@ def test_unit_fweights_and_aweights(self):
21232129
aweights=self.unit_weights),
21242130
self.res1)
21252131

2132+
@pytest.mark.parametrize("test_type", [np.half, np. 6621 single, np.double, np.longdouble])
2133+
def test_cov_dtype(self, test_type):
2134+
cast_x1 = self.x1.astype(test_type)
2135+
res = cov(cast_x1, dtype=test_type)
2136+
assert test_type == res.dtype
2137+
21262138

21272139
class Test_I0:
21282140

0 commit comments

Comments
 (0)
0