8000 feat: add `cardinality` function to calculate total elements in an ar… · kylebarron/datafusion-python@e015482 · GitHub
[go: up one dir, main page]

Skip to content

Commit e015482

Browse files
authored
feat: add cardinality function to calculate total elements in an array (apache#937)
1 parent 0bc2f31 commit e015482

File tree

4 files changed

+40
-0
lines changed

4 files changed

+40
-0
lines changed

docs/source/user-guide/common-operations/expressions.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,20 @@ This function returns a boolean indicating whether the array is empty.
9696
9797
In this example, the `is_empty` column will contain `True` for the first row and `False` for the second row.
9898

99+
To get the total number of elements in an array, you can use the function :py:func:`datafusion.functions.cardinality`.
100+
This function returns an integer indicating the total number of elements in the array.
101+
102+
.. ipython:: python
103+
104+
from datafusion import SessionContext, col
105+
from datafusion.functions import cardinality
106+
107+
ctx = SessionContext()
108+
df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5, 6]]})
109+
df.select(cardinality(col("a")).alias("num_elements"))
110+
111+
In this example, the `num_elements` column will contain `3` for both rows.
112+
99113
Structs
100114
-------
101115

python/datafusion/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
"find_in_set",
133133
"first_value",
134134
"flatten",
135+
"cardinality",
135136
"floor",
136137
"from_unixtime",
137138
"gcd",
@@ -1516,6 +1517,11 @@ def flatten(array: Expr) -> Expr:
15161517
return Expr(f.flatten(array.expr))
15171518

15181519

1520+
def cardinality(array: Expr) -> Expr:
1521+
"""Returns the total number of elements in the array."""
1522+
return Expr(f.cardinality(array.expr))
1523+
1524+
15191525
# aggregate functions
15201526
def approx_distinct(
15211527
expression: Expr,

python/tests/test_functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,24 @@ def test_array_function_flatten():
540540
)
541541

542542

543+
def test_array_function_cardinality():
544+
data = [[1, 2, 3], [4, 4, 5, 6]]
545+
ctx = SessionContext()
546+
batch = pa.RecordBatch.from_arrays([np.array(data, dtype=object)], names=["arr"])
547+
df = ctx.create_dataframe([[batch]])
548+
549+
stmt = f.cardinality(column("arr"))
550+
py_expr = [len(arr) for arr in data] # Expected lengths: [3, 3]
551+
# assert py_expr lengths
552+
553+
query_result = df.select(stmt).collect()[0].column(0)
554+
555+
for a, b in zip(query_result, py_expr):
556+
np.testing.assert_array_equal(
557+
np.array([a.as_py()], dtype=int), np.array([b], dtype=int)
558+
)
559+
560+
543561
@pytest.mark.parametrize(
544562
("stmt", "py_expr"),
545563
[

src/functions.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ array_fn!(array_intersect, first_array second_array);
594594
array_fn!(array_union, array1 array2);
595595
array_fn!(array_except, first_array second_array);
596596
array_fn!(array_resize, array size value);
597+
array_fn!(cardinality, array);
597598
array_fn!(flatten, array);
598599
array_fn!(range, start stop step);
599600

@@ -1030,6 +1031,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
10301031
m.add_wrapped(wrap_pyfunction!(array_sort))?;
10311032
m.add_wrapped(wrap_pyfunction!(array_slice))?;
10321033
m.add_wrapped(wrap_pyfunction!(flatten))?;
1034+
m.add_wrapped(wrap_pyfunction!(cardinality))?;
10331035

10341036
// Window Functions
10351037
m.add_wrapped(wrap_pyfunction!(lead))?;

0 commit comments

Comments
 (0)
0