8000 UDAF `sum` workaround (#741) · datapythonista/datafusion-python@ec835ab · GitHub
[go: up one dir, main page]

Skip to content

Commit ec835ab

Browse files
UDAF sum workaround (apache#741)
* provides workaround for half-migrated UDAF `sum` Ref apache#730 * provide compatibility for sqlparser::ast::NullTreatment This is now exposed as part of the API to `first_value` and `last_value` functions. If there's a more elegant way to achieve this, please let me know.
1 parent 32d6975 commit ec835ab

File tree

6 files changed

+53
-14
lines changed

6 files changed

+53
-14
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ parking_lot = "0.12"
5656
regex-syntax = "0.8.1"
5757
syn = "2.0.67"
5858
url = "2.2"
59+
sqlparser = "0.47.0"
5960

6061
[build-dependencies]
6162
pyo3-build-config = "0.21"

examples/tpch/_tests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def check_q17(df):
7474
("q10_returned_item_reporting", "q10"),
7575
pytest.param(
7676
"q11_important_stock_identification", "q11",
77-
marks=pytest.mark.xfail # https://github.com/apache/datafusion-python/issues/730
7877
),
7978
("q12_ship_mode_order_priority", "q12"),
8079
("q13_customer_distribution", "q13"),

src/common.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
2929
m.add_class::<data_type::DataTypeMap>()?;
3030
m.add_class::<data_type::PythonType>()?;
3131
m.add_class::<data_type::SqlType>()?;
32+
m.add_class::<data_type::NullTreatment>()?;
3233
m.add_class::<schema::SqlTable>()?;
3334
m.add_class::<schema::SqlSchema>()?;
3435
m.add_class::<schema::SqlView>()?;

src/common/data_type.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,3 +757,33 @@ pub enum SqlType {
757757
VARBINARY,
758758
VARCHAR,
759759
}
760+
761+
/// Specifies Ignore / Respect NULL within window functions.
762+
/// For example
763+
/// `FIRST_VALUE(column2) IGNORE NULLS OVER (PARTITION BY column1)`
764+
#[allow(non_camel_case_types)]
765+
#[allow(clippy::upper_case_acronyms)]
766+
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
767+
#[pyclass(name = "PythonType", module = "datafusion.common")]
768+
pub enum NullTreatment {
769+
IGNORE_NULLS,
770+
RESPECT_NULLS,
771+
}
772+
773+
impl From<NullTreatment> for sqlparser::ast::NullTreatment {
774+
fn from(null_treatment: NullTreatment) -> sqlparser::ast::NullTreatment {
775+
match null_treatment {
776+
NullTreatment::IGNORE_NULLS => sqlparser::ast::NullTreatment::IgnoreNulls,
777+
NullTreatment::RESPECT_NULLS => sqlparser::ast::NullTreatment::RespectNulls,
778+
}
779+
}
780+
}
781+
782+
impl From<sqlparser::ast::NullTreatment> for NullTreatment {
783+
fn from(null_treatment: sqlparser::ast::NullTreatment) -> NullTreatment {
784+
match null_treatment {
785+
sqlparser::ast::NullTreatment::IgnoreNulls => NullTreatment::IGNORE_NULLS,
786+
sqlparser::ast::NullTreatment::RespectNulls => NullTreatment::RESPECT_NULLS,
787+
}
788+
}
789+
}

src/functions.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use pyo3::{prelude::*, wrap_pyfunction};
1919

20+
use crate::common::data_type::NullTreatment;
2021
use crate::context::PySessionContext;
2122
use crate::errors::DataFusionError;
2223
use crate::expr::conditional_expr::PyCaseBuilder;
@@ -73,15 +74,15 @@ pub fn var(y: PyExpr) -> PyExpr {
7374
}
7475

7576
#[pyfunction]
76-
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None))]
77+
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))]
7778
pub fn first_value(
7879
args: Vec<PyExpr>,
7980
distinct: bool,
8081
filter: Option<PyExpr>,
8182
order_by: Option<Vec<PyExpr>>,
83+
null_treatment: Option<NullTreatment>,
8284
) -> PyExpr {
83-
// TODO: allow user to select null_treatment
84-
let null_treatment = None;
85+
let null_treatment = null_treatment.map(Into::into);
8586
let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>();
8687
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
8788
functions_aggregate::expr_fn::first_value(
@@ -95,15 +96,15 @@ pub fn first_value(
9596
}
9697

9798
#[pyfunction]
98-
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None))]
99+
#[pyo3(signature = (*args, distinct = false, filter = None, order_by = None, null_treatment = None))]
99100
pub fn last_value(
100101
args: Vec<PyExpr>,
101102
distinct: bool,
102103
filter: Option<PyExpr>,
103104
order_by: Option<Vec<PyExpr>>,
105+
null_treatment: Option<NullTreatment>,
104106
) -> PyExpr {
105-
// TODO: allow user to select null_treatment
106-
let null_treatment = None;
107+
let null_treatment = null_treatment.map(Into::into);
107108
let args = args.into_iter().map(|x| x.expr).collect::<Vec<_>>();
108109
let order_by = order_by.map(|x| x.into_iter().map(|x| x.expr).collect::<Vec<_>>());
109110
functions_aggregate::expr_fn::last_value(
@@ -320,14 +321,20 @@ fn window(
320321
window_frame: Option<PyWindowFrame>,
321322
ctx: Option<PySessionContext>,
322323
) -> PyResult<PyExpr> {
323-
let fun = find_df_window_func(name).or_else(|| {
324-
ctx.and_then(|ctx| {
325-
ctx.ctx
326-
.udaf(name)
327-
.map(WindowFunctionDefinition::AggregateUDF)
328-
.ok()
324+
// workaround for https://github.com/apache/datafusion-python/issues/730
325+
let fun = if name == "sum" {
326+
let sum_udf = functions_aggregate::sum::sum_udaf();
327+
Some(WindowFunctionDefinition::AggregateUDF(sum_udf))
328+
} else {
329+
find_df_window_func(name).or_else(|| {
330+
ctx.and_then(|ctx| {
331+
ctx.ctx
332+
.udaf(name)
333+
.map(WindowFunctionDefinition::AggregateUDF)
334+
.ok()
335+
})
329336
})
330-
});
337+
};
331338
if fun.is_none() {
332339
return Err(DataFusionError::Common("window function not found".to_string()).into());
333340
}

0 commit comments

Comments
 (0)
0