3
3
import itertools
4
4
import warnings
5
5
from functools import partial , reduce
6
- from typing import TYPE_CHECKING
6
+ from typing import TYPE_CHECKING , Callable
7
7
8
8
import numpy as np
9
9
import pandas as pd
10
10
import pytest
11
11
from numpy_groupies .aggregate_numpy import aggregate
12
12
13
+ from flox import xrutils
13
14
from flox .aggregations import Aggregation
14
15
from flox .core import (
15
16
_convert_expected_groups_to_index ,
@@ -53,6 +54,7 @@ def dask_array_ones(*args):
53
54
"sum" ,
54
55
"nansum" ,
55
56
"argmax" ,
57
+ "nanfirst" ,
56
58
pytest .param ("nanargmax" , marks = (pytest .mark .skip ,)),
57
59
"prod" ,
58
60
"nanprod" ,
@@ -70,6 +72,7 @@ def dask_array_ones(*args):
70
72
pytest .param ("nanargmin" , marks = (pytest .mark .skip ,)),
71
73
"any" ,
72
74
"all" ,
75
+ "nanlast" ,
73
76
pytest .param ("median" , marks = (pytest .mark .skip ,)),
74
77
pytest .param ("nanmedian" , marks = (pytest .mark .skip ,)),
75
78
)
@@ -78,6 +81,21 @@ def dask_array_ones(*args):
78
81
from flox .core import T_Engine , T_ExpectedGroupsOpt , T_Func2
79
82
80
83
84
+ def _get_array_func (func : str ) -> Callable :
85
+ if func == "count" :
86
+
87
+ def npfunc (x ):
88
+ x = np .asarray (x )
89
+ return (~ np .isnan (x )).sum ()
90
+
91
+ elif func in ["nanfirst" , "nanlast" ]:
92
+ npfunc = getattr (xrutils , func )
93
+ else :
94
+ npfunc = getattr (np , func )
95
+
96
+ return npfunc
97
+
98
+
81
99
def test_alignment_error ():
82
100
da = np .ones ((12 ,))
83
101
labels = np .ones ((5 ,))
@@ -217,6 +235,8 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
217
235
if "arg" in func and add_nan_by :
218
236
array_ [..., nanmask ] = np .nan
219
237
expected = getattr (np , "nan" + func )(array_ , axis = - 1 , ** kwargs )
238
+ elif func in ["nanfirst" , "nanlast" ]:
239
+ expected = getattr (xrutils , func )(array_ [..., ~ nanmask ], axis = - 1 , ** kwargs )
220
240
else :
221
241
expected = getattr (np , func )(array_ [..., ~ nanmask ], axis = - 1 , ** kwargs )
222
242
for _ in range (nby ):
@@ -486,6 +506,20 @@ def test_dask_reduce_axis_subset():
486
506
)
487
507
488
508
509
+ @pytest .mark .parametrize ("func" , ["first" , "last" , "nanfirst" , "nanlast" ])
510
+ @pytest .mark .parametrize ("axis" , [(0 , 1 )])
511
+ def test_first_last_disallowed (axis , func ):
512
+ with pytest .raises (ValueError ):
513
+ groupby_reduce (np .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = axis )
514
+
515
+
516
+ @pytest .mark .parametrize ("func" , ["first" , "last" , "nanfirst" , "nanlast" ])
517
+ @pytest .mark .parametrize ("axis" , [None , (0 , 1 , 2 )])
518
+ def test_first_last_disallowed_dask (axis , func ):
519
+ with pytest .raises (ValueError ):
520
+ groupby_reduce (dask .array .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = axis )
521
+
522
+
489
523
@requires_dask
490
524
@pytest .mark .parametrize ("func" , ALL_FUNCS )
491
525
@pytest .mark .parametrize (
@@ -495,8 +529,12 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
495
529
if "arg" in func and engine == "flox" :
496
530
pytest .skip ()
497
531
498
- if not isinstance (axis , int ) and "arg" in func and (axis is None or len (axis ) > 1 ):
499
- pytest .skip ()
532
+ if not isinstance (axis , int ):
533
+ if "arg" in func and (axis is None or len (axis ) > 1 ):
534
+ pytest .skip ()
535
+ if ("first" in func or "last" in func ) and (axis is not None and len (axis ) not in [1 , 3 ]):
536
+ pytest .skip ()
537
+
500
538
if func in ["all" , "any" ]:
501
539
fill_value = False
502
540
else :
@@ -513,17 +551,45 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
513
551
kwargs = dict (
514
552
func = func , axis = axis , expected_groups = [0 , 2 ], fill_value = fill_value , engine = engine
515
553
)
554
+ expected , _ = groupby_reduce (array , by , ** kwargs )
555
+ if engine == "flox" :
556
+ kwargs .pop ("engine" )
557
+ expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
558
+ assert_equal (expected_npg , expected )
559
+
560
+ if func in ["all" , "any" ]:
561
+ fill_value = False
562
+ else :
563
+ fill_value = 123
564
+
565
+ if "var" in func or "std" in func :
566
+ tolerance = {"rtol" : 1e-14 , "atol" : 1e-16 }
567
+ else :
568
+ tolerance = None
569
+ # tests against the numpy output to make sure dask compute matches
570
+ by = np .broadcast_to (labels2d , (3 , * labels2d .shape ))
571
+ rng = np .random .default_rng (12345 )
572
+ array = rng .random (by .shape )
573
+ kwargs = dict (
574
+ func = func , axis = axis , expected_groups = [0 , 2 ], fill_value = fill_value , engine = engine
575
+ )
576
+ expected , _ = groupby_reduce (array , by , ** kwargs )
577
+ if engine == "flox" :
578
+ kwargs .pop ("engine" )
579
+ expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
580
+ assert_equal (expected_npg , expected )
581
+
582
+ if ("first" in func or "last" in func ) and (
583
+ axis is None or (not isinstance (axis , int ) and len (axis ) != 1 )
584
+ ):
585
+ return
586
+
516
587
with raise_if_dask_computes ():
517
588
actual , _ = groupby_reduce (
518
589
da .from_array (array , chunks = (- 1 , 2 , 3 )),
519
590
da .from_array (by , chunks = (- 1 , 2 , 2 )),
520
591
** kwargs ,
521
592
)
522
- expected , _ = groupby_reduce (array , by , ** kwargs )
523
- if engine == "flox" :
524
- kwargs .pop ("engine" )
525
- expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
526
- assert_equal (expected_npg , expected )
527
593
assert_equal (actual , expected , tolerance )
528
594
529
595
@@ -751,23 +817,17 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine):
751
817
if chunks is not None and not has_dask :
752
818
pytest .skip ()
753
819
754
- if func == "count" :
755
-
756
- def npfunc (x ):
757
- x = np .asarray (x )
758
- return (~ np .isnan (x )).sum ()
759
-
760
- else :
761
- npfunc = getattr (np , func )
762
-
820
+ npfunc = _get_array_func (func )
763
821
by = np .array ([1 , 2 , 3 , 1 , 2 , 3 ])
764
822
array = np .array ([np .nan , 1 , 1 , np .nan , 1 , 1 ])
765
823
if chunks :
766
824
array = dask .array .from_array (array , chunks )
767
825
actual , _ = groupby_reduce (
768
826
array , by , func = func , engine = engine , fill_value = fill_value , expected_groups = [0 , 1 , 2 , 3 ]
769
827
)
770
- expected = np .array ([fill_value , fill_value , npfunc ([1.0 , 1.0 ]), npfunc ([1.0 , 1.0 ])])
828
+ expected = np .array (
829
+ [fill_value , fill_value , npfunc ([1.0 , 1.0 ], axis = 0 ), npfunc ([1.0 , 1.0 ], axis = 0 )]
830
+ )
771
831
assert_equal (actual , expected )
772
832
773
833
@@ -832,6 +892,8 @@ def test_cohorts_nd_by(func, method, axis, engine):
832
892
833
893
if axis is not None and method != "map-reduce" :
834
894
pytest .xfail ()
895
+ if axis is None and ("first" in func or "last" in func ):
896
+ pytest .skip ()
835
897
836
898
kwargs = dict (func = func , engine = engine , method = method , axis = axis , fill_value = fill_value )
837
899
actual , groups = groupby_reduce (array , by , ** kwargs )
@@ -897,7 +959,8 @@ def test_bool_reductions(func, engine):
897
959
pytest .skip ()
898
960
groups = np .array ([1 , 1 , 1 ])
899
961
data = np .array ([True , True , False ])
900
- expected = np .expand_dims (getattr (np , func )(data ), - 1 )
962
+ npfunc = _get_array_func (func )
963
+ expected = np .expand_dims (npfunc (data , axis = 0 ), - 1 )
901
964
actual , _ = groupby_reduce (data , groups , func = func , engine = engine )
902
965
assert_equal (expected , actual )
903
966
0 commit comments