2
2
3
3
import numpy as np
4
4
5
- from .xrutils import isnull
5
+ from .xrutils import is_scalar , isnull , notnull
6
6
7
7
8
- def _prepare_for_flox (group_idx , array ):
8
+ def _prepare_for_flox (group_idx , array , lexsort ):
9
9
"""
10
10
Sort the input array once to save time.
11
11
"""
12
12
assert array .shape [- 1 ] == group_idx .shape [0 ]
13
- issorted = (group_idx [:- 1 ] <= group_idx [1 :]).all ()
14
- if issorted :
15
- ordered_array = array
13
+
14
+ if lexsort :
15
+ # lexsort allows us to sort by label AND array value
16
+ # numpy's quantile uses partition, which could be a big win
17
+ # IF we can figure out how to do that.
18
+ # This trick was snagged from scipy.ndimage.median() :)
19
+ labels_broadcast = np .broadcast_to (group_idx , array .shape )
20
+ idxs = np .lexsort ((array , labels_broadcast ), axis = - 1 )
21
+ ordered_array = np .take_along_axis (array , idxs , axis = - 1 )
22
+ group_idx = np .take_along_axis (group_idx , idxs [(0 ,) * (idxs .ndim - 1 ) + (...,)], axis = - 1 )
16
23
else :
17
- perm = group_idx .argsort (kind = "stable" )
18
- group_idx = group_idx [..., perm ]
19
- ordered_array = array [..., perm ]
24
+ issorted = (group_idx [:- 1 ] <= group_idx [1 :]).all ()
25
+ if issorted :
26
+ ordered_array = array
27
+ else :
28
+ perm = group_idx .argsort (kind = "stable" )
29
+ group_idx = group_idx [..., perm ]
30
+ ordered_array = array [..., perm ]
20
31
return group_idx , ordered_array
21
32
22
33
23
- def _np_grouped_op (group_idx , array , op , axis = - 1 , size = None , fill_value = None , dtype = None , out = None ):
34
+ def _lerp (a , b , * , t , dtype , out = None ):
35
+ """
36
+ COPIED from numpy.
37
+
38
+ Compute the linear interpolation weighted by gamma on each point of
39
+ two same shape array.
40
+
41
+ a : array_like
42
+ Left bound.
43
+ b : array_like
44
+ Right bound.
45
+ t : array_like
46
+ The interpolation weight.
47
+ """
48
+ if out is None :
49
+ out = np .empty_like (a , dtype = dtype )
50
+ diff_b_a = np .subtract (b , a )
51
+ # asanyarray is a stop-gap until gh-13105
52
+ np .add (a , diff_b_a * t , out = out )
53
+ np .subtract (b , diff_b_a * (1 - t ), out = out , where = t >= 0.5 )
54
+ return out
55
+
56
+
57
+ def quantile_ (array , inv_idx , * , q , axis , skipna , dtype = None , out = None ):
58
+ inv_idx = np .concatenate ((inv_idx , [array .shape [- 1 ]]))
59
+
60
+ if skipna :
61
+ sizes = np .add .reduceat (notnull (array ), inv_idx [:- 1 ], axis = axis )
62
+ else :
63
+ sizes = np .reshape (np .diff (inv_idx ), (1 ,) * (array .ndim - 1 ) + (inv_idx .size - 1 ,))
64
+ nanmask = isnull (np .take_along_axis (array , sizes - 1 , axis = axis ))
65
+
66
+ qin = q
67
+ q = np .atleast_1d (qin )
68
+ q = np .reshape (q , (len (q ),) + (1 ,) * array .ndim )
69
+
70
+ # This is numpy's method="linear"
71
+ # TODO: could support all the interpolations here
72
+ virtual_index = q * (sizes - 1 ) + inv_idx [:- 1 ]
73
+
74
+ is_scalar_q = is_scalar (qin )
75
+ if is_scalar_q :
76
+ virtual_index = virtual_index .squeeze (axis = 0 )
77
+ idxshape = array .shape [:- 1 ] + (sizes .shape [- 1 ],)
78
+ a_ = array
79
+ else :
80
+ idxshape = (q .shape [0 ],) + array .shape [:- 1 ] + (sizes .shape [- 1 ],)
81
+ a_ = np .broadcast_to (array , (q .shape [0 ],) + array .shape )
82
+
83
+ # Broadcast to (num quantiles, ..., num labels)
84
+ lo_ = np .floor (virtual_index , casting = "unsafe" , out = np .empty (idxshape , dtype = np .int64 ))
85
+ hi_ = np .ceil (virtual_index , casting = "unsafe" , out = np .empty (idxshape , dtype = np .int64 ))
86
+
87
+ # get bounds
88
+ loval = np .take_along_axis (a_ , lo_ , axis = axis )
89
+ hival = np .take_along_axis (a_ , hi_ , axis = axis )
90
+
91
+ # TODO: could support all the interpolations here
92
+ gamma = np .broadcast_to (virtual_index , idxshape ) - lo_
93
+ result = _lerp (loval , hival , t = gamma , out = out , dtype = dtype )
94
+ if not skipna and np .any (nanmask ):
95
+ result [..., nanmask ] = np .nan
96
+ return result
97
+
98
+
99
+ def _np_grouped_op (
100
+ group_idx , array , op , axis = - 1 , size = None , fill_value = None , dtype = None , out = None , ** kwargs
101
+ ):
24
102
"""
25
103
most of this code is from shoyer's gist
26
104
https://gist.github.com/shoyer/f538ac78ae904c936844
@@ -38,16 +116,21 @@ def _np_grouped_op(group_idx, array, op, axis=-1, size=None, fill_value=None, dt
38
116
dtype = array .dtype
39
117
40
118
if out is None :
41
- out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
119
+ q = kwargs .get ("q" , None )
120
+ if q is None :
121
+ out = np .full (array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
122
+ else :
123
+ nq = len (np .atleast_1d (q ))
124
+ out = np .full ((nq ,) + array .shape [:- 1 ] + (size ,), fill_value = fill_value , dtype = dtype )
42
125
43
126
if (len (uniques ) == size ) and (uniques == np .arange (size , like = array )).all ():
44
127
# The previous version of this if condition
45
128
# ((uniques[1:] - uniques[:-1]) == 1).all():
46
129
# does not work when group_idx is [1, 2] for e.g.
47
130
# This happens during binning
48
- op . reduceat (array , inv_idx , axis = axis , dtype = dtype , out = out )
131
+ op (array , inv_idx , axis = axis , dtype = dtype , out = out , ** kwargs )
49
132
else :
50
- out [..., uniques ] = op . reduceat (array , inv_idx , axis = axis , dtype = dtype )
133
+ out [..., uniques ] = op (array , inv_idx , axis = axis , dtype = dtype , ** kwargs )
51
134
52
135
return out
53
136
@@ -65,14 +148,18 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs):
65
148
return result
66
149
67
150
68
- sum = partial (_np_grouped_op , op = np .add )
151
+ sum = partial (_np_grouped_op , op = np .add . reduceat )
69
152
nansum = partial (_nan_grouped_op , func = sum , fillna = 0 )
70
- prod = partial (_np_grouped_op , op = np .multiply )
153
+ prod = partial (_np_grouped_op , op = np .multiply . reduceat )
71
154
nanprod = partial (_nan_grouped_op , func = prod , fillna = 1 )
72
- max = partial (_np_grouped_op , op = np .maximum )
155
+ max = partial (_np_grouped_op , op = np .maximum . reduceat )
73
156
nanmax = partial (_nan_grouped_op , func = max , fillna = - np .inf )
74
- min = partial (_np_grouped_op , op = np .minimum )
157
+ min = partial (_np_grouped_op , op = np .minimum . reduceat )
75
158
nanmin = partial (_nan_grouped_op , func = min , fillna = np .inf )
159
+ quantile = partial (_np_grouped_op , op = partial (quantile_ , skipna = False ))
160
+ nanquantile = partial (_np_grouped_op , op = partial (quantile_ , skipna = True ))
161
+ median = partial (_np_grouped_op , op = partial (quantile_ , q = 0.5 , skipna = False ))
162
+ nanmedian = partial (_np_grouped_op , op = partial (quantile_ , q = 0.5 , skipna = True ))
76
163
# TODO: all, any
77
164
78
165
@@ -99,7 +186,7 @@ def nansum_of_squares(group_idx, array, *, axis=-1, size=None, fill_value=None,
99
186
100
187
101
188
def nanlen (group_idx , array , * args , ** kwargs ):
102
- return sum (group_idx , (~ isnull (array )).astype (int ), * args , ** kwargs )
189
+ return sum (group_idx , (notnull (array )).astype (int ), * args , ** kwargs )
103
190
104
191
105
192
def mean (group_idx , array , * , axis = - 1 , size = None , fill_value = None , dtype = None ):
0 commit comments