8000 BUG: `take` casting logic with an `out=` argument · royJackman/numpy@024b453 · GitHub
[go: up one dir, main page]

Skip to content

Commit 024b453

Browse files
committed
BUG: take casting logic with an out= argument
Closes numpy#16319
1 parent 3f11db4 commit 024b453

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

numpy/ma/core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5972,7 +5972,17 @@ def take(self, indices, axis=None, out=None, mode='raise'):
59725972
if out is None:
59735973
out = _data.take(indices, axis=axis, mode=mode)[...].view(cls)
59745974
else:
5975-
np.take(_data, indices, axis=axis, mode=mode, out=out)
5975+
# Check if the numeric input values are within the range of the
5976+
# output dtype, if so, convert type and output, else raise error
5977+
in_dtype = self.dtype
5978+
out_dtype = out.dtype
5979+
if ntypes.issubdtype(in_dtype, np.number) and ntypes.issubdtype(out_dtype, np.number):
5980+
if np.logical_and(_data >= np.iinfo(out_dtype).min, _data <= np.iinfo(out_dtype).max).all():
5981+
np.take(_data.astype(out_dtype), indices, axis=axis, mode=mode, out=out)
5982+
else:
5983+
raise TypeError('Output format does not cover input range')
5984+
else:
5985+
np.take(_data, indices, axis=axis, mode=mode, out=out)
59765986
# Get the mask
59775987
if isinstance(out, MaskedArray):
59785988
if _mask is nomask:

numpy/ma/tests/test_core.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3485,6 +3485,27 @@ def test_take(self):
34853485
assert_equal(take(x, [0, 2], axis=1),
34863486
array([[10, 30], [40, 60]], mask=[[0, 1], [1, 0]]))
34873487

3488+
def test_take_dtype_conversion(self):
3489+
# Take to smaller dtype within range
3490+
a = arange(3, dtype=np.int16)
3491+
b = empty((3,), dtype=np.int8)
3492+
a.take([0,1,2], out=b)
3493+
assert_equal(a,b)
3494+
# Smaller dtype out of range
3495+
a = a + 128
3496+
with assert_raises(TypeError):
3497+
a.take([0,1,2], out=b)
3498+
# Larger dtype
3499+
b = empty((3,), dtype=np.int32)
3500+
a.take([0,1,2], out=b)
3501+
assert_equal(a, b)
3502+
3503+
# Take to smaller dtype within range float to int
3504+
a = np.random.rand(3).astype(np.float32)
3505+
b = empty((3,), dtype=np.int8)
3506+
a.take([0,1,2], out=b)
3507+
assert_equal(a.astype(np.int8), b)
3508+
34883509
def test_take_masked_indices(self):
34893510
# Test take w/ masked indices
34903511
a = np.array((40, 18, 37, 9, 22))

0 commit comments

Comments
 (0)
0