@@ -261,6 +261,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
261
261
fill_value = None
262
262
tolerance = None
263
263
264
+ # for constructing expected
265
+ array_func = _get_array_func (func )
266
+
264
267
for kwargs in finalize_kwargs :
265
268
flox_kwargs = dict (func = func , engine = engine , finalize_kwargs = kwargs , fill_value = fill_value )
266
269
with np .errstate (invalid = "ignore" , divide = "ignore" ):
@@ -280,7 +283,6 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
280
283
array_ [..., nanmask ] = np .nan
281
284
expected = getattr (np , func_ )(array_ , axis = - 1 , ** kwargs )
282
285
else :
283
- array_func = _get_array_func (func )
284
286
expected = array_func (array_ [..., ~ nanmask ], axis = - 1 , ** kwargs )
285
287
for _ in range (nby ):
286
288
expected = np .expand_dims (expected , - 1 )
@@ -290,15 +292,28 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
290
292
flox_kwargs ["method" ] = "blockwise"
291
293
292
294
actual , * groups = groupby_reduce (array , * by , ** flox_kwargs )
293
- assert actual .ndim == (array .ndim + nby - 1 )
294
- assert expected .ndim == (array .ndim + nby - 1 )
295
+ assert actual .ndim == expected .ndim == (array .ndim + nby - 1 )
295
296
expected_groups = tuple (np .array ([idx + 1.0 ]) for idx in range (nby ))
296
297
for actual_group , expect in zip (groups , expected_groups ):
297
298
assert_equal (actual_group , expect )
298
299
if "arg" in func :
299
300
assert actual .dtype .kind == "i"
300
301
assert_equal (expected , actual , tolerance )
301
302
303
+ if "nan" not in func and "arg" not in func :
304
+ # test non-NaN skipping behaviour when NaNs are present
305
+ nanned = array_ .copy ()
306
+ # remove nans in by to reduce complexity
307
+ # We are checking for consistent behaviour with NaNs in array
308
+ by_ = tuple (np .nan_to_num (b , nan = np .nanmin (b )) for b in by )
309
+ nanned [[1 , 4 , 5 ], ...] = np .nan
310
+ nanned .reshape (- 1 )[0 ] = np .nan
311
+ actual , * _ = groupby_reduce (nanned , * by_ , ** flox_kwargs )
312
+ expected = array_func (nanned , axis = - 1 , ** kwargs )
313
+ for _ in range (nby ):
314
+ expected = np .expand_dims (expected , - 1 )
315
+ assert_equal (expected , actual , tolerance )
316
+
302
317
if not has_dask or chunks is None or func in BLOCKWISE_FUNCS :
303
318
continue
304
319
0 commit comments