@@ -1179,49 +1179,150 @@ def test_simple(self):
1179
1179
1180
1180
1181
1181
class TestPiecewise (TestCase ):
1182
+ def test_0d (self ):
1183
+ # Input: scalar
1184
+ x = 5
1185
+
1186
+ # Condition: scalar bool
1187
+ y = piecewise (x , x < 7 , [1 ])
1188
+ assert (y .ndim == 0 )
1189
+ assert (y == 1 )
1190
+
1191
+ # Condition: singleton list of scalar bool
1192
+ y = piecewise (x , [x < 7 ], [1 ])
1193
+ assert (y == 1 )
1194
+
1195
+ # Condition: 0-d array of bool
1196
+ y = piecewise (x , np .array (x < 7 ), [1 ])
1197
+ assert (y == 1 )
1198
+
1199
+ # Condition: 1-d array of bool
1200
+ y = piecewise (x , np .array ([x < 7 ]), [1 ])
1201
+ assert (y == 1 )
1202
+
1203
+ # Condition: singleton list of 0-d array of bool
1204
+ y = piecewise (x , [np .array (x < 7 )], [1 ])
1205
+ assert (y == 1 )
1206
+
1207
+ # Condition: singleton list of 1-d array of bool
1208
+ y = piecewise (x , [np .array ([x < 7 ])], [1 ])
1209
+ assert (y == 1 )
1210
+
1211
+ # Condition: scalar int
1212
+ y = piecewise (x , 1 , [1 ])
1213
+ assert (y == 1 )
1214
+
1215
+ # Condition: singleton list of int
1216
+ y = piecewise (x , [1 ], [1 ])
1217
+ assert (y == 1 )
1218
+
1219
+ # Condition: 1-d list of bools
1220
+ y = piecewise (x , [x < 7 , x >= 7 ], [1 , 2 ])
1221
+ assert (y == 1 )
1222
+
1223
+ # Condition: 1-d list of bools (test alternative)
1224
+ y = piecewise (x , [x >= 7 , x < 7 ], [1 , 2 ])
1225
+ assert (y == 2 )
1226
+
1227
+ # Condition: 1-d list of 0-d arrays of bools
1228
+ y = piecewise (x , [np .array (x < 7 ), np .array (x >= 7 )], [1 , 2 ])
1229
+ assert (y == 1 )
1230
+
1231
+ # Input: 0-d array
1232
+ x = np .array (5 )
1233
+
1234
+ y = piecewise (x , x < 7 , [1 ])
1235
+ assert (y .ndim == 0 )
1236
+ assert (y == 1 )
1237
+
1182
1238
def test_simple (self ):
1183
- # Condition is single bool list
1184
- x = piecewise ([0 , 0 ], [True , False ], [1 ])
1185
- assert_array_equal (x , [1 , 0 ])
1239
+ # Input: 1-d array
1240
+ x = np .array ([3 ,5 ])
1241
+
1242
+ # Condition: bare array of bool
1243
+ y = piecewise (x , x < 7 , [1 ])
1244
+ assert_array_equal (y , [1 , 1 ])
1186
1245
1187
- # List of conditions: single bool list
1188
- x = piecewise ([ 0 , 0 ], [[ True , False ]] , [1 ])
1189
- assert_array_equal (x , [1 , 0 ])
1246
+ # Make sure callables are called
1247
+ y = piecewise (x , x < 7 , [( lambda x : - x ) ])
1248
+ assert_array_equal (y , [- 3 , - 5 ])
1190
1249
1191
- # Conditions is single bool array
1192
- x = piecewise ([ 0 , 0 ], np . array ([ True , False ]) , [1 ])
1193
- assert_array_equal (x , [1 , 0 ])
1250
+ # Condition: singleton list of array of bool
1251
+ y = piecewise (x , [
A3DB
x < 7 ] , [1 ])
1252
+ assert_array_equal (y , [1 , 1 ])
1194
1253
1195
- # Condition is single int array
1196
- x = piecewise ([ 0 , 0 ], np .array ([1 , 0 ]), [1 ])
1197
- assert_array_equal (x , [1 , 0 ])
1254
+ # Condition: (1,2) array of bool
1255
+ y = piecewise (x , np .array ([x < 7 ]), [1 ])
1256
+ assert_array_equal (y , [1 , 1 ])
1198
1257
1199
- # List of conditions: int array
1200
- x = piecewise ([ 0 , 0 ], [ np . array ([ 1 , 0 ])] , [1 ])
1201
- assert_array_equal (x , [1 , 0 ])
1258
+ # Condition: list of array of bool
1259
+ y = piecewise (x , [ x >= 4 , x < 4 ] , [1 , 2 ])
1260
+ assert_array_equal (y , [2 , 1 ])
1202
1261
1262
+ y = piecewise (x , [x > 7 , x <= 7 ], [1 , 2 ])
1263
+ assert_array_equal (y , [2 , 2 ])
1203
1264
1204
- x = piecewise ([ 0 , 0 ], [[ False , True ]] , [lambda x : - 1 ])
1205
- assert_array_equal (x , [0 , - 1 ])
1265
+ y = piecewise (x , [ x < 4 , x >= 4 ] , [1 , 2 ])
1266
+ assert_array_equal (y , [1 , 2 ])
1206
1267
1207
- x = piecewise ([ 1 , 2 ], [[ True , False ], [ False , True ]] , [3 , 4 ])
1208
- assert_array_equal (x , [3 , 4 ])
1268
+ y = piecewise (x , np . array ([ x < 4 , x >= 4 ]) , [1 , 2 ])
1269
+ assert_array_equal (y , [1 , 2 ])
1209
1270
1210
1271
def test_default (self ):
1211
- # No value specified for x[1], should be 0
1212
- x = piecewise ([1 , 2 ], [True , False ], [2 ])
1213
- assert_array_equal (x , [2 , 0 ])
1272
+ # Input: scalar
1273
+ x = 5
1214
1274
1215
- # Should set x[1] to 3
1216
- x = piecewise ([1 , 2 ], [True , False ], [2 , 3 ])
1217
- assert_array_equal (x , [2 , 3 ])
1275
+ # built-in no-match: 0
1218
1276
1219
- def test_0d (self ):
1220
- x = np .array (3 )
1221
- y = piecewise (x , x > 3 , [4 , 0 ])
1222
- assert_ (y .ndim == 0 )
1223
- assert_ (y == 0 )
1277
+ # Condition: scalar bool
1278
+ y = piecewise (x , x > 7 , [1 ])
1279
+ assert_array_equal (y , 0 )
1280
+
1281
+ # Condition: scalar int
1282
+ y = piecewise (x , 0 , [1 ])
1283
+ assert_array_equal (y , 0 )
1284
+
1285
+ # custom no-match
1286
+
1287
+ y = piecewise (x , x < 7 , [1 , 2 ])
1288
+ assert_array_equal (y , [1 ])
1289
+
1290
+ y = piecewise (x , x > 7 , [1 , 2 ])
1291
+ assert_array_equal (y , [2 ])
1292
+
1293
+ # Condition: scalar int
1294
+ y = piecewise (x , 0 , [1 , 2 ])
1295
+ assert_array_equal (y , [2 ])
1296
+
1297
+ # Input: 1-d array
1298
+ x = np .array ([3 ,5 ])
1299
+
1300
+ # built-in no-match: 0
1301
+
1302
+ y = piecewise (x , x > 7 , [1 ])
1303
+ assert_array_equal (y , [0 ,0 ])
1304
+
1305
+ y = piecewise (x , x < 4 , [1 ])
1306
+ assert_array_equal (y , [1 , 0 ])
1307
+
1308
+ # custom no-match
1309
+
1310
+ y = piecewise (x , x < 7 , [1 , 2 ])
1311
+ assert_array_equal (y , [1 , 1 ])
1312
+
1313
+ y = piecewise (x , x > 7 , [1 , 2 ])
1314
+ assert_array_equal (y , [2 , 2 ])
1315
+
1316
+ y = piecewise (x , x < 4 , [1 , 2 ])
1317
+ assert_array_equal (y , [1 , 2 ])
1318
+
1319
+ # Condition: list of array of bool
1320
+ y = piecewise (x , [x < 4 ], [1 , 2 ])
1321
+ assert_array_equal (y , [1 , 2 ])
1224
1322
1323
+ # Condition: (1,2) array of bool
1324
+ y = piecewise (x , np .array ([x < 4 ]), [1 , 2 ])
1325
+ assert_array_equal (y , [1 , 2 ])
1225
1326
1226
1327
class TestBincount (TestCase ):
1227
1328
def test_simple (self ):
0 commit comments