8000 Merge pull request #5909 from argriffing/linalg-astype-no-copy · numpy/numpy@fd94000 · GitHub
[go: up one dir, main page]

Skip to content

Commit fd94000

Browse files
committed
Merge pull request #5909 from argriffing/linalg-astype-no-copy
MAINT: use copy=False in a few astype() calls
2 parents 9dba7a4 + 1ff2be1 commit fd94000

File tree

1 file changed

+34
-20
lines changed

1 file changed

+34
-20
lines changed

numpy/linalg/linalg.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
2424
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
2525
finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
26-
broadcast, atleast_2d, intp, asanyarray
26+
broadcast, atleast_2d, intp, asanyarray, isscalar
2727
)
2828
from numpy.lib import triu, asfarray
2929
from numpy.linalg import lapack_lite, _umath_linalg
@@ -382,7 +382,7 @@ def solve(a, b):
382382
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
383383
r = gufunc(a, b, signature=signature, extobj=extobj)
384384

385-
return wrap(r.astype(result_t))
385+
return wrap(r.astype(result_t, copy=False))
386386

387387

388388
def tensorinv(a, ind=2):
@@ -522,7 +522,7 @@ def inv(a):
522522
signature = 'D->D' if isComplexType(t) else 'd->d'
523523
extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
524524
ainv = _umath_linalg.inv(a, signature=signature, extobj=extobj)
525-
return wrap(ainv.astype(result_t))
525+
return wrap(ainv.astype(result_t, copy=False))
526526

527527

528528
# Cholesky decomposition
@@ -606,7 +606,8 @@ def cholesky(a):
606606
_assertNdSquareness(a)
607607
t, result_t = _commonType(a)
608608
signature = 'D->D' if isComplexType(t) else 'd->d'
609-
return wrap(gufunc(a, signature=signature, extobj=extobj).astype(result_t))
609+
r = gufunc(a, signature=signature, extobj=extobj)
610+
return wrap(r.astype(result_t, copy=False))
610611

611612
# QR decompostion
612613

@@ -781,7 +782,7 @@ def qr(a, mode='reduced'):
781782

782783
if mode == 'economic':
783784
if t != result_t :
784-
a = a.astype(result_t)
785+
a = a.astype(result_t, copy=False)
785786
return wrap(a.T)
786787

787788
# generate q from a
@@ -908,7 +909,7 @@ def eigvals(a):
908909
else:
909910
result_t = _complexType(result_t)
910911

911-
return w.astype(result_t)
912+
return w.astype(result_t, copy=False)
912913

913914
def eigvalsh(a, UPLO='L'):
914915
"""
@@ -978,7 +979,7 @@ def eigvalsh(a, UPLO='L'):
978979
t, result_t = _commonType(a)
979980
signature = 'D->d' if isComplexType(t) else 'd->d'
980981
w = gufunc(a, signature=signature, extobj=extobj)
981-
return w.astype(_realType(result_t))
982+
return w.astype(_realType(result_t), copy=False)
982983

983984
def _convertarray(a):
984985
t, result_t = _commonType(a)
@@ -1124,8 +1125,8 @@ def eig(a):
11241125
else:
11251126
result_t = _complexType(result_t)
11261127

1127-
vt = vt.astype(result_t)
1128-
return w.astype(result_t), wrap(vt)
1128+
vt = vt.astype(result_t, copy=False)
1129+
return w.astype(result_t, copy=False), wrap(vt)
11291130

11301131

11311132
def eigh(a, UPLO='L'):
@@ -1232,8 +1233,8 @@ def eigh(a, UPLO='L'):
12321233

12331234
signature = 'D->dD' if isComplexType(t) else 'd->dd'
12341235
w, vt = gufunc(a, signature=signature, extobj=extobj)
1235-
w = w.astype(_realType(result_t))
1236-
vt = vt.astype(result_t)
1236+
w = w.astype(_realType(result_t), copy=False)
1237+
vt = vt.astype(result_t, copy=False)
12371238
return w, wrap(vt)
12381239

12391240

@@ -1344,9 +1345,9 @@ def svd(a, full_matrices=1, compute_uv=1):
13441345

13451346
signature = 'D->DdD' if isComplexType(t) else 'd->ddd'
13461347
u, s, vt = gufunc(a, signature=signature, extobj=extobj)
1347-
u = u.astype(result_t)
1348-
s = s.astype(_realType(result_t))
1349-
vt = vt.astype(result_t)
1348+
u = u.astype(result_t, copy=False)
1349+
s = s.astype(_realType(result_t), copy=False)
1350+
vt = vt.astype(result_t, copy=False)
13501351
return wrap(u), s, wrap(vt)
13511352
else:
13521353
if m < n:
@@ -1356,7 +1357,7 @@ def svd(a, full_matrices=1, compute_uv=1):
13561357

13571358
signature = 'D->d' if isComplexType(t) else 'd->d'
13581359
s = gufunc(a, signature=signature, extobj=extobj)
1359-
s = s.astype(_realType(result_t))
1360+
s = s.astype(_realType(result_t), copy=False)
13601361
return s
13611362

13621363
def cond(x, p=None):
@@ -1695,7 +1696,15 @@ def slogdet(a):
16951696
real_t = _realType(result_t)
16961697
signature = 'D->Dd' if isComplexType(t) else 'd->dd'
16971698
sign, logdet = _umath_linalg.slogdet(a, signature=signature)
1698-
return sign.astype(result_t), logdet.astype(real_t)
1699+
if isscalar(sign):
1700+
sign = sign.astype(result_t)
1701+
else:
1702+
sign = sign.astype(result_t, copy=False)
1703+
if isscalar(logdet):
1704+
logdet = logdet.astype(real_t)
1705+
else:
1706+
logdet = logdet.astype(real_t, copy=False)
1707+
return sign, logdet
16991708

17001709
def det(a):
17011710
"""
@@ -1749,7 +1758,12 @@ def det(a):
17491758
_assertNdSquareness(a)
17501759
t, result_t = _commonType(a)
17511760
signature = 'D->D' if isComplexType(t) else 'd->d'
1752-
return _umath_linalg.det(a, signature=signature).astype(result_t)
1761+
r = _umath_linalg.det(a, signature=signature)
1762+
if isscalar(r):
1763+
r = r.astype(result_t)
1764+
else:
1765+
r = r.astype(result_t, copy=False)
1766+
return r
17531767

17541768
# Linear Least Squares
17551769

@@ -1905,12 +1919,12 @@ def lstsq(a, b, rcond=-1):
19051919
if results['rank'] == n and m > n:
19061920
if isComplexType(t):
19071921
resids = sum(abs(transpose(bstar)[n:,:])**2, axis=0).astype(
1908-
result_real_t)
1922+
result_real_t, copy=False)
19091923
else:
19101924
resids = sum((transpose(bstar)[n:,:])**2, axis=0).astype(
1911-
result_real_t)
1925+
result_real_t, copy=False)
19121926

1913-
st = s[:min(n, m)].copy().astype(result_real_t)
1927+
st = s[:min(n, m)].astype(result_real_t, copy=True)
19141928
return wrap(x), wrap(resids), results['rank'], st
19151929

19161930

0 commit comments

Comments
 (0)
0