8000 Vectorized grouped (nan)quantile (#329) · xarray-contrib/flox@c966323 · GitHub
[go: up one dir, main page]

Skip to content

Commit c966323

Browse files
authored
Vectorized grouped (nan)quantile (#329)
* Vectorized grouped (nan)quantile * Choose by default * Add fallback logging * Fix quantile when NaNs are present * Support median, nanmedian * micro-optimization * Fix typing * notnull optimization * Some optimization * Minor cleanup * Add test for non-nan skipping agg with nans * Fix test
1 parent 561378d commit c966323

File tree

6 files changed

+168
-28
lines changed

6 files changed

+168
-28
lines changed

flox/aggregate_flox.py

Lines changed: 104 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,103 @@
22

33
import numpy as np
44

5-
from .xrutils import isnull
5+
from .xrutils import is_scalar, isnull, notnull
66

77

8-
def _prepare_for_flox(group_idx, array):
8+
def _prepare_for_flox(group_idx, array, lexsort):
99
"""
1010
Sort the input array once to save time.
1111
"""
1212
assert array.shape[-1] == group_idx.shape[0]
13-
issorted = (group_idx[:-1] <= group_idx[1:]).all()
14-
if issorted:
15-
ordered_array = array
13+
14+
if lexsort:
15+
# lexsort allows us to sort by label AND array value
16+
# numpy's quantile uses partition, which could be a big win
17+
# IF we can figure out how to do that.
18+
# This trick was snagged from scipy.ndimage.median() :)
19+
labels_broadcast = np.broadcast_to(group_idx, array.shape)
20+
idxs = np.lexsort((array, labels_broadcast), axis=-1)
21+
ordered_array = np.take_along_axis(array, idxs, axis=-1)
22+
group_idx = np.take_along_axis(group_idx, idxs[(0,) * (idxs.ndim - 1) + (...,)], axis=-1)
1623
else:
17-
perm = group_idx.argsort(kind="stable")
18-
group_idx = group_idx[..., perm]
19-
ordered_array = array[..., perm]
24+
issorted = (group_idx[:-1] <= group_idx[1:]).all()
25+
if issorted:
26+
ordered_array = array
27+
else:
28+
perm = group_idx.argsort(kind="stable")
29+
group_idx = group_idx[..., perm]
30+
ordered_array = array[..., perm]
2031
return group_idx, ordered_array
2132

2233

23-
def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dtype=None, out=None):
34+
def _lerp(a, b, *, t, dtype, out=None):
35+
"""
36+
COPIED from numpy.
37+
38+
Compute the linear interpolation weighted by gamma on each point of
39+
two same shape array.
40+
41+
a : array_like
42+
Left bound.
43+
b : array_like
44+
Right bound.
45+
t : array_like
46+
The interpolation weight.
47+
"""
48+
if out is None:
49+
out = np.empty_like(a, dtype=dtype)
50+
diff_b_a = np.subtract(b, a)
51+
# asanyarray is a stop-gap until gh-13105
52+
np.add(a, diff_b_a * t, out=out)
53+
np.subtract(b, diff_b_a * (1 - t), out=out, where=t >= 0.5)
54+
return out
55+
56+
57+
def quantile_(array, inv_idx, *, q, axis, skipna, dtype=None, out=None):
58+
inv_idx = np.concatenate((inv_idx, [array.shape[-1]]))
59+
60+
if skipna:
61+
sizes = np.add.reduceat(notnull(array), inv_idx[:-1], axis=axis)
62+
else:
63+
sizes = np.reshape(np.diff(inv_idx), (1,) * (array.ndim - 1) + (inv_idx.size - 1,))
64+
nanmask = isnull(np.take_along_axis(array, sizes - 1, axis=axis))
65+
66+
qin = q
67+
q = np.atleast_1d(qin)
68+
q = np.reshape(q, (len(q),) + (1,) * array.ndim)
69+
70+
# This is numpy's method="linear"
71+
# TODO: could support all the interpolations here
72+
virtual_index = q * (sizes - 1) + inv_idx[:-1]
73+
74+
is_scalar_q = is_scalar(qin)
75+
if is_scalar_q:
76+
virtual_index = virtual_index.squeeze(axis=0)
77+
idxshape = array.shape[:-1] + (sizes.shape[-1],)
78+
a_ = array
79+
else:
80+
idxshape = (q.shape[0],) + array.shape[:-1] + (sizes.shape[-1],)
81+
a_ = np.broadcast_to(array, (q.shape[0],) + array.shape)
82+
83+
# Broadcast to (num quantiles, ..., num labels)
84+
lo_ = np.floor(virtual_index, casting="unsafe", out=np.empty(idxshape, dtype=np.int64))
85+
hi_ = np.ceil(virtual_index, casting="unsafe", out=np.empty(idxshape, dtype=np.int64))
86+
87+
# get bounds
88+
loval = np.take_along_axis(a_, lo_, axis=axis)
89+
hival = np.take_along_axis(a_, hi_, axis=axis)
90+
91+
# TODO: could support all the interpolations here
92+
gamma = np.broadcast_to(virtual_index, idxshape) - lo_
93+
result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype)
94+
if not skipna and np.any(nanmask):
95+
result[..., nanmask] = np.nan
96+
return result
97+
98+
99+
def _np_grouped_op(
100+
group_idx, array, op, axis=-1, size=None, fill_value=None, dtype=None, out=None, **kwargs
101+
):
24102
"""
25103
most of this code is from shoyer's gist
26104
https://gist.github.com/shoyer/f538ac78ae904c936844
@@ -38,16 +116,21 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
38116
dtype = array.dtype
39117

40118
if out is None:
41-
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
119+
q = kwargs.get("q", None)
120+
if q is None:
121+
out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
122+
else:
123+
nq = len(np.atleast_1d(q))
124+
out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype)
42125

43126
if (len(uniques) == size) and (uniques == np.arange(size, like=array)).all():
44127
# The previous version of this if condition
45128
# ((uniques[1:] - uniques[:-1]) == 1).all():
46129
# does not work when group_idx is [1, 2] for e.g.
47130
# This happens during binning
48-
op.reduceat(array, inv_idx, axis=axis, dtype=dtype, out=out)
131+
op(array, inv_idx, axis=axis, dtype=dtype, out=out, **kwargs)
49132
else:
50-
out[..., uniques] = op.reduceat(array, inv_idx, axis=axis, dtype=dtype)
133+
out[..., uniques] = op(array, inv_idx, axis=axis, dtype=dtype, **kwargs)
51134

52135
return out
53136

@@ -65,14 +148,18 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
65148
return result
66149

67150

68-
sum = partial(_np_grouped_op, op=np.add)
151+
sum = partial(_np_grouped_op, op=np.add.reduceat)
69152
nansum = partial(_nan_grouped_op, func=sum, fillna=0)
70-
prod = partial(_np_grouped_op, op=np.multiply)
153+
prod = partial(_np_grouped_op, op=np.multiply.reduceat)
71154
nanprod = partial(_nan_grouped_op, func=prod, fillna=1)
72-
max = partial(_np_grouped_op, op=np.maximum)
155+
max = partial(_np_grouped_op, op=np.maximum.reduceat)
73156
nanmax = partial(_nan_grouped_op, func=max, fillna=-np.inf)
74-
min = partial(_np_grouped_op, op=np.minimum)
157+
min = partial(_np_grouped_op, op=np.minimum.reduceat)
75158
nanmin = partial(_nan_grouped_op, func=min, fillna=np.inf)
159+
quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False))
160+
nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True))
161+
median = partial(_np_grouped_op, op=partial(quantile_, q=0.5, skipna=False))
162+
nanmedian = partial(_np_grouped_op, op=partial(quantile_, q=0.5, skipna=True))
76163
# TODO: all, any
77164

78165

@@ -99,7 +186,7 @@ def nansum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None,
99186

100187

101188
def nanlen(group_idx, array, *args, **kwargs):
102-
return sum(group_idx, (~isnull(array)).astype(int), *args, **kwargs)
189+
return sum(group_idx, (notnull(array)).astype(int), *args, **kwargs)
103190

104191

105192
def mean(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):

flox/aggregate_npg.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ def _get_aggregate(engine):
88
return npg.aggregate_numpy if engine == "numpy" else npg.aggregate_numba
99

1010

11+
def _casting_wrapper(func, grp, dtype):
12+
"""Used for generic aggregates. The group is dtype=object, need to cast back to fix weird bugs"""
13+
return func(grp.astype(dtype))
14+
15+
1116
def sum_of_squares(
1217
group_idx,
1318
array,
@@ -106,7 +111,7 @@ def median(group_idx, array, engine, *, axis=-1, size=None, fill_value=None, dty
106111
return npg.aggregate_numpy.aggregate(
107112
group_idx,
108113
array,
109-
func=np.median,
114+
func=partial(_casting_wrapper, np.median, dtype=array.dtype),
110115
axis=axis,
111116
size=size,
112117
fill_value=fill_value,
@@ -118,7 +123,7 @@ def nanmedian(group_idx, array, engine, *, axis=-1, size=None, fill_value=None,
118123
return npg.aggregate_numpy.aggregate(
119124
group_idx,
120125
array,
121-
func=np.nanmedian,
126+
func=partial(_casting_wrapper, np.nanmedian, dtype=array.dtype),
122127
axis=axis,
123128
size=size,
124129
fill_value=fill_value,
@@ -130,7 +135,7 @@ def quantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=None
130135
return npg.aggregate_numpy.aggregate(
131136
group_idx,
132137
array,
133-
func=partial(np.quantile, q=q),
138+
func=partial(_casting_wrapper, partial(np.quantile, q=q), dtype=array.dtype),
134139
axis=axis,
135140
size=size,
136141
fill_value=fill_value,
@@ -142,7 +147,7 @@ def nanquantile(group_idx, array, engine, *, q, axis=-1, size=None, fill_value=N
142147
return npg.aggregate_numpy.aggregate(
143148
group_idx,
144149
array,
145-
func=partial(np.nanquantile, q=q),
150+
func=partial(_casting_wrapper, partial(np.nanquantile, q=q), dtype=array.dtype),
146151
axis=axis,
147152
size=size,
148153
fill_value=fill_value,

flox/aggregations.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import copy
4+
import logging
45
import warnings
56
from functools import partial
67
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict
@@ -16,6 +17,9 @@
1617
OptionalFuncTuple = tuple[Callable | str | None, ...]
1718

1819

20+
logger = logging.getLogger("flox")
21+
22+
1923
def _is_arg_reduction(func: str | Aggregation) -> bool:
2024
if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
2125
return True
@@ -62,6 +66,7 @@ def generic_aggregate(
6266
try:
6367
method = getattr(aggregate_flox, func)
6468
except AttributeError:
69+
logger.debug(f"Couldn't find {func} for engine='flox'. Falling back to numpy")
6570
method = get_npg_aggregation(func, engine="numpy")
6671

6772
elif engine == "numbagg":
@@ -78,6 +83,7 @@ def generic_aggregate(
7883
else:
7984
method = getattr(aggregate_numbagg, func)
8085
except AttributeError:
86+
logger.debug(f"Couldn't find {func} for engine='numbagg'. Falling back to numpy")
8187
method = get_npg_aggregation(func, engine="numpy")
8288

8389
elif engine in ["numpy", "numba"]:

flox/core.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636
generic_aggregate,
3737
)
3838
from .cache import memoize
39-
from .xrutils import is_duck_array, is_duck_dask_array, isnull, module_available
39+
from .xrutils import (
40+
is_duck_array,
41+
is_duck_dask_array,
42+
isnull,
43+
module_available,
44+
notnull,
45+
)
4046

4147
if module_available("numpy", minversion="2.0.0"):
4248
from numpy.lib.array_utils import ( # type: ignore[import-not-found]
@@ -46,6 +52,7 @@
4652
from numpy.core.numeric import normalize_axis_tuple # type: ignore[attr-defined]
4753

4854
HAS_NUMBAGG = module_available("numbagg", minversion="0.3.0")
55+
_LEXSORT_FOR_FLOX = ["quantile", "nanquantile", "median", "nanmedian"]
4956

5057
if TYPE_CHECKING:
5158
try:
@@ -156,7 +163,7 @@ def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
156163
if is_duck_dask_array(by):
157164
raise ValueError("Please provide expected_groups if not grouping by a numpy array.")
158165
flatby = by.reshape(-1)
159-
expected = pd.unique(flatby[~isnull(flatby)])
166+
expected = pd.unique(flatby[notnull(flatby)])
160167
return _convert_expected_groups_to_index((expected,), isbin=(False,), sort=sort)[0]
161168

162169

@@ -953,7 +960,9 @@ def chunk_reduce(
953960
if engine == "flox":
954961
# is_arg_reduction = any("arg" in f for f in func if isinstance(f, str))
955962
# if not is_arg_reduction:
956-
group_idx, array = _prepare_for_flox(group_idx, array)
963+
group_idx, array = _prepare_for_flox(
964+
group_idx, array, lexsort=any(f in _LEXSORT_FOR_FLOX for f in funcs)
965+
)
957966

958967
final_array_shape += results["groups"].shape
959968
final_groups_shape += results["groups"].shape
@@ -1095,7 +1104,7 @@ def _find_unique_groups(x_chunk) -> np.ndarray:
10951104
from dask.utils import deepmap
10961105

10971106
unique_groups = _unique(np.asarray(tuple(flatten(deepmap(listify_groups, x_chunk)))))
1098-
unique_groups = unique_groups[~isnull(unique_groups)]
1107+
unique_groups = unique_groups[notnull(unique_groups)]
10991108

11001109
if len(unique_groups) == 0:
11011110
unique_groups = np.array([np.nan])
@@ -1959,6 +1968,10 @@ def _choose_engine(by, agg: Aggregation):
19591968

19601969
not_arg_reduce = not _is_arg_reduction(agg)
19611970

1971+
if agg.name in _LEXSORT_FOR_FLOX:
1972+
logger.info(f"_choose_engine: Choosing 'flox' since {agg.name}")
1973+
return "flox"
1974+
19621975
# numbagg only supports nan-skipping reductions
19631976
# without dtype specified
19641977
has_blockwise_nan_skipping = (agg.chunk[0] is None and "nan" in agg.name) or any(

flox/xrutils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@ def is_scalar(value: Any, include_0d: bool = True) -> bool:
100100
)
101101

102102

103+
def notnull(data):
104+
if not is_duck_array(data):
105+
data = np.asarray(data)
106+
107+
scalar_type = data.dtype.type
108+
if issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
109+
# these types cannot represent missing values
110+
return np.ones_like(data, dtype=bool)
111+
else:
112+
out = isnull(data)
113+
np.logical_not(out, out=out)
114+
return out
115+
116+
103117
def isnull(data):
104118
if not is_duck_array(data):
105119
data = np.asarray(data)

tests/test_core.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,9 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
261261
fill_value = None
262262
tolerance = None
263263

264+
# for constructing expected
265+
array_func = _get_array_func(func)
266+
264267
for kwargs in finalize_kwargs:
265268
flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value)
266269
with np.errstate(invalid="ignore", divide="ignore"):
@@ -280,7 +283,6 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
280283
array_[..., nanmask] = np.nan
281284
expected = getattr(np, func_)(array_, axis=-1, **kwargs)
282285
else:
283-
array_func = _get_array_func(func)
284286
expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs)
285287
for _ in range(nby):
286288
expected = np.expand_dims(expected, -1)
@@ -290,15 +292,28 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
290292
flox_kwargs["method"] = "blockwise"
291293

292294
actual, *groups = groupby_reduce(array, *by, **flox_kwargs)
293-
assert actual.ndim == (array.ndim + nby - 1)
294-
assert expected.ndim == (array.ndim + nby - 1)
295+
assert actual.ndim == expected.ndim == (array.ndim + nby - 1)
295296
expected_groups = tuple(np.array([idx + 1.0]) for idx in range(nby))
296297
for actual_group, expect in zip(groups, expected_groups):
297298
assert_equal(actual_group, expect)
298299
if "arg" in func:
299300
assert actual.dtype.kind == "i"
300301
assert_equal(expected, actual, tolerance)
301302

303+
if "nan" not in func and "arg" not in func:
304+
# test non-NaN skipping behaviour when NaNs are present
305+
nanned = array_.copy()
306+
# remove nans in by to reduce complexity
307+
# We are checking for consistent behaviour with NaNs in array
308+
by_ = tuple(np.nan_to_num(b, nan=np.nanmin(b)) for b in by)
309+
nanned[[1, 4, 5], ...] = np.nan
310+
nanned.reshape(-1)[0] = np.nan
311+
actual, *_ = groupby_reduce(nanned, *by_, **flox_kwargs)
312+
expected_0 = array_func(nanned, axis=-1, **kwargs)
313+
for _ in range(nby):
314+
expected_0 = np.expand_dims(expected_0, -1)
315+
assert_equal(expected_0, actual, tolerance)
316+
302317
if not has_dask or chunks is None or func in BLOCKWISE_FUNCS:
303318
continue
304319

0 commit comments

Comments
 (0)
0