8000 added tests for ufuncs in gufuncs_linalg (the ones based on pdl). Add… · numpy/numpy@a75fb9e · GitHub
[go: up one dir, main page]

Skip to content

Commit a75fb9e

Browse files
ovillellaspv
ovillellas
authored andcommitted
added tests for ufuncs in gufuncs_linalg (the ones based on pdl). Added multiply4 in the wrapper, as it was missing
1 parent 2276eaa commit a75fb9e

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

numpy/core/src/umath/gufuncs_linalg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
multiply3_add = _impl.multiply3_add
1616
multiply_add = _impl.multiply_add
1717
multiply_add2 = _impl.multiply_add2
18+
multiply4 = _impl.multiply4
1819
multiply4_add = _impl.multiply4_add
1920
eig = _impl.eig
2021
eigvals = _impl.eigvals

numpy/core/tests/test_gufuncs_linalg.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,5 +334,87 @@ def do(self, a, b):
334334
pass
335335
"""
336336

337+
################################################################################
338+
# ufuncs inspired by pdl
339+
# - add3
340+
# - multiply3
341+
# - multiply3_add
342+
# - multiply_add
343+
# - multiply_add2
344+
# - multiply4
345+
# - multiply4_add
346+
347+
class UfuncTestCase(object):
348+
parameter = range(0,10)
349+
350+
def _check_for_type(self, typ):
351+
a = np.array(self.__class__.parameter, dtype=typ)
352+
self.do(a)
353+
354+
def _check_for_type_vector(self, typ):
355+
parameter = self.__class__.parameter
356+
a = np.array([parameter, parameter], dtype=typ)
357+
self.do(a)
358+
359+
def test_single(self):
360+
self._check_for_type(single)
361+
362+
def test_double(self):
363+
self._check_for_type(double)
364+
365+
def test_csingle(self):
366+
self._check_for_type(csingle)
367+
368+
def test_cdouble(self):
369+
self._check_for_type(cdouble)
370+
371+
def test_single_vector(self):
372+
self._check_for_type_vector(single)
373+
374+
def test_double_vector(self):
375+
self._check_for_type_vector(double)
376+
377+
def test_csingle_vector(self):
378+
self._check_for_type_vector(csingle)
379+
380+
def test_cdouble_vector(self):
381+
self._check_for_type_vector(cdouble)
382+
383+
384+
class TestAdd3(UfuncTestCase, TestCase):
385+
def do(self, a):
386+
r = gula.add3(a,a,a)
387+
assert_almost_equal(r, a+a+a)
388+
389+
class TestMultiply3(UfuncTestCase, TestCase):
390+
def do(self, a):
391+
r = gula.multiply3(a,a,a)
392+
assert_almost_equal(r, a*a*a)
393+
394+
class TestMultiply3Add(UfuncTestCase, TestCase):
395+
def do(self, a):
396+
r = gula.multiply3_add(a,a,a,a)
397+
assert_almost_equal(r, a*a*a+a)
398+
399+
class TestMultiplyAdd(UfuncTestCase, TestCase):
400+
def do(self, a):
401+
r = gula.multiply_add(a,a,a)
402+
assert_almost_equal(r, a*a+a)
403+
404+
class TestMultiplyAdd2(UfuncTestCase, TestCase):
405+
def do(self, a):
406+
r = gula.multiply_add2(a,a,a,a)
407+
assert_almost_equal(r, a*a+a+a)
408+
409+
class TestMultiply4(UfuncTestCase, TestCase):
410+
def do(self, a):
411+
r = gula.multiply4(a,a,a,a)
412+
assert_almost_equal(r, a*a*a*a)
413+
414+
class TestMultiply4_add(UfuncTestCase, TestCase):
415+
def do(self, a):
416+
r = gula.multiply4_add(a,a,a,a,a)
417+
assert_almost_equal(r, a*a*a*a+a)
418+
337419
if __name__ == "__main__":
338420
run_module_suite()

0 commit comments

Comments
 (0)
0