8000 Add experimental support for executing SQL with Polars and Pandas by andygrove · Pull Request #190 · apache/datafusion-python · GitHub
[go: up one dir, main page]

Skip to content

Add experimental support for executing SQL with Polars and Pandas #190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
8000
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ from having to lock the GIL when running those operations.
Its query engine, DataFusion, is written in [Rust](https://www.rust-lang.org/), which makes strong assumptions
about thread safety and lack of memory leaks.

There is also experimental support for executing SQL against other DataFrame libraries, such as Polars, Pandas, and any
drop-in replacements for Pandas.

Technically, zero-copy is achieved via the [c data interface](https://arrow.apache.org/docs/format/CDataInterface.html).

## Example Usage
Expand Down
62 changes: 62 additions & 0 deletions datafusion/pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pandas as pd
import datafusion
from datafusion.expr import Projection, Ta 10000 bleScan, Column


class SessionContext:
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}

def register_parquet(self, name, path):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)

def to_pandas_expr(self, expr):

# get Python wrapper for logical expression
expr = expr.to_variant()

if isinstance(expr, Column):
return expr.name()
else:
raise Exception("unsupported expression: {}".format(expr))

def to_pandas_df(self, plan):
# recurse down first to translate inputs into pandas data frames
inputs = [self.to_pandas_df(x) for x in plan.inputs()]

# get Python wrapper for logical operator node
node = plan.to_variant()

if isinstance(node, Projection):
args = [self.to_pandas_expr(expr) for expr in node.projections()]
return inputs[0][args]
elif isinstance(node, TableScan):
return pd.read_parquet(self.parquet_tables[node.table_name()])
else:
raise Exception(
"unsupported logical operator: {}".format(type(node))
)

def sql(self, sql):
datafusion_df = self.datafusion_ctx.sql(sql)
plan = datafusion_df.logical_plan()
return self.to_pandas_df(plan)
85 changes: 85 additions & 0 deletions datafusion/polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import polars
import datafusion
from datafusion.expr import Projection, TableScan, Aggregate
from datafusion.expr import Column, AggregateFunction


class SessionContext:
def __init__(self):
self.datafusion_ctx = datafusion.SessionContext()
self.parquet_tables = {}

def register_parquet(self, name, path):
self.parquet_tables[name] = path
self.datafusion_ctx.register_parquet(name, path)

def to_polars_expr(self, expr):

# get Python wrapper for logical expression
expr = expr.to_variant()

if isinstance(expr, Column):
return polars.col(expr.name())
else:
raise Exception("unsupported expression: {}".format(expr))

def to_polars_df(self, plan):
# recurse down first to translate inputs into Polars data frames
inputs = [self.to_polars_df(x) for x in plan.inputs()]

# get Python wrapper for logical operator node
node = plan.to_variant()

if isinstance(node, Projection):
args = [self.to_polars_expr(expr) for expr in node.projections()]
return inputs[0].select(*args)
elif isinstance(node, Aggregate):
groupby_expr = [
self.to_polars_expr(expr) for expr in node.group_by_exprs()
]
aggs = []
for expr in node.aggregate_exprs():
expr = expr.to_variant()
if isinstance(expr, AggregateFunction):
if expr.aggregate_type() == "COUNT":
aggs.append(polars.count().alias("{}".format(expr)))
else:
raise Exception(
"Unsupported aggregate function {}".format(
expr.aggregate_type()
)
)
else:
raise Exception(
"Unsupported aggregate function {}".format(expr)
)
df = inputs[0].groupby(groupby_expr).agg(aggs)
return df
elif isinstance(node, TableScan):
return polars.read_parquet(self.parquet_tables[node.table_name()])
else:
raise Exception(
"unsupported logical operator: {}".format(type(node))
)

