@@ -250,6 +250,8 @@ def _make_generalized_cases():
250
250
a = np .array ([case .a , 2 * case .a , 3 * case .a ])
251
251
if case .b is None :
252
252
b = None
253
+ elif case .b .ndim == 1 :
254
+ b = case .b
253
255
else :
254
256
b = np .array ([case .b , 7 * case .b , 6 * case .b ])
255
257
new_case = LinalgCase (case .name + "_tile3" , a , b ,
@@ -259,6 +261,9 @@ def _make_generalized_cases():
259
261
a = np .array ([case .a ] * 2 * 3 ).reshape ((3 , 2 ) + case .a .shape )
260
262
if case .b is None :
261
263
b = None
264
+ elif case .b .ndim == 1 :
265
+ b = np .array ([case .b ] * 2 * 3 * a .shape [- 1 ])\
266
+ .reshape ((3 , 2 ) + case .a .shape [- 2 :])
262
267
else :
263
268
b = np .array ([case .b ] * 2 * 3 ).reshape ((3 , 2 ) + case .b .shape )
264
269
new_case = LinalgCase (case .name + "_tile213" , a , b ,
@@ -432,25 +437,6 @@ def test_generalized_empty_herm_cases(self):
432
437
exclude = {'none' })
433
438
434
439
435
- def dot_generalized (a , b ):
436
- a = asarray (a )
437
- if a .ndim >= 3 :
438
- if a .ndim == b .ndim :
439
- # matrix x matrix
440
- new_shape = a .shape [:- 1 ] + b .shape [- 1 :]
441
- elif a .ndim == b .ndim + 1 :
442
- # matrix x vector
443
- new_shape = a .shape [:- 1 ]
444
- else :
445
- raise ValueError ("Not implemented..." )
446
- r = np .empty (new_shape , dtype = np .common_type (a , b ))
447
- for c in itertools .product (* map (range , a .shape [:- 2 ])):
448
- r [c ] = dot (a [c ], b [c ])
449
- return r
450
- else :
451
- return dot (a , b )
452
-
453
-
454
440
def identity_like_generalized (a ):
455
441
a = asarray (a )
456
442
if a .ndim >= 3 :
@@ -465,7 +451,14 @@ class SolveCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
465
451
# kept apart from TestSolve for use for testing with matrices.
466
452
def do (self , a , b , tags ):
467
453
x = linalg .solve (a , b )
468
- assert_almost_equal (b , dot_generalized (a , x ))
454
+ if np .array (b ).ndim == 1 :
455
+ # When a is (..., M, M) and b is (M,), it is the same as when b is
456
+ # (M, 1), except the result has shape (..., M)
457
+ adotx = matmul (a , x [..., None ])[..., 0 ]
458
+ assert_almost_equal (np .broadcast_to (b , adotx .shape ), adotx )
459
+ else :
460
+ adotx = matmul (a , x )
461
+ assert_almost_equal (b , adotx )
469
462
assert_ (consistent_subclass (x , b ))
470
463
471
464
@@ -475,6 +468,23 @@ def test_types(self, dtype):
475
468
x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
476
469
assert_equal (linalg .solve (x , x ).dtype , dtype )
477
470
471
+ def test_1_d (self ):
472
+ class ArraySubclass (np .ndarray ):
473
+ pass
474
+ a = np .arange (8 ).reshape (2 , 2 , 2 )
475
+ b = np .arange (2 ).view (ArraySubclass )
476
+ result = linalg .solve (a , b )
477
+ assert result .shape == (2 , 2 )
478
+
479
+ # If b is anything other than 1-D it should be treated as a stack of
480
+ # matrices
481
+ b = np .arange (4 ).reshape (2 , 2 ).view (ArraySubclass )
482
+ result = linalg .solve (a , b )
483
+ assert result .shape == (2 , 2 , 2 )
484
+
485
+ b = np .arange (2 ).reshape (1 , 2 ).view (ArraySubclass )
486
+ assert_raises (ValueError , linalg .solve , a , b )
487
+
478
488
def test_0_size (self ):
479
489
class ArraySubclass (np .ndarray ):
480
490
pass
@@ -497,9 +507,9 @@ class ArraySubclass(np.ndarray):
497
507
assert_raises (ValueError , linalg .solve , a [0 :0 ], b [0 :0 ])
498
508
499
509
# Test zero "single equations" with 0x0 matrices.
500
- b = np .arange (2 ).reshape ( 1 , 2 ). view (ArraySubclass )
510
+ b = np .arange (2 ).view (ArraySubclass )
501
511
expected = linalg .solve (a , b )[:, 0 :0 ]
502
- result = linalg .solve (a [:, 0 :0 , 0 :0 ], b [:, 0 :0 ])
512
+ result = linalg .solve (a [:, 0 :0 , 0 :0 ], b [0 :0 ])
503
513
assert_array_equal (result , expected )
504
514
assert_ (isinstance (result , ArraySubclass ))
505
515
@@ -531,7 +541,7 @@ class InvCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
531
541
532
542
def do (self , a , b , tags ):
533
543
a_inv = linalg .inv (a )
534
- assert_almost_equal (dot_generalized (a , a_inv ),
544
+ assert_almost_equal (matmul (a , a_inv ),
535
545
identity_like_generalized (a ))
536
546
assert_ (consistent_subclass (a_inv , a ))
537
547
@@ -599,7 +609,7 @@ class EigCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
599
609
def do (self , a , b , tags ):
600
610
res = linalg .eig (a )
601
611
eigenvalues , eigenvectors = res .eigenvalues , res .eigenvectors
602
- assert_allclose (dot_generalized (a , eigenvectors ),
612
+ assert_allclose (matmul (a , eigenvectors ),
603
613
np .asarray (eigenvectors ) * np .asarray (eigenvalues )[..., None , :],
604
614
rtol = get_rtol (eigenvalues .dtype ))
605
615
assert_ (consistent_subclass (eigenvectors , a ))
@@ -660,7 +670,7 @@ class SVDCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase):
660
670
661
671
def do (self , a , b , tags ):
662
672
u , s , vt = linalg .svd (a , False )
663
- assert_allclose (a , dot_generalized (np .asarray (u ) * np .asarray (s )[..., None , :],
673
+ assert_allclose (a , matmul (np .asarray (u ) * np .asarray (s )[..., None , :],
664
674
np .asarray (vt )),
665
675
rtol = get_rtol (u .dtype ))
666
676
assert_ (consistent_subclass (u , a ))
@@ -693,7 +703,7 @@ class SVDHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
693
703
694
704
def do (self , a , b , tags ):
695
705
u , s , vt = linalg .svd (a , False , hermitian = True )
696
- assert_allclose (a , dot_generalized (np .asarray (u ) * np .asarray (s )[..., None , :],
706
+ assert_allclose (a , matmul (np .asarray (u ) * np .asarray (s )[..., None , :],
697
707
np .asarray (vt )),
698
708
rtol = get_rtol (u .dtype ))
699
709
def hermitian (mat ):
@@ -833,7 +843,7 @@ class PinvCases(LinalgSquareTestCase,
833
843
def do (self , a , b , tags ):
834
844
a_ginv = linalg .pinv (a )
835
845
# `a @ a_ginv == I` does not hold if a is singular
836
- dot = dot_generalized
846
+ dot = matmul
837
847
assert_almost_equal (dot (dot (a , a_ginv ), a ), a , single_decimal = 5 , double_decimal = 11 )
838
848
assert_ (consistent_subclass (a_ginv , a ))
839
849
@@ -847,7 +857,7 @@ class PinvHermitianCases(HermitianTestCase, HermitianGeneralizedTestCase):
847
857
def do (self , a , b , tags ):
848
858
a_ginv = linalg .pinv (a , hermitian = True )
849
859
# `a @ a_ginv == I` does not hold if a is singular
850
- dot = dot_generalized
860
+ dot = matmul
851
861
assert_almost_equal (dot (dot (a , a_ginv ), a ), a , single_decimal = 5 , double_decimal = 11 )
852
862
assert_ (consistent_subclass (a_ginv , a ))
853
863
@@ -1178,14 +1188,14 @@ def do(self, a, b, tags):
1178
1188
evalues .sort (axis = - 1 )
1179
1189
assert_almost_equal (ev , evalues )
1180
1190
1181
- assert_allclose (dot_generalized (a , evc ),
1191
+ assert_allclose (matmul (a , evc ),
1182
1192
np .asarray (ev )[..., None , :] * np .asarray (evc ),
1183
1193
rtol = get_rtol (ev .dtype ))
1184
1194
1185
1195
ev2 , evc2 = linalg .eigh (a , 'U' )
1186
1196
assert_almost_equal (ev2 , evalues )
1187
1197
1188
- assert_allclose (dot_generalized (a , evc2 ),
1198
+ assert_allclose (matmul (a , evc2 ),
1189
1199
np .asarray (ev2 )[..., None , :] * np .asarray (evc2 ),
1190
1200
rtol = get_rtol (ev .dtype ), err_msg = repr (a ))
1191
1201
0 commit comments