8000 Fixed numpy.piecewise() and tests · numpy/numpy@e2ad1a7 · GitHub
[go: up one dir, main page]

Skip to content

Commit e2ad1a7

Browse files
committed
Fixed numpy.piecewise() and tests
Corrected tests for numpy.piecewise(), made them more systematic, and altered numpy.piecewise fix to hit all test cases.
1 parent e553e1b commit e2ad1a7

File tree

2 files changed

+134
-31
lines changed

2 files changed

+134
-31
lines changed

numpy/lib/function_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,9 @@ def piecewise(x, condlist, funclist, *args, **kw):
673673
"""
674674
x = asanyarray(x)
675675
n2 = len(funclist)
676-
if isscalar(condlist):
676+
if isscalar(condlist) or \
677+
(isinstance(condlist, np.ndarray) and condlist.ndim == 0) or \
678+
(x.ndim > 0 and condlist[0].ndim == 0):
677679
condlist = [condlist]
678680
condlist = [asarray(c, dtype=bool) for c in condlist]
679681
n = len(condlist)

numpy/lib/tests/test_function_base.py

Lines changed: 131 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,49 +1179,150 @@ def test_simple(self):
11791179

11801180

11811181
class TestPiecewise(TestCase):
1182+
def test_0d(self):
1183+
# Input: scalar
1184+
x = 5
1185+
1186+
# Condition: scalar bool
1187+
y = piecewise(x, x < 7, [1])
1188+
assert(y.ndim == 0)
1189+
assert(y == 1)
1190+
1191+
# Condition: singleton list of scalar bool
1192+
y = piecewise(x, [x < 7], [1])
1193+
assert(y == 1)
1194+
1195+
# Condition: 0-d array of bool
1196+
y = piecewise(x, np.array(x < 7), [1])
1197+
assert(y == 1)
1198+
1199+
# Condition: 1-d array of bool
1200+
y = piecewise(x, np.array([x < 7]), [1])
1201+
assert(y == 1)
1202+
1203+
# Condition: singleton list of 0-d array of bool
1204+
y = piecewise(x, [np.array(x < 7)], [1])
1205+
assert(y == 1)
1206+
1207+
# Condition: singleton list of 1-d array of bool
1208+
y = piecewise(x, [np.array([x < 7])], [1])
1209+
assert(y == 1)
1210+
1211+
# Condition: scalar int
1212+
y = piecewise(x, 1, [1])
1213+
assert(y == 1)
1214+
1215+
# Condition: singleton list of int
1216+
y = piecewise(x, [1], [1])
1217+
assert(y == 1)
1218+
1219+
# Condition: 1-d list of bools
1220+
y = piecewise(x, [x < 7, x >= 7], [1, 2])
1221+
assert(y == 1)
1222+
1223+
# Condition: 1-d list of bools (test alternative)
1224+
y = piecewise(x, [x >= 7, x < 7], [1, 2])
1225+
assert(y == 2)
1226+
1227+
# Condition: 1-d list of 0-d arrays of bools
1228+
y = piecewise(x, [np.array(x < 7), np.array(x >= 7)], [1, 2])
1229+
assert(y == 1)
1230+
1231+
# Input: 0-d array
1232+
x = np.array(5)
1233+
1234+
y = piecewise(x, x < 7, [1])
1235+
assert(y.ndim == 0)
1236+
assert(y == 1)
1237+
11821238
def test_simple(self):
1183-
# Condition is single bool list
1184-
x = piecewise([0, 0], [True, False], [1])
1185-
assert_array_equal(x, [1, 0])
1239+
# Input: 1-d array
1240+
x = np.array([3,5])
1241+
1242+
# Condition: bare array of bool
1243+
y = piecewise(x, x < 7, [1])
1244+
assert_array_equal(y, [1, 1])
11861245

1187-
# List of conditions: single bool list
1188-
x = piecewise([0, 0], [[True, False]], [1])
1189-
assert_array_equal(x, [1, 0])
1246+
# Make sure callables are called
1247+
y = piecewise(x, x < 7, [(lambda x: -x)])
1248+
assert_array_equal(y, [-3, -5])
11901249

1191-
# Conditions is single bool array
1192-
x = piecewise([0, 0], np.array([True, False]), [1])
1193-
assert_array_equal(x, [1, 0])
1250+
# Condition: singleton list of array of bool
1251+
y = piecewise(x, [ A3DB x < 7], [1])
1252+
assert_array_equal(y, [1, 1])
11941253

1195-
# Condition is single int array
1196-
x = piecewise([0, 0], np.array([1, 0]), [1])
1197-
assert_array_equal(x, [1, 0])
1254+
# Condition: (1,2) array of bool
1255+
y = piecewise(x, np.array([x < 7]), [1])
1256+
assert_array_equal(y, [1, 1])
11981257

1199-
# List of conditions: int array
1200-
x = piecewise([0, 0], [np.array([1, 0])], [1])
1201-
assert_array_equal(x, [1, 0])
1258+
# Condition: list of array of bool
1259+
y = piecewise(x, [x >= 4, x < 4], [1, 2])
1260+
assert_array_equal(y, [2, 1])
12021261

1262+
y = piecewise(x, [x > 7, x <= 7], [1, 2])
1263+
assert_array_equal(y, [2, 2])
12031264

1204-
x = piecewise([0, 0], [[False, True]], [lambda x:-1])
1205-
assert_array_equal(x, [0, -1])
1265+
y = piecewise(x, [x < 4, x >= 4], [1, 2])
1266+
assert_array_equal(y, [1, 2])
12061267

1207-
x = piecewise([1, 2], [[True, False], [False, True]], [3, 4])
1208-
assert_array_equal(x, [3, 4])
1268+
y = piecewise(x, np.array([x < 4, x >= 4]), [1, 2])
1269+
assert_array_equal(y, [1, 2])
12091270

12101271
def test_default(self):
1211-
# No value specified for x[1], should be 0
1212-
x = piecewise([1, 2], [True, False], [2])
1213-
assert_array_equal(x, [2, 0])
1272+
# Input: scalar
1273+
x = 5
12141274

1215-
# Should set x[1] to 3
1216-
x = piecewise([1, 2], [True, False], [2, 3])
1217-
assert_array_equal(x, [2, 3])
1275+
# built-in no-match: 0
12181276

1219-
def test_0d(self):
1220-
x = np.array(3)
1221-
y = piecewise(x, x > 3, [4, 0])
1222-
assert_(y.ndim == 0)
1223-
assert_(y == 0)
1277+
# Condition: scalar bool
1278+
y = piecewise(x, x > 7, [1])
1279+
assert_array_equal(y, 0)
1280+
1281+
# Condition: scalar int
1282+
y = piecewise(x, 0, [1])
1283+
assert_array_equal(y, 0)
1284+
1285+
# custom no-match
1286+
1287+
y = piecewise(x, x < 7, [1, 2])
1288+
assert_array_equal(y, [1])
1289+
1290+
y = piecewise(x, x > 7, [1, 2])
1291+
assert_array_equal(y, [2])
1292+
1293+
# Condition: scalar int
1294+
y = piecewise(x, 0, [1, 2])
1295+
assert_array_equal(y, [2])
1296+
1297+
# Input: 1-d array
1298+
x = np.array([3,5])
1299+
1300+
# built-in no-match: 0
1301+
1302+
y = piecewise(x, x > 7, [1])
1303+
assert_array_equal(y, [0,0])
1304+
1305+
y = piecewise(x, x < 4, [1])
1306+
assert_array_equal(y, [1, 0])
1307+
1308+
# custom no-match
1309+
1310+
y = piecewise(x, x < 7, [1, 2])
1311+
assert_array_equal(y, [1, 1])
1312+
1313+
y = piecewise(x, x > 7, [1, 2])
1314+
assert_array_equal(y, [2, 2])
1315+
1316+
y = piecewise(x, x < 4, [1, 2])
1317+
assert_array_equal(y, [1, 2])
1318+
1319+
# Condition: list of array of bool
1320+
y = piecewise(x, [x < 4], [1, 2])
1321+
assert_array_equal(y, [1, 2])
12241322

1323+
# Condition: (1,2) array of bool
1324+
y = piecewise(x, np.array([x < 4]), [1, 2])
1325+
assert_array_equal(y, [1, 2])
12251326

12261327
class TestBincount(TestCase):
12271328
def test_simple(self):

0 commit comments

Comments
 (0)
0