8000 Add tests for recently added functionality (#199) · apache/datafusion-python@d62cbdf · GitHub
[go: up one dir, main page]

Skip to content

Commit d62cbdf

Browse files
authored
Add tests for recently added functionality (#199)
1 parent 3124278 commit d62cbdf

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

datafusion/tests/test_expr.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from datafusion import SessionContext
19+
from datafusion.expr import (
20+
Projection,
21+
Filter,
22+
Aggregate,
23+
Limit,
24+
Sort,
25+
TableScan,
26+
)
27+
import pytest
28+
29+
30+
@pytest.fixture
31+
def test_ctx():
32+
ctx = SessionContext()
33+
ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv")
34+
return ctx
35+
36+
37+
def test_projection(test_ctx):
38+
df = test_ctx.sql("select c1, 123, c1 < 123 from test")
39+
plan = df.logical_plan()
40+
41+
plan = plan.to_variant()
42+
assert isinstance(plan, Projection)
43+
44+
plan = plan.input().to_variant()
45+
assert isinstance(plan, TableScan)
46+
47+
48+
def test_filter(test_ctx):
49+
df = test_ctx.sql("select c1 from test WHERE c1 > 5")
50+
plan = df.logical_plan()
51+
52+
plan = plan.to_variant()
53+
assert isinstance(plan, Projection)
54+
55+
plan = plan.input().to_variant()
56+
assert isinstance(plan, Filter)
57+
58+
59+
def test_limit(test_ctx):
60+
df = test_ctx.sql("select c1 from test LIMIT 10")
61+
plan = df.logical_plan()
62+
63+
plan = plan.to_variant()
64+
assert isinstance(plan, Limit)
65+
66+
67+
def test_aggregate(test_ctx):
68+
df = test_ctx.sql("select c1, COUNT(*) from test GROUP BY c1")
69+
plan = df.logical_plan()
70+
71+
plan = plan.to_variant()
72+
assert isinstance(plan, Projection)
73+
74+
plan = plan.input().to_variant()
75+
assert isinstance(plan, Aggregate)
76+
77+
78+
def test_sort(test_ctx):
79+
df = test_ctx.sql("select c1 from test order by c1")
80+
plan = df.logical_plan()
81+
82+
plan = plan.to_variant()
83+
assert isinstance(plan, Sort)

src/sql/logical.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717

1818
use std::sync::Arc;
1919

20+
use crate::errors::py_runtime_err;
21+
use crate::expr::aggregate::PyAggregate;
22+
use crate::expr::filter::PyFilter;
23+
use crate::expr::limit::PyLimit;
24+
use crate::expr::projection::PyProjection;
25+
use crate::expr::sort::PySort;
26+
use crate::expr::table_scan::PyTableScan;
2027
use datafusion_expr::LogicalPlan;
2128
use pyo3::prelude::*;
2229

@@ -37,6 +44,22 @@ impl PyLogicalPlan {
3744

3845
#[pymethods]
3946
impl PyLogicalPlan {
47+
/// Return the specific logical operator
48+
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
49+
Python::with_gil(|_| match self.plan.as_ref() {
50+
LogicalPlan::Projection(plan) => Ok(PyProjection::from(plan.clone()).into_py(py)),
51+
LogicalPlan::TableScan(plan) => Ok(PyTableScan::from(plan.clone()).into_py(py)),
52+
LogicalPlan::Filter(plan) => Ok(PyFilter::from(plan.clone()).into_py(py)),
53+
LogicalPlan::Limit(plan) => Ok(PyLimit::from(plan.clone()).into_py(py)),
54+
LogicalPlan::Sort(plan) => Ok(PySort::from(plan.clone()).into_py(py)),
55+
LogicalPlan::Aggregate(plan) => Ok(PyAggregate::from(plan.clone()).into_py(py)),
56+
other => Err(py_runtime_err(format!(
57+
"Cannot convert this plan to a LogicalNode: {:?}",
58+
other
59+
))),
60+
})
61+
}
62+
4063
/// Get the inputs to this plan
4164
pub fn inputs(&self) -> Vec<PyLogicalPlan> {
4265
let mut inputs = vec![];

0 commit comments

Comments
 (0)
0