8000 Merge pull request #11895 from QuLogic/linalg-parametrize · eric-wieser/numpy@b8f3be9 · GitHub
[go: up one dir, main page]

Skip to content

Commit b8f3be9

Browse files
authored
Merge pull request numpy#11895 from QuLogic/linalg-parametrize
TST: Parametrize some linalg tests over types.
2 parents f8141ce + 6af54ed commit b8f3be9

File tree

1 file changed

+52
-70
lines changed

1 file changed

+52
-70
lines changed

numpy/linalg/tests/test_linalg.py

Lines changed: 52 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -462,12 +462,10 @@ def do(self, a, b, tags):
462462

463463

464464
class TestSolve(SolveCases):
465-
def test_types(self):
466-
def check(dtype):
467-
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
468-
assert_equal(linalg.solve(x, x).dtype, dtype)
469-
for dtype in [single, double, csingle, cdouble]:
470-
check(dtype)
465+
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
466+
def test_types(self, dtype):
467+
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
468+
assert_equal(linalg.solve(x, x).dtype, dtype)
471469

472470
def test_0_size(self):
473471
class ArraySubclass(np.ndarray):
@@ -531,12 +529,10 @@ def do(self, a, b, tags):
531529

532530

533531
class TestInv(InvCases):
534-
def test_types(self):
535-
def check(dtype):
536-
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
537-
assert_equal(linalg.inv(x).dtype, dtype)
538-
for dtype in [single, double, csingle, cdouble]:
539-
check(dtype)
532+
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
533+
def test_types(self, dtype):
534+
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
535+
assert_equal(linalg.inv(x).dtype, dtype)
540536

541537
def test_0_size(self):
542538
# Check that all kinds of 0-sized arrays work
@@ -564,14 +560,12 @@ def do(self, a, b, tags):
564560

565561

566562
class TestEigvals(EigvalsCases):
567-
def test_types(self):
568-
def check(dtype):
569-
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
570-
assert_equal(linalg.eigvals(x).dtype, dtype)
571-
x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
572-
assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype))
573-
for dtype in [single, double, csingle, cdouble]:
574-
check(dtype)
563+
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
564+
def test_types(self, dtype):
565+
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
566+
assert_equal(linalg.eigvals(x).dtype, dtype)
567+
x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
568+
assert_equal(linalg.eigvals(x).dtype, get_complex_dtype(dtype))
575569

576570
def test_0_size(self):
577571
# Check that all kinds of 0-sized arrays work
@@ -603,20 +597,17 @@ def do(self, a, b, tags):
603597

604598

605599
class TestEig(EigCases):
606-
def test_types(self):
607-
def check(dtype):
608-
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
609-
w, v = np.linalg.eig(x)
610-
assert_equal(w.dtype, dtype)
611-
assert_equal(v.dtype, dtype)
612-
613-
x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
614-
w, v = np.linalg.eig(x)
615-
assert_equal(w.dtype, get_complex_dtype(dtype))
616-
assert_equal(v.dtype, get_complex_dtype(dtype))
617-
618-
for dtype in [single, double, csingle, cdouble]:
619-
check(dtype)
600+
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
601+
def test_types(self, dtype):
602+
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
603+
w, v = np.linalg.eig(x)
604+
assert_equal(w.dtype, dtype)
605+
assert_equal(v.dtype, dtype)
606+
607+
x = np.array([[1, 0.5], [-1, 1]], dtype=dtype)
608+
w, v = np.linalg.eig(x)
609+
assert_equal(w.dtype, get_complex_dtype(dtype))
610+
assert_equal(v.dtype, get_complex_dtype(dtype))
620611

621612
def test_0_size(self):
622613
# Check that all kinds of 0-sized arrays work
@@ -653,18 +644,15 @@ def do(self, a, b, tags):
653644

654645

655646
class TestSVD(SVDCases):
656-
def test_types(self):
657-
def check(dtype):
658-
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
659-
u, s, vh = linalg.svd(x)
660-
assert_equal(u.dtype, dtype)
661-
assert_equal(s.dtype, get_real_dtype(dtype))
662-
assert_equal(vh.dtype, dtype)
663-
s = linalg.svd(x, compute_uv=False)
664-
assert_equal(s.dtype, get_real_dtype(dtype))
665-
666-
for dtype in [single, double, csingle, cdouble]:
667-
check(dtype)
647+
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
648+
def test_types(self, dtype):
649+
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
650+
u, s, vh = linalg.svd(x)
651+
assert_equal(u.dtype, dtype)
652+
assert_equal(s.dtype, get_real_dtype(dtype))
653+
assert_equal(vh.dtype, dtype)
654+
s = linalg.svd(x, compute_uv=False)
655+
assert_equal(s.dtype, get_real_dtype(dtype))
668656

669657
def test_empty_identity(self):
670658
""" Empty input should put an identity matrix in u or vh """
@@ -842,15 +830,13 @@ def test_zero(self):
842830
assert_equal(type(linalg.slogdet([[0.0j]])[0]), cdouble)
843831
assert_equal(type(linalg.slogdet([[0.0j]])[1]), double)
844832

845-
def test_types(self):
846-
def check(dtype):
847-
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
848-
assert_equal(np.linalg.det(x).dtype, dtype)
849-
ph, s = np.linalg.slogdet(x)
850-
assert_equal(s.dtype, get_real_dtype(dtype))
851-
assert_equal(ph.dtype, dtype)
852-
for dtype in [single, double, csingle, cdouble]:
853-
check(dtype)
833+
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
834+
def test_types(self, dtype):
835+
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
836+
assert_equal(np.linalg.det(x).dtype, dtype)
837+
ph, s = np.linalg.slogdet(x)
838+
assert_equal(s.dtype, get_real_dtype(dtype))
839+
assert_equal(ph.dtype, dtype)
854840

855841
def test_0_size(self):
856842
a = np.zeros((0, 0), dtype=np.complex64)
@@ -1049,13 +1035,11 @@ def do(self, a, b, tags):
10491035

10501036

10511037
class TestEigvalsh(object):
1052-
def test_types(self):
1053-
def check(dtype):
1054-
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
1055-
w = np.linalg.eigvalsh(x)
1056-
assert_equal(w.dtype, get_real_dtype(dtype))
1057-
for dtype in [single, double, csingle, cdouble]:
1058-
check(dtype)
1038+
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
1039+
def test_types(self, dtype):
1040+
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
1041+
w = np.linalg.eigvalsh(x)
1042+
assert_equal(w.dtype, get_real_dtype(dtype))
10591043

10601044
def test_invalid(self):
10611045
x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)
@@ -1127,14 +1111,12 @@ def do(self, a, b, tags):
11271111

11281112

11291113
class TestEigh(object):
1130-
def test_types(self):
1131-
def check(dtype):
1132-
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
1133-
w, v = np.linalg.eigh(x)
1134-
assert_equal(w.dtype, get_real_dtype(dtype))
1135-
assert_equal(v.dtype, dtype)
1136-
for dtype in [single, double, csingle, cdouble]:
1137-
check(dtype)
1114+
@pytest.mark.parametrize('dtype', [single, double, csingle, cdouble])
1115+
def test_types(self, dtype):
1116+
x = np.array([[1, 0.5], [0.5, 1]], dtype=dtype)
1117+
w, v = np.linalg.eigh(x)
1118+
assert_equal(w.dtype, get_real_dtype(dtype))
1119+
assert_equal(v.dtype, dtype)
11381120

11391121
def test_invalid(self):
11401122
x = np.array([[1, 0.5], [0.5, 1]], dtype=np.float32)

0 commit comments

Comments
 (0)
0