8000
We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
Describe the bug
When writing a user defined aggregate function, any state beyond the first is ignored during a merge operation. The offending line appears to be
datafusion-python/src/udaf.rs
Line 75 in 1d61548
To Reproduce
from typing import List import pyarrow import pyarrow.compute import datafusion from datafusion import col, udaf, Accumulator class MyAccumulator(Accumulator): """ Interface of a user-defined accumulation. """ def __init__(self): self._sum = 0.0 self._num = 0 def update(self, values: pyarrow.Array) -> None: self._sum = self._sum + pyarrow.compute.sum(values).as_py() self._num = self._num + len(values) def merge(self, states: pyarrow.Array) -> None: self._sum = self._sum + pyarrow.compute.sum(states).as_py() # `states` SHOULD be a list of Array but in actuality is just a single array # The below line is WRONG. It should be a sum of the second column of the state self._num = self._num + len(states) def state(self) -> List[pyarrow.Scalar]: return [pyarrow.scalar(self._sum), pyarrow.scalar(self._num)] def evaluate(self) -> pyarrow.Scalar: return pyarrow.scalar(self._sum / self._num) ctx = datafusion.SessionContext() # Create a large enough array that we're getting it in batches df = ctx.from_pydict( { "a": list(7 for _ in range(10000)), } ) my_udaf = udaf(MyAccumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64(), pyarrow.int64()], 'stable') # This should return 7 regardless of what you set the `range` above to. # If you decrease range to 1000 you will get the right value because it does # all the processing in a single batch. df.aggregate([],[my_udaf(col("a")).alias("output")]).show()
Expected behavior
User defined aggregate functions should be able to process more than one state variable.
Additional context
I'm going to work on this unless somebody wants to tackle it first.