8000 Add Column, Literal, BinaryExpr Python wrappers · andygrove/datafusion-python@025b12e · GitHub
[go: up one dir, main page]

Skip to content

Commit 025b12e

Browse files
committed
Add Column, Literal, BinaryExpr Python wrappers
1 parent 6c3cb97 commit 025b12e

File tree

7 files changed

+219
-1
lines changed

7 files changed

+219
-1
lines changed

datafusion/tests/test_expr.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from datafusion import SessionContext
19+
from datafusion.expr import Column, Literal, BinaryExpr, Projection
20+
import pytest
21+
22+
23+
@pytest.fixture
24+
def test_ctx():
25+
ctx = SessionContext()
26+
ctx.register_csv("test", "testing/data/csv/aggregate_test_100.csv")
27+
return ctx
28+
29+
30+
def test_logical_plan(test_ctx):
31+
df = test_ctx.sql("select c1, 123, c1 < 123 from test")
32+
plan = df.logical_plan()
33+
34+
projection = plan.to_logical_node()
35+
assert isinstance(projection, Projection)
36+
37+
expr = projection.projections()
38+
assert isinstance(expr[0].to_logical_expr(), Column)
39+
assert isinstance(expr[1].to_logical_expr(), Literal)
40+
assert isinstance(expr[2].to_logical_expr(), BinaryExpr)

datafusion/tests/test_imports.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333

3434
from datafusion.expr import (
3535
Expr,
36+
Column,
37+
Literal,
38+
BinaryExpr,
3639
Projection,
3740
TableScan,
3841
)
@@ -55,7 +58,7 @@ def test_class_module_is_datafusion():
5558
]:
5659
assert klass.__module__ == "datafusion"
5760

58-
for klass in [Expr, Projection, TableScan]:
61+
for klass in [Expr, Column, Literal, BinaryExpr, Projection, TableScan]:
5962
assert klass.__module__ == "datafusion.expr"
6063

6164
for klass in [DFField, DFSchema]:

src/expr.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,15 @@ use datafusion::arrow::datatypes::DataType;
2222
use datafusion::arrow::pyarrow::PyArrowType;
2323
use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField};
2424

25+
use crate::errors::py_runtime_err;
26+
use crate::expr::binary_expr::PyBinaryExpr;
27+
use crate::expr::column::PyColumn;
28+
use crate::expr::literal::PyLiteral;
2529
use datafusion::scalar::ScalarValue;
2630

31+
pub mod binary_expr;
32+
pub mod column;
33+
pub mod literal;
2734
pub mod logical_node;
2835
pub mod projection;
2936
pub mod table_scan;
@@ -49,6 +56,19 @@ impl From<Expr> for PyExpr {
4956

5057
#[pymethods]
5158
impl PyExpr {
59+
/// Return a Python object representation of this logical expression
60+
fn to_logical_expr(&self, py: Python) -> PyResult<PyObject> {
61+
Python::with_gil(|_| match &self.expr {
62+
Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)),
63+
Expr::Literal(value) => Ok(PyLiteral::from(value.clone()).into_py(py)),
64+
Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_py(py)),
65+
other => Err(py_runtime_err(format!(
66+
"Cannot convert this Expr to a Python object: {:?}",
67+
other
68+
))),
69+
})
70+
}
71+
5272
fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
5373
let expr = match op {
5474
CompareOp::Lt => self.expr.clone().lt(other.expr),
@@ -143,5 +163,9 @@ pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
143163
m.add_class::<PyExpr>()?;
144164
m.add_class::<table_scan::PyTableScan>()?;
145165
m.add_class::<projection::PyProjection>()?;
166+
m.add_class::<column::PyColumn>()?;
167+
m.add_class::<literal::PyLiteral>()?;
168+
m.add_class::<binary_expr::PyBinaryExpr>()?;
169+
m.add_class::<literal::PyLiteral>()?;
146170
Ok(())
147171
}

