8000 add python binding for approx_distinct aggregate function (#1134) · tfeda/datafusion-python@3cc92fa · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 3cc92fa

Browse files
committed
add python binding for approx_distinct aggregate function (#1134)
GitOrigin-RevId: a3ffc529dd391ee47380f489be6b7c7c341b3b74
1 parent aa37b8a commit 3cc92fa

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

src/functions.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ define_unary_function!(avg);
224224
define_unary_function!(min);
225225
define_unary_function!(max);
226226
define_unary_function!(count);
227+
define_unary_function!(approx_distinct);
227228

228229
#[pyclass(name = "Volatility", module = "datafusion.functions")]
229230
#[derive(Clone)]
@@ -323,6 +324,7 @@ pub fn init(module: &PyModule) -> PyResult<()> {
323324
module.add_class::<PyVolatility>()?;
324325
module.add_function(wrap_pyfunction!(abs, module)?)?;
325326
module.add_function(wrap_pyfunction!(acos, module)?)?;
327+
module.add_function(wrap_pyfunction!(approx_distinct, module)?)?;
326328
module.add_function(wrap_pyfunction!(array, module)?)?;
327329
module.add_function(wrap_pyfunction!(ascii, module)?)?;
328330
module.add_function(wrap_pyfunction!(asin, module)?)?;

tests/test_aggregation.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pyarrow as pa
19+
import pytest
20+
from datafusion import ExecutionContext
21+
from datafusion import functions as f
22+
23+
24+
@pytest.fixture
25+
def df():
26+
ctx = ExecutionContext()
27+
28+
# create a RecordBatch and a new DataFrame from it
29+
batch = pa.RecordBatch.from_arrays(
30+
[pa.array([1, 2, 3]), pa.array([4, 4, 6])],
31+
names=["a", "b"],
32+
)
33+
return ctx.create_dataframe([[batch]])
34+
35+
36+
def test_built_in_aggregation(df):
37+
col_a = f.col("a")
38+
col_b = f.col("b")
39+
df = df.aggregate(
40+
[],
41+
[f.max(col_a), f.min(col_a), f.count(col_a), f.approx_distinct(col_b)],
42+
)
43+
result = df.collect()[0]
44+
assert result.column(0) == pa.array([3])
45+
assert result.column(1) == pa.array([1])
46+
assert result.column(2) == pa.array([3], type=pa.uint64())
47+
assert result.column(3) == pa.array([2], type=pa.uint64())

0 commit comments

Comments
 (0)
0