Closed
Description
Describe the bug
When writing a user defined aggregate function, any state beyond the first is ignored during a merge operation. The offending line appears to be
Line 75 in 1d61548
To Reproduce
from typing import List
import pyarrow
import pyarrow.compute
import datafusion
from datafusion import col, udaf, Accumulator
class MyAccumulator(Accumulator):
"""
Interface of a user-defined accumulation.
"""
def __init__(self):
self._sum = 0.0
self._num = 0
def update(self, values: pyarrow.Array) -> None:
self._sum = self._sum + pyarrow.compute.sum(values).as_py()
self._num = self._num + len(values)
def merge(self, states: pyarrow.Array) -> None:
self._sum = self._sum + pyarrow.compute.sum(states).as_py()
# `states` SHOULD be a list of Array but in actuality is just a single array
# The below line is WRONG. It should be a sum of the second column of the state
self._num = self._num + len(states)
def state(self) -> List[pyarrow.Scalar]:
return [pyarrow.scalar(self._sum), pyarrow.scalar(self._num)]
def evaluate(self) -> pyarrow.Scalar:
return pyarrow.scalar(self._sum / self._num)
ctx = datafusion.SessionContext()
# Create a large enough array that we're getting it in batches
df = ctx.from_pydict(
{
"a": list(7 for _ in range(10000)),
}
)
my_udaf = udaf(MyAccumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64(), pyarrow.int64()], 'stable')
# This should return 7 regardless of what you set the `range` above to.
# If you decrease range to 1000 you will get the right value because it does
# all the processing in a single batch.
df.aggregate([],[my_udaf(col("a")).alias("output")]).show()
Expected behavior
User defined aggregate functions should be able to process more than one state variable.
Additional context
I'm going to work on this unless somebody wants to tackle it first.