8000 Merge pull request #25914 from asmeurer/solve-broadcasting-change · numpy/numpy@7fc3d0f · GitHub
[go: up one dir, main page]

Skip to content

Commit 7fc3d0f

Browse files
authored
Merge pull request #25914 from asmeurer/solve-broadcasting-change
API: Remove broadcasting ambiguity from np.linalg.solve
2 parents ae8f784 + 064b55c commit 7fc3d0f

File tree

2 files changed

+52
-33
lines changed

2 files changed

+52
-33
lines changed

numpy/linalg/_linalg.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,13 +327,15 @@ def solve(a, b):
327327
----------
328328
a : (..., M, M) array_like
329329
Coefficient matrix.
330-
b : {(..., M,), (..., M, K)}, array_like
330+
b : {(M,), (..., M, K)}, array_like
331331
Ordinate or "dependent variable" values.
332332
333333
Returns
334334
-------
335335
x : {(..., M,), (..., M, K)} ndarray
336-
Solution to the system a x = b. Returned shape is identical to `b`.
336+
Solution to the system a x = b. Returned shape is (..., M) if b is
337+
shape (M,) and (..., M, K) if b is (..., M, K), where the "..." part is
338+
broadcasted between a and b.
337339
338340
Raises
339341
------
@@ -359,6 +361,13 @@ def solve(a, b):
359361
`lstsq` for the least-squares best "solution" of the
360362
system/equation.
361363
364+
.. versionchanged:: 2.0
365+
366+
The b array is only treated as a shape (M,) column vector if it is
367+
exactly 1-dimensional. In all other instances it is treated as a stack
368+
of (M, K) matrices. Previously b would be treated as a stack of (M,)
369+
vectors if b.ndim was equal to a.ndim - 1.
370+
362371
References
363372
----------
364373
.. [1] G. Strang, *Linear Algebra and Its Applications*, 2nd Ed., Orlando,
@@ -390,7 +399,7 @@ def solve(a, b):
390399

391400
# We use the b = (..., M,) logic, only if the number of extra dimensions
392401
# match exactly
393-
if b.ndim == a.ndim - 1:
402+
if b.ndim == 1:
394403
gufunc = _umath_linalg.solve1
395404
else:
396405
gufunc = _umath_linalg.solve

