8000 Expose array sort (#764) · PhVHoang/datafusion-python@f00b8ee · GitHub
[go: up one dir, main page]

Skip to content

Commit f00b8ee

Browse files
authored
Expose array sort (apache#764)
1 parent aa8aa9c commit f00b8ee

File tree

3 files changed

+34
-0
lines changed

3 files changed

+34
-0
lines changed

python/datafusion/functions.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,28 @@ def list_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr:
11261126
return array_replace_all(array, from_val, to_val)
11271127

11281128

1129+
def array_sort(array: Expr, descending: bool = False, null_first: bool = False) -> Expr:
1130+
"""Sort an array.
1131+
1132+
Args:
1133+
array: The input array to sort.
1134+
descending: If True, sorts in descending order.
1135+
null_first: If True, nulls will be returned at the beginning of the array.
1136+
"""
1137+
desc = "DESC" if descending else "ASC"
1138+
nulls_first = "NULLS FIRST" if null_first else "NULLS LAST"
1139+
return Expr(
1140+
f.array_sort(
1141+
array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr
1142+
)
1143+
)
1144+
1145+
1146+
def list_sort(array: Expr, descending: bool = False, null_first: bool = False) -> Expr:
1147+
"""This is an alias for ``array_sort``."""
1148+
return array_sort(array, descending=descending, null_first=null_first)
1149+
1150+
11291151
def array_slice(
11301152
array: Expr, begin: Expr, end: Expr, stride: Expr | None = None
11311153
) -> Expr:

python/datafusion/tests/test_functions.py

Lines changed: 8 additions & 0 deletions
< 8000 td data-grid-cell-id="diff-5f26615a7388d5a3d118447dfdc7630050e31421d070c001abab59138818a1bc-457-465-1" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative diff-line-number-neutral left-side">465
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,14 @@ def py_flatten(arr):
453453
lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)),
454454
lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data],
455455
],
456+
[
457+
lambda col: f.array_sort(col, descending=True, null_first=True),
458+
lambda data: [np.sort(arr)[::-1] for arr in data],
459+
],
460+
[
461+
lambda col: f.list_sort(col, descending=False, null_first=False),
462+
lambda data: [np.sort(arr) for arr in data],
463+
],
456464
[
457
lambda col: f.array_slice(col, literal(2), literal(4)),
458466
lambda data: [arr[1:4] for arr in data],

src/functions.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,8 @@ array_fn!(array_replace_n, array from to max);
673673
array_fn!(list_replace_n, array_replace_n, array from to max);
674674
array_fn!(array_replace_all, array from to);
675675
array_fn!(list_replace_all, array_replace_all, array from to);
676+
array_fn!(array_sort, array desc null_first);
677+
array_fn!(list_sort, array_sort, array desc null_first);
676678
array_fn!(array_intersect, first_array second_array);
677679
array_fn!(list_intersect, array_intersect, first_array second_array);
678680
array_fn!(array_union, array1 array2);
@@ -936,6 +938,8 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
936938
m.add_wrapped(wrap_pyfunction!(list_replace_n))?;
937939
m.add_wrapped(wrap_pyfunction!(array_replace_all))?;
938940
m.add_wrapped(wrap_pyfunction!(list_replace_all))?;
941+
m.add_wrapped(wrap_pyfunction!(array_sort))?;
942+
m.add_wrapped(wrap_pyfunction!(list_sort))?;
939943
m.add_wrapped(wrap_pyfunction!(array_slice))?;
940944
m.add_wrapped(wrap_pyfunction!(list_slice))?;
941945
m.add_wrapped(wrap_pyfunction!(flatten))?;

0 commit comments

Comments
 (0)
0