diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index 54c685794..e9c142f0a 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -50,6 +50,7 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo import pyarrow.compute import datafusion from datafusion import col, udaf, Accumulator + from typing import List class MyAccumulator(Accumulator): """ @@ -62,9 +63,9 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo # not nice since pyarrow scalars can't be summed yet. This breaks on `None` self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(values).as_py()) - def merge(self, states: pyarrow.Array) -> None: + def merge(self, states: List[pyarrow.Array]) -> None: # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states).as_py()) + self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states[0]).as_py()) def state(self) -> pyarrow.Array: return pyarrow.array([self._sum.as_py()]) diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index 81194927c..76488e19b 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -38,10 +38,10 @@ def update(self, values: pa.Array) -> None: # This breaks on `None` self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - def merge(self, states: pa.Array) -> None: + def merge(self, states: List[pa.Array]) -> None: # Not nice since pyarrow scalars can't be summed yet. # This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(states).as_py()) + self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) def evaluate(self) -> pa.Scalar: return self._sum diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 12563b3d2..bdbad661a 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -157,7 +157,7 @@ def update(self, values: pyarrow.Array) -> None: pass @abstractmethod - def merge(self, states: pyarrow.Array) -> None: + def merge(self, states: List[pyarrow.Array]) -> None: """Merge a set of states.""" pass diff --git a/src/udaf.rs b/src/udaf.rs index 7b5e03668..2041e5a74 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -72,18 +72,21 @@ impl Accumulator for RustAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { Python::with_gil(|py| { - let state = &states[0]; - - // 1. cast states to Pyarrow array - let state = state - .into_data() - .to_pyarrow(py) - .map_err(|e| DataFusionError::Execution(format!("{e}")))?; + // // 1. cast states to Pyarrow arrays + let py_states: Result> = states + .iter() + .map(|state| { + state + .into_data() + .to_pyarrow(py) + .map_err(|e| DataFusionError::Execution(format!("{e}"))) + }) + .collect(); // 2. call merge self.accum .bind(py) - .call_method1("merge", (state,)) + .call_method1("merge", (py_states?,)) .map_err(|e| DataFusionError::Execution(format!("{e}")))?; Ok(())