8000 BUG: Return scalars from ufunclike objects · numpy/numpy@05f4228 · GitHub
[go: up one dir, main page]

Skip to content

Commit 05f4228

Browse files
committed
BUG: Return scalars from ufunclike objects
No need to reinvent the wheel here - the ufunc machinery will handle the out arguments Fixes #8993
1 parent 628f7b6 commit 05f4228

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

numpy/lib/tests/test_ufunclike.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import division, absolute_import, print_function
22

3+
import numpy as np
34
import numpy.core as nx
45
import numpy.lib.ufunclike as ufl
56
from numpy.testing import (
@@ -62,11 +63,35 @@ def __array_wrap__(self, obj, context=None):
6263
assert_(isinstance(f, MyArray))
6364
assert_equal(f.metadata, 'foo')
6465

66+
# check 0d arrays don't decay to scalars
67+
m0d = m[0,...]
68+
m0d.metadata = 'bar'
69+
f0d = ufl.fix(m0d)
70+
assert_(isinstance(f0d, MyArray))
71+
assert_equal(f0d.metadata, 'bar')
72+
6573
def test_deprecated(self):
6674
# NumPy 1.13.0, 2017-04-26
6775
assert_warns(DeprecationWarning, ufl.fix, [1, 2], y=nx.empty(2))
6876
assert_warns(DeprecationWarning, ufl.isposinf, [1, 2], y=nx.empty(2))
6977
assert_warns(DeprecationWarning, ufl.isneginf, [1, 2], y=nx.empty(2))
7078

79+
def test_scalar(self):
80+
x = np.inf
81+
actual = np.isposinf(x)
82+
expected = np.True_
83+
assert_equal(actual, expected)
84+
assert_equal(type(actual), type(expected))
85+
86+
x = -3.4
87+
actual = np.fix(x)
88+
expected = np.float64(-3.0)
89+
assert_equal(actual, expected)
90+
assert_equal(type(actual), type(expected))
91+
92+
out = np.array(0.0)
93+
actual = np.fix(x, out=out)
94+
assert_(actual is out)
95+
7196
if __name__ == "__main__":
7297
run_module_suite()

numpy/lib/ufunclike.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,15 @@ def fix(x, out=None):
7171
array([ 2., 2., -2., -2.])
7272
7373
"""
74-
x = nx.asanyarray(x)
75-
y1 = nx.floor(x)
76-
y2 = nx.ceil(x)
77-
if out is None:
78-
out = nx.asanyarray(y1)
79-
out[...] = nx.where(x >= 0, y1, y2)
80-
return out
74+
# promote back to an array if flattened
75+
res = nx.asanyarray(nx.ceil(x, out=out))
76+
res = nx.floor(x, out=res, where=nx.greater_equal(x, 0))
8177

78+
# when no out argument is passed and no subclasses are involved, flatten
79+
# scalars
80+
if out is None and type(res) is nx.ndarray:
81+
res = res[()]
82+
return res
8283

8384
@_deprecate_out_named_y
8485
def isposinf(x, out=None):
@@ -137,11 +138,7 @@ def isposinf(x, out=None):
137138
array([0, 0, 1])
138139
139140
"""
140-
if y is None:
141-
x = nx.asarray(x)
142-
out = nx.empty(x.shape, dtype=nx.bool_)
143-
nx.logical_and(nx.isinf(x), ~nx.signbit(x), out)
144-
return out
141+
return nx.logical_and(nx.isinf(x), ~nx.signbit(x), out)
145142

146143

147144
@_deprecate_out_named_y
@@ -202,8 +199,4 @@ def isneginf(x, out=None):
202199
array([1, 0, 0])
203200
204201
"""
205-
if out is None:
206-
x = nx.asarray(x)
207-
out = nx.empty(x.shape, dtype=nx.bool_)
208-
nx.logical_and(nx.isinf(x), nx.signbit(x), out)
209-
return out
202+
return nx.logical_and(nx.isinf(x), nx.signbit(x), out)

0 commit comments

Comments
 (0)
0