8000 MAINT: Remove similar branches from linalg.lstsq by eric-wieser · Pull Request #9986 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT: Remove similar branches from linalg.lstsq #9986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 9, 2017
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions numpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,7 +1982,11 @@ def lstsq(a, b, rcond="warn"):
ldb = max(n, m)
if m != b.shape[0]:
raise LinAlgError('Incompatible dimensions')

t, result_t = _commonType(a, b)
real_t = _linalgRealType(t)
result_real_t = _realType(result_t)

# Determine default rcond value
if rcond == "warn":
# 2017-08-19, 1.14.0
Expand All @@ -1997,8 +2001,6 @@ def lstsq(a, b, rcond="warn"):
if rcond is None:
rcond = finfo(t).eps * ldb

result_real_t = _realType(result_t)
real_t = _linalgRealType(t)
bstar = zeros((ldb, n_rhs), t)
bstar[:b.shape[0], :n_rhs] = b.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not this PR, but when I looked at this before, I wondered what would be the point of .copy(); it is not like a view gets taken and this cannot be of much speed benefit for the whole routine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the copy stuff here is weird.

a, bstar = _fastCopyAndTranspose(t, a, bstar)
Expand Down Expand Up @@ -2039,28 +2041,35 @@ def lstsq(a, b, rcond="warn"):
0, work, lwork, iwork, 0)
if results['info'] > 0:
raise LinAlgError('SVD did not converge in Linear Least Squares')
resids = array([], result_real_t)
if is_1d:
x = array(ravel(bstar)[:n], dtype=result_t, copy=True)
if results['rank'] == n and m > n:
if isComplexType(t):
resids = array([sum(abs(ravel(bstar)[n:])**2)],
dtype=result_real_t)
else:
resids = array([sum((ravel(bstar)[n:])**2)],
dtype=result_real_t)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what makes no sense at all, this branch produces the same effect as the one that follows it.


# undo transpose imposed by fortran-order arrays
b_out = bstar.T

# b_out contains both the solution and the components of the residuals
x = b_out[:n,:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8, no alignment like this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, PEP8 shows a bunch of other whitespace violations in linalg.py, so we could probably use a style PR to clean those up at some point.

r_parts = b_out[n:,:]
if isComplexType(t):
resids = sum(abs(r_parts)**2, axis=-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish we had a sensible power or so, but to avoid a needless square root, one can do
sum(r_parts.real**2 + r_parts.imag**2, axis=-2)

Copy link
Member Author
@eric-wieser eric-wieser Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r_parts * r_parts.conj() is probably a little faster, and also removes the branching. I'd rather leave this untouched though for now, since that would probably change results by a ULP.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not (at least on my machine), but fine to let this be.

Copy link
Member Author
@eric-wieser eric-wieser Nov 9, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not faster, or it's not guilty of introducing the ULP error?

Seems to me that there must be some value for which abs(x)**2 != x * x.conj(). Of course, the x * x.conj() value is closer to the true result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just meant that x*x.conj() is slower than x.real**2 + x.imag**2 (which makes sense, as the former does a few useless multiplications that cancel). I do agree that there must be values of abs(x)**2 that are slightly less correct, given the sqrt and square after calculating x.real**2+x.imag**2

Anyway, fine to not worry about it here!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've contemplated adding a ufunc for the squared absolute value, the main problem seems to be the name.

else:
x = array(bstar.T[:n,:], dtype=result_t, copy=True)
if results['rank'] == n and m > n:
if isComplexType(t):
resids = sum(abs(bstar.T[n:,:])**2, axis=0).astype(
result_real_t, copy=False)
else:
resids = sum((bstar.T[n:,:])**2, axis=0).astype(
result_real_t, copy=False)
resids = sum(r_parts**2, axis=-2)

rank = results['rank']

st = s[:min(n, m)].astype(result_real_t, copy=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This slice was pointless, because len(s) == min(n,m)

return wrap(x), wrap(resids), results['rank'], st
# remove the axis we added
if is_1d:
x = x.squeeze(axis=-1)
# we probably should squeeze resids too, but we can't
# without breaking compatibility.

# as documented
if rank != n or m <= n:
resids = array([], result_real_t)
Copy link
Member Author
@eric-wieser eric-wieser Nov 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bizarre interface, and resids already contains 0 in the m <= n case, which is a more meaningful way to say "no residual" than []. But we're stuck with it, because that's how it's documented.


# coerce output arrays
s = s.astype(result_real_t, copy=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the copy=True here; s is created in this routine, so no need to copy, it would seem.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree - kept only because it was there before.

resids = resids.astype(result_real_t, copy=False) # array is temporary
x = x.astype(result_t, copy=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x is a view, so I guess it makes sense to copy. Maybe note that? (Also, copy=True is the default.)

Copy link
Member F438 Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy=Trues confuse me, since as you note, they're the default. In fact, before #9888 there was a reasonable amount of code devoted to passing that argument.

Maybe this is trying to deal with a subclass that has a different default for copy?

Comment seems reasonable here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's fine. If one were to design this from scratch, one would do the coercion only if an output array was given...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking back, the copy=False arguments were introduced in #5909, and the =True is deliberate and for clarity.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If one were to design this from scratch, one would do the coercion only if an output array was given

Or maybe just work with the dtype passed in, rather than always promoting to double before handing off to the ufunc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works only if one also uses different LAPACK routines (which is fine, of course), and would be less precise. But seems more logical in any case; just a different rcond.

return wrap(x), wrap(resids), rank, s


def _multi_svd_norm(x, row_axis, col_axis, op):
Expand Down
0