File tree Expand file tree Collapse file tree 4 files changed +17
-13
lines changed
docs/source/user-guide/common-operations Expand file tree Collapse file tree 4 files changed +17
-13
lines changed Original file line number Diff line number Diff line change @@ -50,6 +50,7 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo
50
50
import pyarrow.compute
51
51
import datafusion
52
52
from datafusion import col, udaf, Accumulator
53
+ from typing import List
53
54
54
55
class MyAccumulator (Accumulator ):
55
56
"""
@@ -62,9 +63,9 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo
62
63
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
63
64
self ._sum = pyarrow.scalar(self ._sum.as_py() + pyarrow.compute.sum(values).as_py())
64
65
65
- def merge (self , states : pyarrow.Array) -> None :
66
+ def merge (self , states : List[ pyarrow.Array] ) -> None :
66
67
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
67
- self ._sum = pyarrow.scalar(self ._sum.as_py() + pyarrow.compute.sum(states).as_py())
68
+ self ._sum = pyarrow.scalar(self ._sum.as_py() + pyarrow.compute.sum(states[ 0 ] ).as_py())
68
69
69
70
def state (self ) -> pyarrow.Array:
70
71
return pyarrow.array([self ._sum.as_py()])
Original file line number Diff line number Diff line change @@ -38,10 +38,10 @@ def update(self, values: pa.Array) -> None:
38
38
# This breaks on `None`
39
39
self ._sum = pa .scalar (self ._sum .as_py () + pc .sum (values ).as_py ())
40
40
41
- def merge (self , states : pa .Array ) -> None :
41
+ def merge (self , states : List [ pa .Array ] ) -> None :
42
42
# Not nice since pyarrow scalars can't be summed yet.
43
43
# This breaks on `None`
44
- self ._sum = pa .scalar (self ._sum .as_py () + pc .sum (states ).as_py ())
44
+ self ._sum = pa .scalar (self ._sum .as_py () + pc .sum (states [ 0 ] ).as_py ())
45
45
46
46
def evaluate (self ) -> pa .Scalar :
47
47
return self ._sum
Original file line number Diff line number Diff line change @@ -157,7 +157,7 @@ def update(self, values: pyarrow.Array) -> None:
157
157
pass
158
158
159
159
@abstractmethod
160
- def merge (self , states : pyarrow .Array ) -> None :
160
+ def merge (self , states : List [ pyarrow .Array ] ) -> None :
161
161
"""Merge a set of states."""
162
162
pass
163
163
Original file line number Diff line number Diff line change @@ -72,18 +72,21 @@ impl Accumulator for RustAccumulator {
72
72
73
73
fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
74
74
Python :: with_gil ( |py| {
75
- let state = & states[ 0 ] ;
76
-
77
- // 1. cast states to Pyarrow array
78
- let state = state
79
- . into_data ( )
80
- . to_pyarrow ( py)
81
- . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) ) ?;
75
+ // // 1. cast states to Pyarrow arrays
76
+ let py_states: Result < Vec < PyObject > > = states
77
+ . iter ( )
78
+ . map ( |state| {
79
+ state
80
+ . into_data ( )
81
+ . to_pyarrow ( py)
82
+ . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) )
83
+ } )
84
+ . collect ( ) ;
82
85
83
86
// 2. call merge
84
87
self . accum
85
88
. bind ( py)
86
- . call_method1 ( "merge" , ( state , ) )
89
+ . call_method1 ( "merge" , ( py_states? , ) )
87
90
. map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) ) ?;
88
91
89
92
Ok ( ( ) )
You can’t perform that action at this time.
0 commit comments