8000 MAINT: Implement `lstsq` as a `gufunc` by eric-wieser · Pull Request #9980 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT: Implement lstsq as a gufunc #9980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 11, 2018
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to 8000
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
MAINT: Move lstsq to umath_linalg
This does not yet enable any broadcasting, but makes doing so in future far
easier.
  • Loading branch information
eric-wieser committed Apr 11, 2018
commit 3ef55be846bc8e6e21515787b29608fec7e1fad0
64 changes: 14 additions & 50 deletions numpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def _raise_linalgerror_eigenvalues_nonconvergence(err, flag):
def _raise_linalgerror_svd_nonconvergence(err, flag):
raise LinAlgError("SVD did not converge")

def _raise_linalgerror_lstsq(err, flag):
raise LinAlgError("SVD did not converge in Linear Least Squares")

def get_linalg_error_extobj(callback):
extobj = list(_linalg_error_extobj) # make a copy
extobj[2] = callback
Expand Down Expand Up @@ -1997,7 +2000,6 @@ def lstsq(a, b, rcond="warn"):
>>> plt.show()

"""
import math
a, _ = _makearray(a)
b, wrap = _makearray(b)
is_1d = b.ndim == 1
Expand All @@ -2008,7 +2010,6 @@ def lstsq(a, b, rcond="warn"):
m = a.shape[0]
n = a.shape[1]
n_rhs = b.shape[1]
ldb = max(n, m)
if m != b.shape[0]:
raise LinAlgError('Incompatible dimensions')

Expand All @@ -2028,62 +2029,25 @@ def lstsq(a, b, rcond="warn"):
FutureWarning, stacklevel=2)
rcond = -1
if rcond is None:
rcond = finfo(t).eps * ldb

bstar = zeros((ldb, n_rhs), t)
bstar[:m, :n_rhs] = b
a, bstar = _fastCopyAndTranspose(t, a, bstar)
a, bstar = _to_native_byte_order(a, bstar)
s = zeros((min(m, n),), real_t)
# This line:
# * is incorrect, according to the LAPACK documentation
# * raises a ValueError if min(m,n) == 0
# * should not be calculated here anyway, as LAPACK should calculate
# `liwork` for us. But that only works if our version of lapack does
# not have this bug:
# http://icl.cs.utk.edu/lapack-forum/archives/lapack/msg00899.html
# Lapack_lite does have that bug...
nlvl = max( 0, int( math.log( float(min(m, n))/2. ) ) + 1 )
iwork = zeros((3*min(m, n)*nlvl+11*min(m, n),), fortran_int)
if isComplexType(t):
lapack_routine = lapack_lite.zgelsd
lwork = 1
rwork = zeros((lwork,), real_t)
work = zeros((lwork,), t)
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
0, work, -1, rwork, iwork, 0)
lrwork = int(rwork[0])
lwork = int(work[0].real)
work = zeros((lwork,), t)
rwork = zeros((lrwork,), real_t)
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
0, work, lwork, rwork, iwork, 0)
rcond = finfo(t).eps * max(n, m)

if m <= n:
gufunc = _umath_linalg.lstsq_m
else:
lapack_routine = lapack_lite.dgelsd
lwork = 1
work = zeros((lwork,), t)
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
0, work, -1, iwork, 0)
lwork = int(work[0])
work = zeros((lwork,), t)
results = lapack_routine(m, n, n_rhs, a, m, bstar, ldb, s, rcond,
0, work, lwork, iwork, 0)
if results['info'] > 0:
raise LinAlgError('SVD did not converge in Linear Least Squares')

# undo transpose imposed by fortran-order arrays
b_out = bstar.T
gufunc = _umath_linalg.lstsq_n

signature = 'DDd->Did' if isComplexType(t) else 'ddd->did'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary given that the inputs are coerced already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The inputs don't look coerced already to me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, above one just gets the types for later conversion. Should have noticed that.

extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
b_out, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj)

# b_out contains both the solution and the components of the residuals
x = b_out[:n,:]
r_parts = b_out[n:,:]
x = b_out[...,:n,:]
r_parts = b_out[...,n:,:]
if isComplexType(t):
resids = sum(abs(r_parts)**2, axis=-2)
else:
resids = sum(r_parts**2, axis=-2)

rank = results['rank']

# remove the axis we added
if is_1d:
x = x.squeeze(axis=-1)
Expand Down
0