8000 MAINT/BUG: Remove special-casing for 0d arrays, now that indexing with a single boolean is ok by eric-wieser · Pull Request #9900 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT/BUG: Remove special-casing for 0d arrays, now that indexing with a single boolean is ok #9900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 24, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions numpy/lib/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,12 +1254,12 @@ def piecewise(x, condlist, funclist, *args, **kw):

The length of `condlist` must correspond to that of `funclist`.
If one extra function is given, i.e. if
``len(funclist) - len(condlist) == 1``, then that extra function
``len(funclist) == len(condlist) + 1``, then that extra function
is the default value, used wherever all conditions are false.
funclist : list of callables, f(x,*args,**kw), or scalars
Each function is evaluated over `x` wherever its corresponding
condition is True. It should take an array as input and give an array
or a scalar value as output. If, instead of a callable,
condition is True. It should take a 1d array as input and give an 1d
array or a scalar value as output. If, instead of a callable,
a scalar is provided then a constant function (``lambda x: scalar``) is
assumed.
args : tuple, optional
Expand Down Expand Up @@ -1323,25 +1323,24 @@ def piecewise(x, condlist, funclist, *args, **kw):
"""
x = asanyarray(x)
n2 = len(funclist)
if (isscalar(condlist) or not (isinstance(condlist[0], list) or
isinstance(condlist[0], ndarray))):
if not isscalar(condlist) and x.size == 1 and x.ndim == 0:
condlist = [[c] for c in condlist]
Copy link
Member Author
@eric-wieser eric-wieser Oct 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line inserts an extra dim into cond to match what we to did to 0d x below - so is no longer needed.

Note that this would misfire on non-0d x sometimes - which is why a test below needed changing

Essentially, this line is replaced with pass, and then the nested ifs are collapsed

else:
condlist = [condlist]

# undocumented: single condition is promoted to a list of one condition
if isscalar(condlist) or (
not isinstance(condlist[0], (list, ndarray)) and x.ndim != 0):
condlist = [condlist]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, was the idea to deprecate this? An alternative might be to just add ndmin=1 in the conversion to boolean array below.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to deprecate this in future, but this patch is orthogonal to that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndim=1 won't help - this is intended to also promote 1d arrays to a list of length 1 containing a single 1d array.

condlist = array(condlist, dtype=bool)
n = len(condlist)
# This is a hack to work around problems with NumPy's
# handling of 0-d arrays and boolean indexing with
# numpy.bool_ scalars
zerod = False
if x.ndim == 0:
x = x[None]
zerod = True
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this worked around arr_0d[bool_0d] not being supported in the past


if n == n2 - 1: # compute the "otherwise" condition.
condelse = ~np.any(condlist, axis=0, keepdims=True)
condlist = np.concatenate([condlist, condelse], axis=0)
n += 1
elif n != n2:
raise ValueError(
"with {} condition(s), either {} or {} functions are expected"
.format(n, n, n+1)
)

y = zeros(x.shape, x.dtype)
for k in range(n):
Expand All @@ -1352,8 +1351,7 @@ def piecewise(x, condlist, funclist, *args, **kw):
vals = x[condlist[k]]
if vals.size > 0:
y[condlist[k]] = item(vals, *args, **kw)
if zerod:
y = y.squeeze()

return y


Expand Down
18 changes: 17 additions & 1 deletion numpy/lib/tests/test_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,11 @@ def test_simple(self):
x = piecewise([0, 0], [[False, True]], [lambda x:-1])
assert_array_equal(x, [0, -1])

assert_raises_regex(ValueError, '1 or 2 functions are expected',
piecewise, [0, 0], [[False, True]], [])
assert_raises_regex(ValueError, '1 or 2 functions are expected',
piecewise, [0, 0], [[False, True]], [1, 2, 3])

def test_two_conditions(self):
x = piecewise([1, 2], [[True, False], [False, True]], [3, 4])
assert_array_equal(x, [3, 4])
Expand All @@ -2538,7 +2543,7 @@ def test_0d(self):
assert_(y == 0)

x = 5
y = piecewise(x, [[True], [False]], [1, 0])
Copy link
Member Author
@eric-wieser eric-wieser Oct 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't make any sense, and should have failed. The conditions should be the same shape as the x. We got away with it because the 0d special-casing made the shapes match.

y = piecewise(x, [True, False], [1, 0])
assert_(y.ndim == 0)
assert_(y == 1)

Expand All @@ -2556,6 +2561,17 @@ def test_0d_comparison(self):
y = piecewise(x, [x <= 3, (x > 3) * (x <= 5), x > 5], [1, 2, 3])
assert_array_equal(y, 2)

assert_raises_regex(ValueError, '2 or 3 functions are expected',
piecewise, x, [x <= 3, x > 3], [1])
assert_raises_regex(ValueError, '2 or 3 functions are expected',
piecewise, x, [x <= 3, x > 3], [1, 1, 1, 1])

def test_0d_0d_condition(self):
x = np.array(3)
c = np.array(x > 3)
y = piecewise(x, [c], [1, 2])
assert_equal(y, 2)

def test_multidimensional_extrafunc(self):
x = np.array([[-2.5, -1.5, -0.5],
[0.5, 1.5, 2.5]])
Expand Down
0