-
-
Notifications
You must be signed in to change notification settings - Fork 11k
MAINT/BUG: Remove special-casing for 0d arrays, now that indexing with a single boolean is ok #9900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4712875
4729550
303941c
c875b13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1254,12 +1254,12 @@ def piecewise(x, condlist, funclist, *args, **kw): | |
|
||
The length of `condlist` must correspond to that of `funclist`. | ||
If one extra function is given, i.e. if | ||
``len(funclist) - len(condlist) == 1``, then that extra function | ||
``len(funclist) == len(condlist) + 1``, then that extra function | ||
is the default value, used wherever all conditions are false. | ||
funclist : list of callables, f(x,*args,**kw), or scalars | ||
Each function is evaluated over `x` wherever its corresponding | ||
condition is True. It should take an array as input and give an array | ||
or a scalar value as output. If, instead of a callable, | ||
condition is True. It should take a 1d array as input and give an 1d | ||
array or a scalar value as output. If, instead of a callable, | ||
a scalar is provided then a constant function (``lambda x: scalar``) is | ||
assumed. | ||
args : tuple, optional | ||
|
@@ -1323,25 +1323,24 @@ def piecewise(x, condlist, funclist, *args, **kw): | |
""" | ||
x = asanyarray(x) | ||
n2 = len(funclist) | ||
if (isscalar(condlist) or not (isinstance(condlist[0], list) or | ||
isinstance(condlist[0], ndarray))): | ||
if not isscalar(condlist) and x.size == 1 and x.ndim == 0: | ||
condlist = [[c] for c in condlist] | ||
else: | ||
condlist = [condlist] | ||
|
||
# undocumented: single condition is promoted to a list of one condition | ||
if isscalar(condlist) or ( | ||
not isinstance(condlist[0], (list, ndarray)) and x.ndim != 0): | ||
condlist = [condlist] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, was the idea to deprecate this? An alternative might be to just add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to deprecate this in future, but this patch is orthogonal to that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
condlist = array(condlist, dtype=bool) | ||
n = len(condlist) | ||
# This is a hack to work around problems with NumPy's | ||
# handling of 0-d arrays and boolean indexing with | ||
# numpy.bool_ scalars | ||
zerod = False | ||
if x.ndim == 0: | ||
x = x[None] | ||
zerod = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this worked around |
||
|
||
if n == n2 - 1: # compute the "otherwise" condition. | ||
condelse = ~np.any(condlist, axis=0, keepdims=True) | ||
condlist = np.concatenate([condlist, condelse], axis=0) | ||
n += 1 | ||
elif n != n2: | ||
raise ValueError( | ||
"with {} condition(s), either {} or {} functions are expected" | ||
.format(n, n, n+1) | ||
) | ||
|
||
y = zeros(x.shape, x.dtype) | ||
for k in range(n): | ||
|
@@ -1352,8 +1351,7 @@ def piecewise(x, condlist, funclist, *args, **kw): | |
vals = x[condlist[k]] | ||
if vals.size > 0: | ||
y[condlist[k]] = item(vals, *args, **kw) | ||
if zerod: | ||
y = y.squeeze() | ||
|
||
return y | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2514,6 +2514,11 @@ def test_simple(self): | |
x = piecewise([0, 0], [[False, True]], [lambda x:-1]) | ||
assert_array_equal(x, [0, -1]) | ||
|
||
assert_raises_regex(ValueError, '1 or 2 functions are expected', | ||
piecewise, [0, 0], [[False, True]], []) | ||
assert_raises_regex(ValueError, '1 or 2 functions are expected', | ||
piecewise, [0, 0], [[False, True]], [1, 2, 3]) | ||
|
||
def test_two_conditions(self): | ||
x = piecewise([1, 2], [[True, False], [False, True]], [3, 4]) | ||
assert_array_equal(x, [3, 4]) | ||
|
@@ -2538,7 +2543,7 @@ def test_0d(self): | |
assert_(y == 0) | ||
|
||
x = 5 | ||
y = piecewise(x, [[True], [False]], [1, 0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This didn't make any sense, and should have failed. The conditions should be the same shape as the |
||
y = piecewise(x, [True, False], [1, 0]) | ||
assert_(y.ndim == 0) | ||
assert_(y == 1) | ||
|
||
|
@@ -2556,6 +2561,17 @@ def test_0d_comparison(self): | |
y = piecewise(x, [x <= 3, (x > 3) * (x <= 5), x > 5], [1, 2, 3]) | ||
assert_array_equal(y, 2) | ||
|
||
assert_raises_regex(ValueError, '2 or 3 functions are expected', | ||
piecewise, x, [x <= 3, x > 3], [1]) | ||
assert_raises_regex(ValueError, '2 or 3 functions are expected', | ||
piecewise, x, [x <= 3, x > 3], [1, 1, 1, 1]) | ||
|
||
def test_0d_0d_condition(self): | ||
x = np.array(3) | ||
c = np.array(x > 3) | ||
y = piecewise(x, [c], [1, 2]) | ||
assert_equal(y, 2) | ||
|
||
def test_multidimensional_extrafunc(self): | ||
x = np.array([[-2.5, -1.5, -0.5], | ||
[0.5, 1.5, 2.5]]) | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line inserts an extra dim into cond to match what we to did to 0d x below - so is no longer needed.
Note that this would misfire on non-0d x sometimes - which is why a test below needed changing
Essentially, this line is replaced with
pass
, and then the nestedif
s are collapsed