src/expr/binary_expr.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion_expr::BinaryExpr;
19+
use pyo3::prelude::*;
20+
21+
#[pyclass(name = "BinaryExpr", module = "datafusion.expr", subclass)]
22+
#[derive(Clone)]
23+
pub struct PyBinaryExpr {
24+
expr: BinaryExpr,
25+
}
26+
27+
impl From<PyBinaryExpr> for BinaryExpr {
28+
fn from(expr: PyBinaryExpr) -> Self {
29+
expr.expr
30+
}
31+
}
32+
33+
impl From<BinaryExpr> for PyBinaryExpr {
34+
fn from(expr: BinaryExpr) -> PyBinaryExpr {
35+
PyBinaryExpr { expr }
36+
}
37+
}
38+
39+
#[pymethods]
40+
impl PyBinaryExpr {
41+
fn __repr__(&self) -> PyResult<String> {
42+
Ok(format!("{}", self.expr))
43+
}
44+
}

src/expr/column.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion_common::Column;
19+
use pyo3::prelude::*;
20+
21+
#[pyclass(name = "Column", module = "datafusion.expr", subclass)]
22+
#[derive(Clone)]
23+
pub struct PyColumn {
24+
pub col: Column,
25+
}
26+
27+
impl PyColumn {
28+
pub fn new(col: Column) -> Self {
29+
Self { col }
30+
}
31+
}
32+
33+
impl From<Column> for PyColumn {
34+
fn from(col: Column) -> PyColumn {
35+
PyColumn { col }
36+
}
37+
}
38+
39+
#[pymethods]
40+
impl PyColumn {
41+
fn __repr__(&self) -> PyResult<String> {
42+
Ok(format!("{}", self.col))
43+
}
44+
}

src/expr/literal.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion_common::ScalarValue;
19+
use pyo3::prelude::*;
20+
21+
#[pyclass(name = "Literal", module = "datafusion.expr", subclass)]
22+
#[derive(Clone)]
23+
pub struct PyLiteral {
24+
pub value: ScalarValue,
25+
}
26+
27+
impl From<PyLiteral> for ScalarValue {
28+
fn from(lit: PyLiteral) -> ScalarValue {
29+
lit.value
30+
}
31+
}
32+
33+
impl From<ScalarValue> for PyLiteral {
34+
fn from(value: ScalarValue) -> PyLiteral {
35+
PyLiteral { value }
36+
}
37+
}
38+
39+
#[pymethods]
40+
impl PyLiteral {
41+
fn __repr__(&self) -> PyResult<String> {
42+
Ok(format!("{}", self.value))
43+
}
44+
}

src/sql/logical.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
use std::sync::Arc;
1919

20+
use crate::errors::py_runtime_err;
21+
use crate::expr::projection::PyProjection;
22+
use crate::expr::table_scan::PyTableScan;
2023
use datafusion_expr::LogicalPlan;
2124
use pyo3::prelude::*;
2225

@@ -37,6 +40,18 @@ impl PyLogicalPlan {
3740

3841
#[pymethods]
3942
impl PyLogicalPlan {
43+
/// Return a Python object representation of this logical operator
44+
fn to_logical_node(&self, py: Python) -> PyResult<PyObject> {
45+
Python::with_gil(|_| match self.plan.as_ref() {
46+
LogicalPlan::Projection(plan) => Ok(PyProjection::from(plan.clone()).into_py(py)),
47+
LogicalPlan::TableScan(plan) => Ok(PyTableScan::from(plan.clone()).into_py(py)),
48+
other => Err(py_runtime_err(format!(
49+
"Cannot convert this plan to a LogicalNode: {:?}",
50+
other
51+
))),
52+
})
53+
}
54+
4055
/// Get the inputs to this plan
4156
pub fn inputs(&self) -> Vec<PyLogicalPlan> {
4257
let mut inputs = vec![];
@@ -46,6 +61,10 @@ impl PyLogicalPlan {
4661
inputs
4762
}
4863

64+
fn __repr__(&self) -> PyResult<String> {
65+
Ok(format!("{:?}", self.plan))
66+
}
67+
4968
pub fn display(&self) -> String {
5069
format!("{}", self.plan.display())
5170
}

0 commit comments

Comments
 (0)
0