-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
ENH: support for empty matrices in linalg.lstsq #11594
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
cb0fc23
45d8c5d
2c05f70
3abfc05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2115,7 +2115,6 @@ def lstsq(a, b, rcond="warn"): | |
if is_1d: | ||
b = b[:, newaxis] | ||
_assertRank2(a, b) | ||
_assertNoEmpty2d(a, b) # TODO: relax this constraint | ||
m, n = a.shape[-2:] | ||
m2, n_rhs = b.shape[-2:] | ||
if m != m2: | ||
|
@@ -2146,7 +2145,16 @@ def lstsq(a, b, rcond="warn"): | |
|
||
signature = 'DDd->Ddid' if isComplexType(t) else 'ddd->ddid' | ||
extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq) | ||
if n_rhs == 0: | ||
# lapack can't handle n_rhs = 0 - so allocate the array one larger in that axis | ||
b = zeros(b.shape[:-2] + (m, n_rhs + 1), dtype=b.dtype) | ||
x, resids, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj) | ||
if m == 0: | ||
x[...] = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One of the tests should be checking for this case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding that... |
||
if n_rhs == 0: | ||
# remove the item we added | ||
x = x[..., :n_rhs] | ||
resids = resids[..., :n_rhs] | ||
|
||
# remove the axis we added | ||
if is_1d: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -875,14 +875,12 @@ def test_0_size(self): | |
class LstsqCases(LinalgSquareTestCase, LinalgNonsquareTestCase): | ||
|
||
def do(self, a, b, tags): | ||
if 'size-0' in tags: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eric-wieser: The entire early return has been removed. The tests below that bit will run. |
||
assert_raises(LinAlgError, linalg.lstsq, a, b) | ||
return | ||
|
||
arr = np.asarray(a) | ||
m, n = arr.shape | ||
u, s, vt = linalg.svd(a, 0) | ||
x, residuals, rank, sv = linalg.lstsq(a, b, rcond=-1) | ||
if m == 0: | ||
assert_((x == 0).all()) | ||
if m <= n: | ||
assert_almost_equal(b, dot(a, x)) | ||
assert_equal(rank, m) | ||
|
@@ -923,6 +921,30 @@ def test_future_rcond(self): | |
# Warning should be raised exactly once (first command) | ||
assert_(len(w) == 1) | ||
|
||
@pytest.mark.parametrize(["m", "n", "n_rhs"], [ | ||
(4, 2, 2), | ||
(0, 4, 1), | ||
(0, 4, 2), | ||
(4, 0, 1), | ||
(4, 0, 2), | ||
(4, 2, 0), | ||
(0, 0, 0) | ||
]) | ||
def test_empty_a_b(self, m, n, n_rhs): | ||
a = np.arange(m * n).reshape(m, n) | ||
b = np.ones((m, n_rhs)) | ||
x, residuals, rank, s = linalg.lstsq(a, b, rcond=None) | ||
if m == 0: | ||
assert_((x == 0).all()) | ||
assert_equal(x.shape, (n, n_rhs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test for the contents of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
in the case where We could fill it with zeros, of course. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ... but the test for the residuals below are enough. I'm adding a regular (non-empty) case to make sure things are correct. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Worse, it might be filled with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be clear, I'm not advocating that we fill with nan / inf - I'm saying that some invariants you'd expect for uninitialized nonsense, such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Of course not everyone hates nan's... but when a is 0 × n, post multiplying will have a first dimension of 0. =P (Actually, they are pretty good for debugging numerical PDE solvers.) |
||
assert_equal(residuals.shape, ((n_rhs,) if m > n else (0,))) | ||
if m > n and n_rhs > 0: | ||
# residuals are exactly the squared norms of b's columns | ||
r = b - np.dot(a, x) | ||
assert_almost_equal(residuals, (r * r).sum(axis=-2)) | ||
assert_equal(rank, min(m, n)) | ||
assert_equal(s.shape, (min(m, n),)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How much of this block could be replaced with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that part can be left to another enhancement. Because I would have to add test cases that apply to everything. That's fine, and it would be great.
Currently, the Once the two are in, we won't have any more There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait... that's the norm... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then I'm conflicted because there is no inverse of an empty matrix. Maybe that bit with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The whole There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd leave |
||
|
||
|
||
class TestMatrixPower(object): | ||
R90 = array([[0, 1], [-1, 0]]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eric-wieser: This will have to be replaced...
Possibly with:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me. You should merge / rebase on master so that you can combine the messages