@@ -405,6 +405,10 @@ def check_promotion_cases(self, promote_func):
405
405
assert_equal (promote_func (cld ,f64 ), np .dtype (clongdouble ))
406
406
407
407
# coercion between scalars and 1-D arrays
408
+ assert_equal (promote_func (array ([b ]),i8 ), np .dtype (int8 ))
409
+ assert_equal (promote_func (array ([b ]),u8 ), np .dtype (uint8 ))
410
+ assert_equal (promote_func (array ([b ]),i32 ), np .dtype (int32 ))
411
+ assert_equal (promote_func (array ([b ]),u32 ), np .dtype (uint32 ))
408
412
assert_equal (promote_func (array ([i8 ]),i64 ), np .dtype (int8 ))
409
413
assert_equal (promote_func (u64 ,array ([i32 ])), np .dtype (int32 ))
410
414
assert_equal (promote_func (i64 ,array ([u32 ])), np .dtype (uint32 ))
@@ -429,7 +433,7 @@ def check_promotion_cases(self, promote_func):
429
433
assert_equal (promote_func (array ([u16 ]), i32 ), np .dtype (uint16 ))
430
434
# float and complex are treated as the same "kind" for
431
435
# the purposes of array-scalar promotion, so that you can do
432
- # (1j * float32array) to get a complex64 array instead of
436
+ # (0j + float32array) to get a complex64 array instead of
433
437
# a complex128 array.
434
438
assert_equal (promote_func (array ([f32 ]),c128 ), np .dtype (complex64 ))
435
439
@@ -438,6 +442,51 @@ def res_type(a, b):
438
442
return np .add (a , b ).dtype
439
443
self .check_promotion_cases (res_type )
440
444
445
+ # Use-case: float/complex scalar * bool/int8 array
446
+ # shouldn't narrow the float/complex type
447
+ for a in [np .array ([True ,False ]), np .array ([- 3 ,12 ], dtype = np .int8 )]:
448
+ b = 1.234 * a
449
+ assert_equal (b .dtype , np .dtype ('f8' ), "array type %s" % a .dtype )
450
+ b = np .longdouble (1.234 ) * a
451
+ assert_equal (b .dtype , np .dtype (np .longdouble ),
452
+ "array type %s" % a .dtype )
453
+ b = np .float64 (1.234 ) * a
454
+ assert_equal (b .dtype , np .dtype ('f8' ), "array type %s" % a .dtype )
455
+ b = np .float32 (1.234 ) * a
456
+ assert_equal (b .dtype , np .dtype ('f4' ), "array type %s" % a .dtype )
457
+ b = np .float16 (1.234 ) * a
458
+ assert_equal (b .dtype , np .dtype ('f2' ), "array type %s" % a .dtype )
459
+
460
+ b = 1.234j * a
461
+ assert_equal (b .dtype , np .dtype ('c16' ), "array type %s" % a .dtype )
462
+ b = np .clongdouble (1.234j ) * a
463
+ assert_equal (b .dtype , np .dtype (np .clongdouble ),
464
+ "array type %s" % a .dtype )
465
+ b = np .complex128 (1.234j ) * a
466
+ assert_equal (b .dtype , np .dtype ('c16' ), "array type %s" % a .dtype )
467
+ b = np .complex64 (1.234j ) * a
468
+ assert_equal (b .dtype , np .dtype ('c8' ), "array type %s" % a .dtype )
469
+
470
+ # The following use-case is problematic, and to resolve its
471
+ # tricky side-effects requires more changes.
472
+ #
473
+ ## Use-case: (1-t)*a, where 't' is a boolean array and 'a' is
474
+ ## a float32, shouldn't promote to float64
475
+ #a = np.array([1.0, 1.5], dtype=np.float32)
476
+ #t = np.array([True, False])
477
+ #b = t*a
478
+ #assert_equal(b, [1.0, 0.0])
479
+ #assert_equal(b.dtype, np.dtype('f4'))
480
+ #b = (1-t)*a
481
+ #assert_equal(b, [0.0, 1.5])
482
+ #assert_equal(b.dtype, np.dtype('f4'))
483
+ ## Probably ~t (bitwise negation) is more proper to use here,
484
+ ## but this is arguably less intuitive to understand at a glance, and
485
+ ## would fail if 't' is actually an integer array instead of boolean:
486
+ #b = (~t)*a
487
+ #assert_equal(b, [0.0, 1.5])
488
+ #assert_equal(b.dtype, np.dtype('f4'))
489
+
441
490
def test_result_type (self ):
442
491
self .check_promotion_cases (np .result_type )
443
492
0 commit comments