From 1d1fb8c0651e9a2644b89ed79a2b29686ed3381f Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 7 Aug 2024 11:02:59 -0400 Subject: [PATCH 1/4] Change udaf function to pass all state arrays instead of just the first value --- python/datafusion/udf.py | 2 +- src/udaf.rs | 25 ++++++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 12563b3d2..bdbad661a 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -157,7 +157,7 @@ def update(self, values: pyarrow.Array) -> None: pass @abstractmethod - def merge(self, states: pyarrow.Array) -> None: + def merge(self, states: List[pyarrow.Array]) -> None: """Merge a set of states.""" pass diff --git a/src/udaf.rs b/src/udaf.rs index 7b5e03668..035fa9471 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -72,18 +72,29 @@ impl Accumulator for RustAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { Python::with_gil(|py| { - let state = &states[0]; + // let state = &states[0]; + + // // 1. cast states to Pyarrow array + // let state = state + // .into_data() + // .to_pyarrow(py) + // .map_err(|e| DataFusionError::Execution(format!("{e}")))?; + let py_states: Result> = states + .iter() + .map(|state| { + state + .into_data() + .to_pyarrow(py) + .map_err(|e| DataFusionError::Execution(format!("{e}"))) + }) + .collect(); - // 1. cast states to Pyarrow array - let state = state - .into_data() - .to_pyarrow(py) - .map_err(|e| DataFusionError::Execution(format!("{e}")))?; + // let py_states = PyTuple::new_bound(py, py_states?.iter()); // 2. call merge self.accum .bind(py) - .call_method1("merge", (state,)) + .call_method1("merge", (py_states?,)) .map_err(|e| DataFusionError::Execution(format!("{e}")))?; Ok(()) From b66ff021eb4a32f479bab0751e12aff5ee1044e2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 7 Aug 2024 11:38:17 -0400 Subject: [PATCH 2/4] Update unit test --- python/datafusion/tests/test_udaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index 81194927c..76488e19b 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -38,10 +38,10 @@ def update(self, values: pa.Array) -> None: # This breaks on `None` self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) - def merge(self, states: pa.Array) -> None: + def merge(self, states: List[pa.Array]) -> None: # Not nice since pyarrow scalars can't be summed yet. # This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(states).as_py()) + self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) def evaluate(self) -> pa.Scalar: return self._sum From 0560cb8dc7257850d54c93bda7b4fdff4bfc5691 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 7 Aug 2024 11:44:26 -0400 Subject: [PATCH 3/4] Update user documenation --- docs/source/user-guide/common-operations/udf-and-udfa.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index 54c685794..e9c142f0a 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -50,6 +50,7 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo import pyarrow.compute import datafusion from datafusion import col, udaf, Accumulator + from typing import List class MyAccumulator(Accumulator): """ @@ -62,9 +63,9 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo # not nice since pyarrow scalars can't be summed yet. This breaks on `None` self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(values).as_py()) - def merge(self, states: pyarrow.Array) -> None: + def merge(self, states: List[pyarrow.Array]) -> None: # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states).as_py()) + self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states[0]).as_py()) def state(self) -> pyarrow.Array: return pyarrow.array([self._sum.as_py()]) From 3149faa7943daf75853e4a8c4cdc4924f36b5791 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 7 Aug 2024 12:31:44 -0400 Subject: [PATCH 4/4] Remove stale comments --- src/udaf.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/udaf.rs b/src/udaf.rs index 035fa9471..2041e5a74 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -72,13 +72,7 @@ impl Accumulator for RustAccumulator { fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { Python::with_gil(|py| { - // let state = &states[0]; - - // // 1. cast states to Pyarrow array - // let state = state - // .into_data() - // .to_pyarrow(py) - // .map_err(|e| DataFusionError::Execution(format!("{e}")))?; + // // 1. cast states to Pyarrow arrays let py_states: Result> = states .iter() .map(|state| { @@ -89,8 +83,6 @@ impl Accumulator for RustAccumulator { }) .collect(); - // let py_states = PyTuple::new_bound(py, py_states?.iter()); - // 2. call merge self.accum .bind(py)