@@ -462,12 +462,10 @@ def do(self, a, b, tags):
462
462
463
463
464
464
class TestSolve (SolveCases ):
465
- def test_types (self ):
466
- def check (dtype ):
467
- x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
468
- assert_equal (linalg .solve (x , x ).dtype , dtype )
469
- for dtype in [single , double , csingle , cdouble ]:
470
- check (dtype )
465
+ @pytest .mark .parametrize ('dtype' , [single , double , csingle , cdouble ])
466
+ def test_types (self , dtype ):
467
+ x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
468
+ assert_equal (linalg .solve (x , x ).dtype , dtype )
471
469
472
470
def test_0_size (self ):
473
471
class ArraySubclass (np .ndarray ):
@@ -531,12 +529,10 @@ def do(self, a, b, tags):
531
529
532
530
533
531 class TestInv (InvCases ):
534
- def test_types (self ):
535
- def check (dtype ):
536
- x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
537
- assert_equal (linalg .inv (x ).dtype , dtype )
538
- for dtype in [single , double , csingle , cdouble ]:
539
- check (dtype )
532
+ @pytest .mark .parametrize ('dtype' , [single , double , csingle , cdouble ])
533
+ def test_types (self , dtype ):
534
+ x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
535
+ assert_equal (linalg .inv (x ).dtype , dtype )
540
536
541
537
def test_0_size (self ):
542
538
# Check that all kinds of 0-sized arrays work
@@ -564,14 +560,12 @@ def do(self, a, b, tags):
564
560
565
561
566
562
class TestEigvals (EigvalsCases ):
567
- def test_types (self ):
568
- def check (dtype ):
569
- x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
570
- assert_equal (linalg .eigvals (x ).dtype , dtype )
571
- x = np .array ([[1 , 0.5 ], [- 1 , 1 ]], dtype = dtype )
572
- assert_equal (linalg .eigvals (x ).dtype , get_complex_dtype (dtype ))
573
- for dtype in [single , double , csingle , cdouble ]:
574
- check (dtype )
563
+ @pytest .mark .parametrize ('dtype' , [single , double , csingle , cdouble ])
564
+ def test_types (self , dtype ):
565
+ x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
566
+ assert_equal (linalg .eigvals (x ).dtype , dtype )
567
+ x = np .array ([[1 , 0.5 ], [- 1 , 1 ]], dtype = dtype )
568
+ assert_equal (linalg .eigvals (x ).dtype , get_complex_dtype (dtype ))
575
569
576
570
def test_0_size (self ):
577
571
# Check that all kinds of 0-sized arrays work
@@ -603,20 +597,17 @@ def do(self, a, b, tags):
603
597
604
598
605
599
class TestEig (EigCases ):
606
- def test_types (self ):
607
- def check (dtype ):
608
- x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
609
- w , v = np .linalg .eig (x )
610
- assert_equal (w .dtype , dtype )
611
- assert_equal (v .dtype , dtype )
612
-
613
- x = np .array ([[1 , 0.5 ], [- 1 , 1 ]], dtype = dtype )
614
- w , v = np .linalg .eig (x )
615
- assert_equal (w .dtype , get_complex_dtype (dtype ))
616
- assert_equal (v .dtype , get_complex_dtype (dtype ))
617
-
618
- for dtype in [single , double , csingle , cdouble ]:
619
- check (dtype )
600
+ @pytest .mark .parametrize ('dtype' , [single , double , csingle , cdouble ])
601
+ def test_types (self , dtype ):
602
+ x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
603
+ w , v = np .linalg .eig (x )
604
+ assert_equal (w .dtype , dtype )
605
+ assert_equal (v .dtype , dtype )
606
+
607
+ x = np .array ([[1 , 0.5 ], [- 1 , 1 ]], dtype = dtype )
608
+ w , v = np .linalg .eig (x )
609
+ assert_equal (w .dtype , get_complex_dtype (dtype ))
610
+ assert_equal (v .dtype , get_complex_dtype (dtype ))
620
611
621
612
def test_0_size (self ):
622
613
# Check that all kinds of 0-sized arrays work
@@ -653,18 +644,15 @@ def do(self, a, b, tags):
653
644
654
645
655
646
class TestSVD (SVDCases ):
656
- def test_types (self ):
657
- def check (dtype ):
658
- x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
659
- u , s , vh = linalg .svd (x )
660
- assert_equal (u .dtype , dtype )
661
- assert_equal (s .dtype , get_real_dtype (dtype ))
662
- assert_equal (vh .dtype , dtype )
663
- s = linalg .svd (x , compute_uv = False )
664
- assert_equal (s .dtype , get_real_dtype (dtype ))
665
-
666
- for dtype in [single , double , csingle , cdouble ]:
667
- check (dtype )
647
+ @pytest .mark .parametrize ('dtype' , [single , double , csingle , cdouble ])
648
+ def test_types (self , dtype ):
649
+ x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
650
+ u , s , vh = linalg .svd (x )
651
+ assert_equal (u .dtype , dtype )
652
+ assert_equal (s .dtype , get_real_dtype (dtype ))
653
+ assert_equal (vh .dtype , dtype )
654
+ s = linalg .svd (x , compute_uv = False )
655
+ assert_equal (s .dtype , get_real_dtype (dtype ))
668
656
669
657
def test_empty_identity (self ):
670
658
""" Empty input should put an identity matrix in u or vh """
@@ -842,15 +830,13 @@ def test_zero(self):
842
830
assert_equal (type (linalg .slogdet ([[0.0j ]])[0 ]), cdouble )
843
831
assert_equal (type (linalg .slogdet ([[0.0j ]])[1 ]), double )
844
832
845
- def test_types (self ):
846
- def check (dtype ):
847
- x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
848
- assert_equal (np .linalg .det (x ).dtype , dtype )
849
- ph , s = np .linalg .slogdet (x )
850
- assert_equal (s .dtype , get_real_dtype (dtype ))
851
- assert_equal (ph .dtype , dtype )
852
- for dtype in [single , double , csingle , cdouble ]:
853
- check (dtype )
833
+ @pytest .mark .parametrize ('dtype' , [single , double , csingle , cdouble ])
834
+ def test_types (self , dtype ):
835
+ x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
836
+ assert_equal (np .linalg .det (x ).dtype , dtype )
837
+ ph , s = np .linalg .slogdet (x )
838
+ assert_equal (s .dtype , get_real_dtype (dtype ))
839
+ assert_equal (ph .dtype , dtype )
854
840
855
841
def test_0_size (self ):
856
842
a = np .zeros ((0 , 0 ), dtype = np .complex64 )
@@ -1049,13 +1035,11 @@ def do(self, a, b, tags):
1049
1035
1050
1036
1051
1037
class TestEigvalsh (object ):
1052
- def test_types (self ):
1053
- def check (dtype ):
1054
- x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
1055
- w = np .linalg .eigvalsh (x )
1056
- assert_equal (w .dtype , get_real_dtype (dtype ))
1057
- for dtype in [single , double , csingle , cdouble ]:
1058
- check (dtype )
1038
+ @pytest .mark .parametrize ('dtype' , [single , double , csingle , cdouble ])
1039
+ def test_types (self , dtype ):
1040
+ x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
1041
+ w = np .linalg .eigvalsh (x )
1042
+ assert_equal (w .dtype , get_real_dtype (dtype ))
1059
1043
1060
1044
def test_invalid (self ):
1061
1045
x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = np .float32 )
@@ -1127,14 +1111,12 @@ def do(self, a, b, tags):
1127
1111
1128
1112
1129
1113
class TestEigh (object ):
1130
- def test_types (self ):
1131
- def check (dtype ):
1132
- x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
1133
- w , v = np .linalg .eigh (x )
1134
- assert_equal (w .dtype , get_real_dtype (dtype ))
1135
- assert_equal (v .dtype , dtype )
1136
- for dtype in [single , double , csingle , cdouble ]:
1137
- check (dtype )
1114
+ @pytest .mark .parametrize ('dtype' , [single , double , csingle , cdouble ])
1115
+ def test_types (self , dtype ):
1116
+ x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = dtype )
1117
+ w , v = np .linalg .eigh (x )
1118
+ assert_equal (w .dtype , get_real_dtype (dtype ))
1119
+ assert_equal (v .dtype , dtype )
1138
1120
1139
1121
def test_invalid (self ):
1140
1122
x = np .array ([[1 , 0.5 ], [0.5 , 1 ]], dtype = np .float32 )
0 commit comments