@@ -334,5 +334,87 @@ def do(self, a, b):
334
334
pass
335
335
"""
336
336
337
+ ################################################################################
338
+ # ufuncs inspired by pdl
339
+ # - add3
340
+ # - multiply3
341
+ # - multiply3_add
342
+ # - multiply_add
343
+ # - multiply_add2
344
+ # - multiply4
345
+ # - multiply4_add
346
+
347
+ class UfuncTestCase (object ):
348
+ parameter = range (0 ,10 )
349
+
350
+ def _check_for_type (self , typ ):
351
+ a = np .array (self .__class__ .parameter , dtype = typ )
352
+ self .do (a )
353
+
354
+ def _check_for_type_vector (self , typ ):
355
+ parameter = self .__class__ .parameter
356
+ a = np .array ([parameter , parameter ], dtype = typ )
357
+ self .do (a )
358
+
359
+ def test_single (self ):
360
+ self ._check_for_type (single )
361
+
362
+ def test_double (self ):
363
+ self ._check_for_type (double )
364
+
365
+ def test_csingle (self ):
366
+ self ._check_for_type (csingle )
367
+
368
+ def test_cdouble (self ):
369
+ self ._check_for_type (cdouble )
370
+
371
+ def test_single_vector (self ):
372
+ self ._check_for_type_vector (single )
373
+
374
+ def test_double_vector (self ):
375
+ self ._check_for_type_vector (double )
376
+
377
+ def test_csingle_vector (self ):
378
+ self ._check_for_type_vector (csingle )
379
+
380
+ def test_cdouble_vector (self ):
381
+ self ._check_for_type_vector (cdouble )
382
+
383
+
384
+ class TestAdd3 (UfuncTestCase , TestCase ):
385
+ def do (self , a ):
386
+ r = gula .add3 (a ,a ,a )
387
+ assert_almost_equal (r , a + a + a )
388
+
389
+ class TestMultiply3 (UfuncTestCase , TestCase ):
390
+ def do (self , a ):
391
+ r = gula .multiply3 (a ,a ,a )
392
+ assert_almost_equal (r , a * a * a )
393
+
394
+ class TestMultiply3Add (UfuncTestCase , TestCase ):
395
+ def do (self , a ):
396
+ r = gula .multiply3_add (a ,a ,a ,a )
397
+ assert_almost_equal (r , a * a * a + a )
398
+
399
+ class TestMultiplyAdd (UfuncTestCase , TestCase ):
400
+ def do (self , a ):
401
+ r = gula .multiply_add (a ,a ,a )
402
+ assert_almost_equal (r , a * a + a )
403
+
404
+ class TestMultiplyAdd2 (UfuncTestCase , TestCase ):
405
+ def do (self , a ):
406
+ r = gula .multiply_add2 (a ,a ,a ,a )
407
+ assert_almost_equal (r , a * a + a + a )
408
+
409
+ class TestMultiply4 (UfuncTestCase , TestCase ):
410
+ def do (self , a ):
411
+ r = gula .multiply4 (a ,a ,a ,a )
412
+ assert_almost_equal (r , a * a * a * a )
413
+
414
+ class TestMultiply4_add (UfuncTestCase , TestCase ):
415
+ def do (self , a ):
416
+ r = gula .multiply4_add (a ,a ,a ,a ,a )
417
+ assert_almost_equal (r , a * a * a * a + a )
418
+
337
419
if __name__ == "__main__" :
338
420
run_module_suite ()
0 commit comments