8000 BUG: Fixed piecewise function for 0d input · numpy/numpy@292b9ff · GitHub
[go: up one dir, main page]

Skip to content

Commit 292b9ff

Browse files
committed
BUG: Fixed piecewise function for 0d input
When `x` has more than one element the condlist `[True, False]` is being made equivalent to `[[True, False]]`, which is correct. However, when `x` is zero dimensional the expected condlist is `[[True], [False]]`: this commit addresses the issue. Besides, the documentation stated that there could be undefined values but actually these are 0 by default: using `nan` would be desirable, but for the moment the docs were corrected. Closes #331.
1 parent db710ce commit 292b9ff

File tree

2 files changed

+22
-20
lines changed

2 files changed

+22
-20
lines changed

numpy/lib/function_base.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ def piecewise(x, condlist, funclist, *args, **kw):
651651
The output is the same shape and type as x and is found by
652652
calling the functions in `funclist` on the appropriate portions of `x`,
653653
as defined by the boolean arrays in `condlist`. Portions not covered
654-
by any condition have undefined values.
654+
by any condition have a default value of 0.
655655
656656
657657
See Also
@@ -693,32 +693,24 @@ def piecewise(x, condlist, funclist, *args, **kw):
693693
if (isscalar(condlist) or not (isinstance(condlist[0], list) or
694694
isinstance(condlist[0], ndarray))):
695695
condlist = [condlist]
696-
condlist = [asarray(c, dtype=bool) for c in condlist]
696+
condlist = array(condlist, dtype=bool)
697697
n = len(condlist)
698-
if n == n2 - 1: # compute the "otherwise" condition.
699-
totlist = condlist[0]
700-
for k in range(1, n):
701-
totlist |= condlist[k]
702-
condlist.append(~totlist)
703-
n += 1
704-
if (n != n2):
705-
raise ValueError(
706-
"function list and condition list must be the same")
707-
zerod = False
708698
# This is a hack to work around problems with NumPy's
709699
# handling of 0-d arrays and boolean indexing with
710700
# numpy.bool_ scalars
701+
zerod = False
711702
if x.ndim == 0:
712703
x = x[None]
713704
zerod = True
714-
newcondlist = []
715-
for k in range(n):
716-
if condlist[k].ndim == 0:
717-
condition = condlist[k][None]
718-
else:
719-
condition = condlist[k]
720-
newcondlist.append(condition)
721-
condlist = newcondlist
705+
if condlist.shape[-1] != 1:
706+
condlist = condlist.T
707+
if n == n2 - 1: # compute the "otherwise" condition.
708+
totlist = np.logical_or.reduce(condlist, axis=0)
709+
condlist = np.vstack([condlist, ~totlist])
710+
n += 1
711+
if (n != n2):
712+
raise ValueError(
713+
"function list and condition list must be the same")
722714

723715
y = zeros(x.shape, x.dtype)
724716
for k in range(n):

numpy/lib/tests/test_function_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,7 @@ def test_simple(self):
14871487
x = piecewise([0, 0], [[False, True]], [lambda x:-1])
14881488
assert_array_equal(x, [0, -1])
14891489

1490+
def test_two_conditions(self):
14901491
x = piecewise([1, 2], [[True, False], [False, True]], [3, 4])
14911492
assert_array_equal(x, [3, 4])
14921493

@@ -1505,6 +1506,15 @@ def test_0d(self):
15051506
assert_(y.ndim == 0)
15061507
assert_(y == 0)
15071508

1509+
x = 5
1510+
y = piecewise(x, [[True], [False]], [1, 0])
1511+
assert_(y.ndim == 0)
1512+
assert_(y == 1)
1513+
1514+
def test_0d_comparison(self):
1515+
x = 3
1516+
y = piecewise(x, [x <= 3, x > 3], [4, 0])
1517+
15081518

15091519
class TestBincount(TestCase):
15101520
def test_simple(self):

0 commit comments

Comments
 (0)
0