8000 feat: expose PyWindowFrame (#509) · psvri/arrow-datafusion-python@399fa75 · GitHub
[go: up one dir, main page]

Skip to content

Commit 399fa75

Browse files
authored
feat: expose PyWindowFrame (apache#509)
* feat: expose PyWindowFrame * fix: PyWindowFrame: return Err instead of panicking * test: test PyWindowFrame creation
1 parent 5ec45dd commit 399fa75

File tree

6 files changed

+200
-5
lines changed

6 files changed

+200
-5
lines changed

datafusion/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
SessionConfig,
3434
RuntimeConfig,
3535
ScalarUDF,
36+
WindowFrame,
3637
)
3738

3839
from .common import (
@@ -98,6 +99,7 @@
9899
"Expr",
99100
"AggregateUDF",
100101
"ScalarUDF",
102+
"WindowFrame",
101103
"column",
102104
"literal",
103105
"TableScan",

datafusion/tests/test_dataframe.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121
import pytest
2222

2323
from datafusion import functions as f
24-
from datafusion import DataFrame, SessionContext, column, literal, udf
24+
from datafusion import (
25+
DataFrame,
26+
SessionContext,
27+
WindowFrame,
28+
column,
29+
literal,
30+
udf,
31+
)
2532

2633

2734
@pytest.fixture
@@ -304,6 +311,38 @@ def test_window_functions(df):
304311
assert table.sort_by("a").to_pydict() == expected
305312

306313

314+
@pytest.mark.parametrize(
315+
("units", "start_bound", "end_bound"),
316+
[
317+
(units, start_bound, end_bound)
318+
for units in ("rows", "range")
319+
for start_bound in (None, 0, 1)
320+
for end_bound in (None, 0, 1)
321+
]
322+
+ [
323+
("groups", 0, 0),
324+
],
325+
)
326+
def test_valid_window_frame(units, start_bound, end_bound):
327+
WindowFrame(units, start_bound, end_bound)
328+
329+
330+
@pytest.mark.parametrize(
331+
("units", "start_bound", "end_bound"),
332+
[
333+
("invalid-units", 0, None),
334+
("invalid-units", None, 0),
335+
("invalid-units", None, None),
336+
("groups", None, 0),
337+
("groups", 0, None),
338+
("groups", None, None),
339+
],
340+
)
341+
def test_invalid_window_frame(units, start_bound, end_bound):
342+
with pytest.raises(RuntimeError):
343+
WindowFrame(units, start_bound, end_bound)
344+
345+
307346
def test_get_dataframe(tmp_path):
308347
ctx = SessionContext()
309348

src/functions.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,20 @@
1717

1818
use pyo3::{prelude::*, wrap_pyfunction};
1919

20+
use crate::context::PySessionContext;
2021
use crate::errors::DataFusionError;
2122
use crate::expr::conditional_expr::PyCaseBuilder;
2223
use crate::expr::PyExpr;
24+
use crate::window_frame::PyWindowFrame;
25+
use datafusion::execution::FunctionRegistry;
2326
use datafusion_common::Column;
2427
use datafusion_expr::expr::Alias;
2528
use datafusion_expr::{
2629
aggregate_function,
2730
expr::{AggregateFunction, ScalarFunction, Sort, WindowFunction},
2831
lit,
2932
window_function::find_df_window_func,
30-
BuiltinScalarFunction, Expr, WindowFrame,
33+
BuiltinScalarFunction, Expr,
3134
};
3235

3336
#[pyfunction]
@@ -130,13 +133,24 @@ fn window(
130133
args: Vec<PyExpr>,
131134
partition_by: Option<Vec<PyExpr>>,
132135
order_by: Option<Vec<PyExpr>>,
136+
window_frame: Option<PyWindowFrame>,
137+
ctx: Option<PySessionContext>,
133138
) -> PyResult<PyExpr> {
134-
let fun = find_df_window_func(name);
139+
let fun = find_df_window_func(name).or_else(|| {
140+
ctx.and_then(|ctx| {
141+
ctx.ctx
142+
.udaf(name)
143+
.map(|fun| datafusion_expr::WindowFunction::AggregateUDF(fun))
144+
.ok()
145+
})
146+
});
135147
if fun.is_none() {
136148
return Err(DataFusionError::Common("window function not found".to_string()).into());
137149
}
138150
let fun = fun.unwrap();
139-
let window_frame = WindowFrame::new(order_by.is_some());
151+
let window_frame = window_frame
152+
.unwrap_or_else(|| PyWindowFrame::new("rows", None, Some(0)).unwrap())
153+
.into();
140154
Ok(PyExpr {
141155
expr: datafusion_expr::Expr::WindowFunction(WindowFunction {
142156
fun,

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ mod udaf;
5454
#[allow(clippy::borrow_deref_ref)]
5555
mod udf;
5656
pub mod utils;
57+
mod window_frame;
5758

5859
#[cfg(feature = "mimalloc")]
5960
#[global_allocator]
@@ -83,6 +84,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
8384
m.add_class::<context::PySessionContext>()?;
8485
m.add_class::<dataframe::PyDataFrame>()?;
8586
m.add_class::<udf::PyScalarUDF>()?;
87+
m.add_class::<window_frame::PyWindowFrame>()?;
8688
m.add_class::<udaf::PyAggregateUDF>()?;
8789
m.add_class::<config::PyConfig>()?;
8890
m.add_class::<sql::logical::PyLogicalPlan>()?;

src/udaf.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
use std::sync::Arc;
1919

20-
use pyo3::{prelude::*, types::PyTuple};
20+
use pyo3::{prelude::*, types::PyBool, types::PyTuple};
2121

2222
use datafusion::arrow::array::{Array, ArrayRef};
2323
use datafusion::arrow::datatypes::DataType;
@@ -93,6 +93,34 @@ impl Accumulator for RustAccumulator {
9393
fn size(&self) -> usize {
9494
std::mem::size_of_val(self)
9595
}
96+
97+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
98+
Python::with_gil(|py| {
99+
// 1. cast args to Pyarrow array
100+
let py_args = values
101+
.iter()
102+
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
103+
.collect::<Vec<_>>();
104+
let py_args = PyTuple::new(py, py_args);
105+
106+
// 2. call function
107+
self.accum
108+
.as_ref(py)
109+
.call_method1("retract_batch", py_args)
110+
.map_err(|e| DataFusionError::Execution(format!("{e}")))?;
111+
112+
Ok(())
113+
})
114+
}
115+
116+
fn supports_retract_batch(&self) -> bool {
117+
Python::with_gil(|py| {
118+
let x: Result<&PyAny, PyErr> =
119+
self.accum.as_ref(py).call_method0("supports_retract_batch");
120+
let x: &PyAny = x.unwrap_or(PyBool::new(py, false));
121+
x.extract().unwrap_or(false)
122+
})
123+
}
96124
}
97125

98126
pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {

src/window_frame.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion_common::{DataFusionError, ScalarValue};
19+
use datafusion_expr::window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
20+
use pyo3::prelude::*;
21+
use std::fmt::{Display, Formatter};
22+
23+
use crate::errors::py_datafusion_err;
24+
25+
#[pyclass(name = "WindowFrame", module = "datafusion", subclass)]
26+
#[derive(Clone)]
27+
pub struct PyWindowFrame {
28+
frame: WindowFrame,
29+
}
30+
31+
impl From<PyWindowFrame> for WindowFrame {
32+
fn from(frame: PyWindowFrame) -> Self {
33+
frame.frame
34+
}
35+
}
36+
37+
impl From<WindowFrame> for PyWindowFrame {
38+
fn from(frame: WindowFrame) -> PyWindowFrame {
39+
PyWindowFrame { frame }
40+
}
41+
}
42+
43+
impl Display for PyWindowFrame {
44+
fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
45+
write!(
46+
f,
47+
"OVER ({} BETWEEN {} AND {})",
48+
self.frame.units, self.frame.start_bound, self.frame.end_bound
49+
)
50+
}
51+
}
52+
53+
#[pymethods]
54+
impl PyWindowFrame {
55+
#[new(unit, start_bound, end_bound)]
56+
pub fn new(units: &str, start_bound: Option<u64>, end_bound: Option<u64>) -> PyResult<Self> {
57+
let units = units.to_ascii_lowercase();
58+
let units = match units.as_str() {
59+
"rows" => WindowFrameUnits::Rows,
60+
"range" => WindowFrameUnits::Range,
61+
"groups" => WindowFrameUnits::Groups,
62+
_ => {
63+
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
64+
"{:?}",
65+
units,
66+
))));
67+
}
68+
};
69+
let start_bound = match start_bound {
70+
Some(start_bound) => {
71+
WindowFrameBound::Preceding(ScalarValue::UInt64(Some(start_bound)))
72+
}
73+
None => match units {
74+
WindowFrameUnits::Range => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
75+
WindowFrameUnits::Rows => WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
76+
WindowFrameUnits::Groups => {
77+
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
78+
"{:?}",
79+
units,
80+
))));
81+
}
82+
},
83+
};
84+
let end_bound = match end_bound {
85+
Some(end_bound) => WindowFrameBound::Following(ScalarValue::UInt64(Some(end_bound))),
86+
None => match units {
87+
WindowFrameUnits::Rows => WindowFrameBound::Following(ScalarValue::UInt64(None)),
88+
WindowFrameUnits::Range => WindowFrameBound::Following(ScalarValue::UInt64(None)),
89+
WindowFrameUnits::Groups => {
90+
return Err(py_datafusion_err(DataFusionError::NotImplemented(format!(
91+
"{:?}",
92+
units,
93+
))));
94+
}
95+
},
96+
};
97+
Ok(PyWindowFrame {
98+
frame: WindowFrame {
99+
units,
100+
start_bound,
101+
end_bound,
102+
},
103+
})
104+
}
105+
106+
/// Get a String representation of this window frame
107+
fn __repr__(&self) -> String {
108+
format!("{}", self)
109+
}
110+
}

0 commit comments

Comments
 (0)
0