numpy/linalg/tests/test_linalg.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,8 @@ def _make_generalized_cases():
250250
a = np.array([case.a, 2 * case.a, 3 * case.a])
251251
if case.b is None:
252252
b = None
253+
elif case.b.ndim == 1:
254+
b = case.b
253255
else:
254256
b = np.array([case.b, 7 * case.b, 6 * case.b])
255257
new_case = LinalgCase(case.name + "_tile3", a, b,
@@ -259,6 +261,9 @@ def _make_generalized_cases():
259261
a = np.array([case.a] * 2 * 3).reshape((3, 2) + case.a.shape)
260262
if case.b is None:
261263
b = None
264+
elif case.b.ndim == 1:
265+
b = np.array([case.b] * 2 * 3 * a.shape[-1])\
266+
.reshape((3, 2) + case.a.shape[-2:])
262267
else:
263268
b = np.array([case.b] * 2 * 3).reshape((3, 2) + case.b.shape)
264269
new_case = LinalgCase(case.name + "_tile213", a, b,
@@ -432,25 +437,6 @@ def test_generalized_empty_herm_cases(self):
432437
exclude={'none'})
433438

434439

435-
def dot_generalized(a, b):
436-
a = asarray(a)
437-
if a.ndim >= 3:
438-
if a.ndim == b.ndim:
439-
# matrix x matrix
440-
new_shape = a.shape[:-1] + b.shape[-1:]
441-
elif a.ndim == b.ndim + 1:
442-
# matrix x vector
443-
new_shape = a.shape[:-1]
444-
else:
445-
raise ValueError("Not implemented...")
446-
r = np.empty(new_shape, dtype=np.common_type(a, b))
447-
for c in itertools.product(*map(range, a.shape[:-2])):
448-
r[c] = dot(a[c], b[c])
449-
return r
450-
else:
451-
return dot(a, b)
452-
453-
454440
def identity_like_generalized(a):
455441
a = asarray(a)
456442
if a.ndim >= 3:
@@ -465,7 +451,14 @@ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
465451
# kept apart from TestSolve for use for testing with matrices.
466452
def do(self, a, b, tags):
467453
x = linalg.solve(a, b)
468-
assert_almost_equal(b, dot_generalized(a, x))
454+
if np.array(b).ndim == 1:
455+
# When a is (..., M, M) and b is (M,), it is the same as when b is
456+
# (M, 1), except the result has shape (..., M)
457+
adotx = matmul(a, x[..., None])[..., 0]
458+
assert_almost_equal(np.broadcast_to(b, adotx.shape), adotx)
459+
else:
460+
adotx = matmul(a, x)
461+
assert_almost_equal(b, adotx)
469462
assert_(consistent_subclass(x, b))
470463

471464

@@ -475,6 +468,23 @@ def test_types(self, dtype):
475468
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
476469
assert_equal(linalg.solve(x, x).dtype, dtype)
477470

471+
def test_1_d(self):
472+
class ArraySubclass(np.ndarray):
473+
pass
474+
a = np.arange(8).reshape(2, 2, 2)
475+
b = np.arange(2).view(ArraySubclass)
476+
result = linalg.solve(a, b)
477+
assert result.shape == (2, 2)
478+
479+
# If b is anything other than 1-D it should be treated as a stack of
480+
# matrices
481+
b = np.arange(4).reshape(2, 2).view(ArraySubclass)
482+
result = linalg.solve(a, b)
483+
assert result.shape == (2, 2, 2)
484+
485+
b = np.arange(2).reshape(1, 2).view(ArraySubclass)
486+
assert_raises(ValueError, linalg.solve, a, b)
487+
478488
def test_0_size(self):
479489
class ArraySubclass(np.ndarray):
480490
pass
@@ -497,9 +507,9 @@ class ArraySubclass(np.ndarray):
497507
assert_raises(ValueError, linalg.solve, a[0:0], b[0:0])
498508

499509
# Test zero "single equations" with 0x0 matrices.
500-
b = np.arange(2).reshape(1, 2).view(ArraySubclass)
510+
b = np.arange(2).view(ArraySubclass)
501511
expected = linalg.solve(a, b)[:, 0:0]
502-
result = linalg.solve(a[:, 0:0, 0:0], b[:, 0:0])
512+
result = linalg.solve(a[:, 0:0, 0:0], b[0:0])
503513
assert_array_equal(result, expected)
504514
assert_(isinstance(result, ArraySubclass))
505515

@@ -531,7 +541,7 @@ class InvCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
531541

532542
def do(self, a, b, tags):
533543
a_inv = linalg.inv(a)
534-
assert_almost_equal(dot_generalized(a, a_inv),
544+
assert_almost_equal(matmul(a, a_inv),
535545
identity_like_generalized(a))
536546
assert_(consistent_subclass(a_inv, a))
537547

@@ -599,7 +609,7 @@ class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
599609
def do(self, a, b, tags):
600610
res = linalg.eig(a)
601611
eigenvalues, eigenvectors = res.eigenvalues, res.eigenvectors
602-
assert_allclose(dot_generalized(a, eigenvectors),
612+
assert_allclose(matmul(a, eigenvectors),
603613
np.asarray(eigenvectors) * np.asarray(eigenvalues)[..., None, :],
604614
rtol=get_rtol(eigenvalues.dtype))
605615
assert_(consistent_subclass(eigenvectors, a))
@@ -660,7 +670,7 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
660670

661671
def do(self, a, b, tags):
662672
u, s, vt = linalg.svd(a, False)
663-
assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :],
673+
assert_allclose(a, matmul(np.asarray(u) * np.asarray(s)[..., None, :],
664674
np.asarray(vt)),
665675
rtol=get_rtol(u.dtype))
666676
assert_(consistent_subclass(u, a))
@@ -693,7 +703,7 @@ class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
693703

694704
def do(self, a, b, tags):
695705
u, s, vt = linalg.svd(a, False, hermitian=True)
696-
assert_allclose(a, dot_generalized(np.asarray(u) * np.asarray(s)[..., None, :],
706+
assert_allclose(a, matmul(np.asarray(u) * np.asarray(s)[..., None, :],
697707
np.asarray(vt)),
698708
rtol=get_rtol(u.dtype))
699709
def hermitian(mat):
@@ -833,7 +843,7 @@ class PinvCases(LinalgSquareTestCase,
833843
def do(self, a, b, tags):
834844
a_ginv = linalg.pinv(a)
835845
# `a @ a_ginv == I` does not hold if a is singular
836-
dot = dot_generalized
846+
dot = matmul
837847
assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11)
838848
assert_(consistent_subclass(a_ginv, a))
839849

@@ -847,7 +857,7 @@ class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
847857
def do(self, a, b, tags):
848858
a_ginv = linalg.pinv(a, hermitian=True)
849859
# `a @ a_ginv == I` does not hold if a is singular
850-
dot = dot_generalized
860+
dot = matmul
851861
assert_almost_equal(dot(dot(a, a_ginv), a), a, single_decimal=5, double_decimal=11)
852862
assert_(consistent_subclass(a_ginv, a))
853863

@@ -1178,14 +1188,14 @@ def do(self, a, b, tags):
11781188
evalues.sort(axis=-1)
11791189
assert_almost_equal(ev, evalues)
11801190

1181-
assert_allclose(dot_generalized(a, evc),
1191+
assert_allclose(matmul(a, evc),
11821192
np.asarray(ev)[..., None, :] * np.asarray(evc),
11831193
rtol=get_rtol(ev.dtype))
11841194

11851195
ev2, evc2 = linalg.eigh(a, 'U')
11861196
assert_almost_equal(ev2, evalues)
11871197

1188-
assert_allclose(dot_generalized(a, evc2),
1198+
assert_allclose(matmul(a, evc2),
11891199
np.asarray(ev2)[..., None, :] * np.asarray(evc2),
11901200
rtol=get_rtol(ev.dtype), err_msg=repr(a))
11911201

0 commit comments

Comments
 (0)
0