8000 Add arrow cast (#962) · SanjayUG/datafusion-python@e36e8ab · GitHub
[go: up one dir, main page]

Skip to content

Commit e36e8ab

Browse files
authored
Add arrow cast (apache#962)
* feat: add data_type parameter to expr_fn macro for arrow_cast function * feat: add arrow_cast function to cast expressions to specified data types * docs: add casting section to user guide with examples for arrow_cast function * test: add unit test for arrow_cast function to validate casting to Float64 and Int32 * fix: update arrow_cast function to accept Expr type for data_type parameter * fix: update test_arrow_cast to use literal casting for data types * fix: update arrow_cast function to accept string type for data_type parameter * fix: update arrow_cast function to accept Expr type for data_type parameter * fix: update test_arrow_cast to use literal for data type parameters * fix: update arrow_cast function to use arg_1 for datatype parameter * fix: update arrow_cast function to accept string type for data_type parameter * Revert "fix: update arrow_cast function to accept string type for data_type parameter" This reverts commit eba0d32. * fix: update test_arrow_cast to cast literals to string type for arrow_cast function * Revert "fix: update test_arrow_cast to cast literals to string type for arrow_cast function" This reverts commit 856ff8c. * fix: update arrow_cast function to accept string type for data_type parameter * Revert "fix: update arrow_cast function to accept string type for data_type parameter" This reverts commit 9e1ced7. * fix: add utf8_literal function to create UTF8 literal expressions in tests * Revert "fix: add utf8_literal function to create UTF8 literal expressions in tests" This reverts commit 11ed674. * feat: add utf8_literal function to create UTF8 literal expressions * fix: update test_arrow_cast to use column 'b' * fix: enhance utf8_literal function to handle non-string values * Add description for utf8_literal vs literal * docs: clarify utf8_literal function documentation to explain use case * docs: add clarification comments for utf8_literal usage in arrow_cast tests * docs: implement ruff recommendation * fix ruff errors * docs: update examples to use utf8_literal in arrow_cast function * docs: correct typo in comment for utf8_literal usage in test_arrow_cast * docs: remove redundant comment in test_arrow_cast for clarity * refactor: rename utf8_literal to string_literal and add alias str_lit * docs: improve docstring for string_literal function for clarity * docs: update import statement to include str_lit alias for string_literal
1 parent 79c22d6 commit e36e8ab

File tree

6 files changed

+66
-3
lines changed

6 files changed

+66
-3
lines changed

docs/source/user-guide/common-operations/functions.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ DataFusion offers mathematical functions such as :py:func:`~datafusion.functions
3838

3939
.. ipython:: python
4040
41-
from datafusion import col, literal
41+
from datafusion import col, literal, string_literal, str_lit
4242
from datafusion import functions as f
4343
4444
df.select(
@@ -104,6 +104,17 @@ This also includes the functions for regular expressions like :py:func:`~datafus
104104
f.regexp_replace(col('"Name"'), literal("saur"), literal("fleur")).alias("flowers")
105105
)
106106
107+
Casting
108+
-------
109+
110+
Casting expressions to different data types using :py:func:`~datafusion.functions.arrow_cast`
111+
112+
.. ipython:: python
113+
114+
df.select(
115+
f.arrow_cast(col('"Total"'), string_literal("Float64")).alias("total_as_float"),
116+
f.arrow_cast(col('"Total"'), str_lit("Int32")).alias("total_as_int")
117+
)
107118
108119
Other
109120
-----

python/datafusion/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,19 @@ def literal(value):
107107
return Expr.literal(value)
108108

109109

110+
def string_literal(value):
111+
"""Create a UTF8 literal expression.
112+
113+
It differs from `literal` which creates a UTF8view literal.
114+
"""
115+
return Expr.string_literal(value)
116+
117+
118+
def str_lit(value):
119+
"""Alias for `string_literal`."""
120+
return string_literal(value)
121+
122+
110123
def lit(value):
111124
"""Create a literal expression."""
112125
return Expr.literal(value)

python/datafusion/expr.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,22 @@ def literal(value: Any) -> Expr:
380380
value = pa.scalar(value)
381381
return Expr(expr_internal.Expr.literal(value))
382382

383+
@staticmethod
384+
def string_literal(value: str) -> Expr:
385+
"""Creates a new expression representing a UTF8 literal value.
386+
387+
It is different from `literal` because it is pa.string() instead of
388+
pa.string_view()
389+
390+
This is needed for cases where DataFusion is expecting a UTF8 instead of
391+
UTF8View literal, like in:
392+
https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
393+
"""
394+
if isinstance(value, str):
395+
value = pa.scalar(value, type=pa.string())
396+
return Expr(expr_internal.Expr.literal(value))
397+
return Expr.literal(value)
398+
383399
@staticmethod
384400
def column(value: str) -> Expr:
385401
"""Creates a new expression representing a column."""

python/datafusion/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
"array_to_string",
8383
"array_union",
8484
"arrow_typeof",
85+
"arrow_cast",
8586
"ascii",
8687
"asin",
8788
"asinh",
@@ -1109,6 +1110,11 @@ def arrow_typeof(arg: Expr) -> Expr:
11091110
return Expr(f.arrow_typeof(arg.expr))
11101111

11111112

1113+
def arrow_cast(expr: Expr, data_type: Expr) -> Expr:
1114+
"""Casts an expression to a specified data type."""
1115+
return Expr(f.arrow_cast(expr.expr, data_type.expr))
1116+
1117+
11121118
def random() -> Expr:
11131119
"""Returns a random value in the range ``0.0 <= x < 1.0``."""
11141120
return Expr(f.random())

python/tests/test_functions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from datafusion import SessionContext, column
2525
from datafusion import functions as f
26-
from datafusion import literal
26+
from datafusion import literal, string_literal
2727

2828
np.seterr(invalid="ignore")
2929

@@ -907,6 +907,22 @@ def test_temporal_functions(df):
907907
assert result.column(10) == pa.array([31, 26, 2], type=pa.float64())
908908

909909

910+
def test_arrow_cast(df):
911+
df = df.select(
912+
# we use `string_literal` to return utf8 instead of `literal` which returns
913+
# utf8view because datafusion.arrow_cast expects a utf8 instead of utf8view
914+
# https://github.com/apache/datafusion/blob/86740bfd3d9831d6b7c1d0e1bf4a21d91598a0ac/datafusion/functions/src/core/arrow_cast.rs#L179
915+
f.arrow_cast(column("b"), string_literal("Float64")).alias("b_as_float"),
916+
f.arrow_cast(column("b"), string_literal("Int32")).alias("b_as_int"),
917+
)
918+
result = df.collect()
919+
assert len(result) == 1
920+
result = result[0]
921+
922+
assert result.column(0) == pa.array([4.0, 5.0, 6.0], type=pa.float64())
923+
assert result.column(1) == pa.array([4, 5, 6], type=pa.int32())
924+
925+
910926
def test_case(df):
911927
df = df.select(
912928
f.case(column("b")).when(literal(4), literal(10)).otherwise(literal(8)),

src/functions.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,6 @@ macro_rules! expr_fn {
400400
}
401401
};
402402
}
403-
404403
/// Generates a [pyo3] wrapper for [datafusion::functions::expr_fn]
405404
///
406405
/// These functions take a single `Vec<PyExpr>` argument using `pyo3(signature = (*args))`.
@@ -575,6 +574,7 @@ expr_fn_vec!(r#struct); // Use raw identifier since struct is a keyword
575574
expr_fn_vec!(named_struct);
576575
expr_fn!(from_unixtime, unixtime);
577576
expr_fn!(arrow_typeof, arg_1);
577+
expr_fn!(arrow_cast, arg_1 datatype);
578578
expr_fn!(random);
579579

580580
// Array Functions
@@ -867,6 +867,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
867867
m.add_wrapped(wrap_pyfunction!(range))?;
868868
m.add_wrapped(wrap_pyfunction!(array_agg))?;
869869
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
870+
m.add_wrapped(wrap_pyfunction!(arrow_cast))?;
870871
m.add_wrapped(wrap_pyfunction!(ascii))?;
871872
m.add_wrapped(wrap_pyfunction!(asin))?;
872873
m.add_wrapped(wrap_pyfunction!(asinh))?;

0 commit comments

Comments
 (0)
0