10000 UDAF process all state variables (#799) · PhVHoang/datafusion-python@b6f06f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit b6f06f7

Browse files
authored
UDAF process all state variables (apache#799)
1 parent 2205e05 commit b6f06f7

File tree

4 files changed

+17
-13
lines changed

4 files changed

+17
-13
lines changed

docs/source/user-guide/common-operations/udf-and-udfa.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo
5050
import pyarrow.compute
5151
import datafusion
5252
from datafusion import col, udaf, Accumulator
53+
from typing import List
5354
5455
class MyAccumulator(Accumulator):
5556
"""
@@ -62,9 +63,9 @@ Additionally the :py:func:`~datafusion.udf.AggregateUDF.udaf` function allows yo
6263
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
6364
self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(values).as_py())
6465
65-
def merge(self, states: pyarrow.Array) -> None:
66+
def merge(self, states: List[pyarrow.Array]) -> None:
6667
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
67-
self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states).as_py())
68+
self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states[0]).as_py())
6869
6970
def state(self) -> pyarrow.Array:
7071
return pyarrow.array([self._sum.as_py()])

python/datafusion/tests/test_udaf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ def update(self, values: pa.Array) -> None:
3838
# This breaks on `None`
3939
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
4040

41-
def merge(self, states: pa.Array) -> None:
41+
def merge(self, states: List[pa.Array]) -> None:
4242
# Not nice since pyarrow scalars can't be summed yet.
4343
# This breaks on `None`
44-
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states).as_py())
44+
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
4545

4646
def evaluate(self) -> pa.Scalar:
4747
return self._sum

python/datafusion/udf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def update(self, values: pyarrow.Array) -> None:
157157
pass
158158

159159
@abstractmethod
160-
def merge(self, states: pyarrow.Array) -> None:
160+
def merge(self, states: List[pyarrow.Array]) -> None:
161161
"""Merge a set of states."""
162162
pass
163163

src/udaf.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,21 @@ impl Accumulator for RustAccumulator {
7272

7373
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
7474
Python::with_gil(|py| {
75-
let state = &states[0];
76-
77-
// 1. cast states to Pyarrow array
78-
let state = state
79-
.into_data()
80-
.to_pyarrow(py)
81-
.map_err(|e| DataFusionError::Execution(format!("{e}")))?;
75+
// // 1. cast states to Pyarrow arrays
76+
let py_states: Result<Vec<PyObject>> = states
77+
.iter()
78+
.map(|state| {
79+
state
80+
.into_data()
81+
.to_pyarrow(py)
82+
.map_err(|e| DataFusionError::Execution(format!("{e}")))
83+
})
84+
.collect();
8285

8386
// 2. call merge
8487
self.accum
8588
.bind(py)
86-
.call_method1("merge", (state,))
89+
.call_method1("merge", (py_states?,))
8790
.map_err(|e| DataFusionError::Execution(format!("{e}")))?;
8891

8992
Ok(())

0 commit comments

Comments
 (0)
0