def sql(self, sql):
datafusion_df = self.datafusion_ctx.sql(sql)
plan = datafusion_df.logical_plan()
return self.to_polars_df(plan)
39 changes: 33 additions & 6 deletions datafusion/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from datafusion import SessionContext
from datafusion.expr import Column, Literal, BinaryExpr, AggregateFunction
from datafusion.expr import (
Projection,
Filter,
Expand All @@ -41,6 +42,24 @@ def test_projection(test_ctx):
plan = plan.to_variant()
assert isinstance(plan, Projection)

expr = plan.projections()

col1 = expr[0].to_variant()
assert isinstance(col1, Column)
assert col1.name() == "c1"
assert col1.qualified_name() == "test.c1"

col2 = expr[1].to_variant()
assert isinstance(col2, Literal)
assert col2.data_type() == "Int64"
assert col2.value_i64() == 123

col3 = expr[2].to_variant()
assert isinstance(col3, BinaryExpr)
assert isinstance(col3.left().to_variant(), Column)
assert col3.op() == "<"
assert isinstance(col3.right().to_variant(), Literal)

plan = plan.input().to_variant()
assert isinstance(plan, TableScan)

Expand All @@ -64,15 +83,23 @@ def test_limit(test_ctx):
assert isinstance(plan, Limit)


def test_aggregate(test_ctx):
df = test_ctx.sql("select c1, COUNT(*) from test GROUP BY c1")
def test_aggregate_query(test_ctx):
df = test_ctx.sql("select c1, count(*) from test group by c1")
plan = df.logical_plan()

plan = plan.to_variant()
assert isinstance(plan, Projection)
projection = plan.to_variant()
assert isinstance(projection, Projection)

plan = plan.input().to_variant()
assert isinstance(plan, Aggregate)
aggregate = projection.input().to_variant()
assert isinstance(aggregate, Aggregate)

col1 = aggregate.group_by_exprs()[0].to_variant()
assert isinstance(col1, Column)
assert col1.name() == "c1"
assert col1.qualified_name() == "test.c1"

col2 = aggregate.aggregate_exprs()[0].to_variant()
assert isinstance(col2, AggregateFunction)


def test_sort(test_ctx):
Expand Down
12 changes: 11 additions & 1 deletion datafusion/tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@

from datafusion.expr import (
Expr,
Column,
Literal,
BinaryExpr,
AggregateFunction,
Projection,
TableScan,
Filter,
Expand All @@ -59,9 +63,15 @@ def test_class_module_is_datafusion():
]:
assert klass.__module__ == "datafusion"

for klass in [Expr, Projection, TableScan, Aggregate, Sort, Limit, Filter]:
# expressions
for klass in [Expr, Column, Literal, BinaryExpr, AggregateFunction]:
assert klass.__module__ == "datafusion.expr"

# operators
for klass in [Projection, TableScan, Aggregate, Sort, Limit, Filter]:
assert klass.__module__ == "datafusion.expr"

# schema
for klass in [DFField, DFSchema]:
assert klass.__module__ == "datafusion.common"

Expand Down
12 changes: 12 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,21 @@

# DataFusion Python Examples

Some of the examples rely on data which can be downloaded from the following site:

- https://www.nyc.gov/site/tlc/about/tlc-trip-record-data.page

Here is a direct link to the file used in the examples:

- https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2021-01.parquet

## Examples

- [Query a Parquet file using SQL](./sql-parquet.py)
- [Query a Parquet file using the DataFrame API](./dataframe-parquet.py)
- [Run a SQL query and store the results in a Pandas DataFrame](./sql-to-pandas.py)
- [Query PyArrow Data](./query-pyarrow-data.py)
- [Register a Python UDF with DataFusion](./python-udf.py)
- [Register a Python UDAF with DataFusion](./python-udaf.py)
- [Executing SQL on Polars](./sql-on-polars.py)
- [Executing SQL on Pandas](./sql-on-pandas.py)
26 changes: 26 additions & 0 deletions examples/sql-on-pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datafusion.pandas import SessionContext


ctx = SessionContext()
ctx.register_parquet(
"taxi", "/mnt/bigdata/nyctaxi/yellow_tripdata_2021-01.parquet"
)
df = ctx.sql("select passenger_count from taxi")
print(df)
28 changes: 28 additions & 0 deletions examples/sql-on-polars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from datafusion.polars import SessionContext


ctx = SessionContext()
ctx.register_parquet(
"taxi", "/mnt/bigdata/nyctaxi/yellow_tripdata_2021-01.parquet"
)
df = ctx.sql(
"select passenger_count, count(*) from taxi group by passenger_count"
)
print(df)
32 changes: 32 additions & 0 deletions src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,20 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField};

use crate::errors::py_runtime_err;
use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
use crate::expr::literal::PyLiteral;
use datafusion::scalar::ScalarValue;

pub mod aggregate;
pub mod aggregate_expr;
pub mod binary_expr;
pub mod column;
pub mod filter;
pub mod limit;
pub mod literal;
pub mod logical_node;
pub mod projection;
pub mod sort;
Expand All @@ -53,6 +62,22 @@ impl From<Expr> for PyExpr {

#[pymethods]
impl PyExpr {
/// Return the specific expression
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
Python::with_gil(|_| match &self.expr {
Expr::Column(col) => Ok(PyColumn::from(col.clone()).into_py(py)),
Expr::Literal(value) => Ok(PyLiteral::from(value.clone()).into_py(py)),
Expr::BinaryExpr(expr) => Ok(PyBinaryExpr::from(expr.clone()).into_py(py)),
Expr::AggregateFunction(expr) => {
Ok(PyAggregateFunction::from(expr.clone()).into_py(py))
}
other => Err(py_runtime_err(format!(
"Cannot convert this Expr to a Python object: {:?}",
other
))),
})
}

fn __richcmp__(&self, other: PyExpr, op: CompareOp) -> PyExpr {
let expr = match op {
CompareOp::Lt => self.expr.clone().lt(other.expr),
Expand Down Expand Up @@ -144,7 +169,14 @@ impl PyExpr {

/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
// expressions
m.add_class::<PyExpr>()?;
m.add_class::<PyColumn>()?;
m.add_class::<PyLiteral>()?;
m.add_class::<PyBinaryExpr>()?;
m.add_class::<PyLiteral>()?;
m.add_class::<PyAggregateFunction>()?;
// operators
m.add_class::<table_scan::PyTableScan>()?;
m.add_class::<projection::PyProjection>()?;
m.add_class::<filter::PyFilter>()?;
Expand Down
Loading
0