8000 Feature/expose when function (#836) · PhVHoang/datafusion-python@003eea8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 003eea8

Browse files
authored
Feature/expose when function (apache#836)
1 parent 69ed7fe commit 003eea8

File tree

3 files changed

+39
-0
lines changed

3 files changed

+39
-0
lines changed

python/datafusion/functions.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@
245245
"var",
246246
"var_pop",
247247
"var_samp",
248+
"when",
248249
"window",
249250
]
250251

@@ -364,6 +365,16 @@ def case(expr: Expr) -> CaseBuilder:
364365
return CaseBuilder(f.case(expr.expr))
365366

366367

368+
def when(when: Expr, then: Expr) -> CaseBuilder:
369+
"""Create a case expression that has no base expression.
370+
371+
Create a :py:class:`~datafusion.expr.CaseBuilder` to match cases for the
372+
expression ``expr``. See :py:class:`~datafusion.expr.CaseBuilder` for
373+
detailed usage.
374+
"""
375+
return CaseBuilder(f.when(when.expr, then.expr))
376+
377+
367378
def window(
368379
name: str,
369380
args: list[Expr],

python/datafusion/tests/test_functions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,25 @@ def test_case(df):
836836
assert result.column(2) == pa.array(["Hola", "Mundo", None])
837837

838838

839+
def test_when_with_no_base(df):
840+
df.show()
841+
df = df.select(
842+
column("b"),
843+
f.when(column("b") > literal(5), literal("too big"))
844+
.when(column("b") < literal(5), literal("too small"))
845+
.otherwise(literal("just right"))
846+
.alias("goldilocks"),
847+
f.when(column("a") == literal("Hello"), column("a")).end().alias("greeting"),
848+
)
849+
df.show()
850+
851+
result = df.collect()
852+
result = result[0]
853+
assert result.column(0) == pa.a 8000 rray([4, 5, 6])
854+
assert result.column(1) == pa.array(["too small", "just right", "too big"])
855+
assert result.column(2) == pa.array(["Hello", None, None])
856+
857+
839858
def test_regr_funcs_sql(df):
840859
# test case base on
841860
# https://github.com/apache/arrow-datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2330

src/functions.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,14 @@ fn case(expr: PyExpr) -> PyResult<PyCaseBuilder> {
533533
})
534534
}
535535

536+
/// Create a CASE WHEN statement with literal WHEN expressions for comparison to the base expression.
537+
#[pyfunction]
538+
fn when(when: PyExpr, then: PyExpr) -> PyResult<PyCaseBuilder> {
539+
Ok(PyCaseBuilder {
540+
case_builder: datafusion_expr::when(when.expr, then.expr),
541+
})
542+
}
543+
536544
/// Helper function to find the appropriate window function.
537545
///
538546
/// Search procedure:
@@ -910,6 +918,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
910918
m.add_wrapped(wrap_pyfunction!(char_length))?;
911919
m.add_wrapped(wrap_pyfunction!(coalesce))?;
912920
m.add_wrapped(wrap_pyfunction!(case))?;
921+
m.add_wrapped(wrap_pyfunction!(when))?;
913922
m.add_wrapped(wrap_pyfunction!(col))?;
914923
m.add_wrapped(wrap_pyfunction!(concat_ws))?;
915924
m.add_wrapped(wrap_pyfunction!(concat))?;

0 commit comments

Comments
 (0)
0