8000 Add test for non-nan skipping agg with nans · xarray-contrib/flox@5cca384 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5cca384

Browse files
committed
Add test for non-nan skipping agg with nans
1 parent 44858c5 commit 5cca384

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

tests/test_core.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
261261
fill_value = None
262262
tolerance = None
263263

264+
# for constructing expected
265+
array_func = _get_array_func(func)
266+
264267
for kwargs in finalize_kwargs:
265268
flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value)
266269
with np.errstate(invalid="ignore", divide="ignore"):
@@ -280,7 +283,6 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
280283
array_[..., nanmask] = np.nan
281284
expected = getattr(np, func_)(array_, axis=-1, **kwargs)
282285
else:
283-
array_func = _get_array_func(func)
284286
expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs)
285287
for _ in range(nby):
286288
expected = np.expand_dims(expected, -1)
@@ -290,15 +292,28 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
290292
flox_kwargs["method"] = "blockwise"
291293

292294
actual, *groups = groupby_reduce(array, *by, **flox_kwargs)
293-
assert actual.ndim == (array.ndim + nby - 1)
294-
assert expected.ndim == (array.ndim + nby - 1)
295+
assert actual.ndim == expected.ndim == (array.ndim + nby - 1)
295296
expected_groups = tuple(np.array([idx + 1.0]) for idx in range(nby))
296297
for actual_group, expect in zip(groups, expected_groups):
297298
assert_equal(actual_group, expect)
298299
if "arg" in func:
299300
assert actual.dtype.kind == "i"
300301
assert_equal(expected, actual, tolerance)
301302

303+
if "nan" not in func and "arg" not in func:
304+
# test non-NaN skipping behaviour when NaNs are present
305+
nanned = array_.copy()
306+
# remove nans in by to reduce complexity
307+
# We are checking for consistent behaviour with NaNs in array
308+
by_ = tuple(np.nan_to_num(b, nan=np.nanmin(b)) for b in by)
309+
nanned[[1, 4, 5], ...] = np.nan
310+
nanned.reshape(-1)[0] = np.nan
311+
actual, *_ = groupby_reduce(nanned, *by_, **flox_kwargs)
312+
expected = array_func(nanned, axis=-1, **kwargs)
313+
for _ in range(nby):
314+
expected = np.expand_dims(expected, -1)
315+
assert_equal(expected, actual, tolerance)
316+
302317
if not has_dask or chunks is None or func in BLOCKWISE_FUNCS:
303318
continue
304319

0 commit comments

Comments
 (0)
0